/src/whoosh/fst.py
Python | 1550 lines | 1469 code | 25 blank | 56 comment | 22 complexity | ec47613e30de35ccb738d6548d55ebc7 MD5 | raw file
Possible License(s): Apache-2.0
- # Copyright 2009 Matt Chaput. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # 1. Redistributions of source code must retain the above copyright notice,
- # this list of conditions and the following disclaimer.
- #
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
- # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
- # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
- # EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
- # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
- # OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
- # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
- # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
- # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # The views and conclusions contained in the software and documentation are
- # those of the authors and should not be interpreted as representing official
- # policies, either expressed or implied, of Matt Chaput.
- """
- This module implements an FST/FSA writer and reader. An FST (Finite State
- Transducer) stores a directed acyclic graph with values associated with the
- leaves. Common elements of the values are pushed inside the tree. An FST that
- does not store values is a regular FSA.
- The format of the leaf values is pluggable using subclasses of the Values
- class.
- Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
- use in (at least) spell checking.
- """
- import sys, copy
- from array import array
- from hashlib import sha1 # @UnresolvedImport
- from whoosh.compat import b, u, BytesIO
- from whoosh.compat import xrange, iteritems, iterkeys, izip, array_tobytes
- from whoosh.compat import bytes_type, text_type
- from whoosh.filedb.structfile import StructFile
- from whoosh.system import _INT_SIZE
- from whoosh.system import pack_byte, pack_int, pack_uint, pack_long
- from whoosh.system import emptybytes
- from whoosh.util.text import utf8encode, utf8decode
- from whoosh.util.varints import varint
- class FileVersionError(Exception):
- pass
- class InactiveCursor(Exception):
- pass
- ARC_LAST = 1
- ARC_ACCEPT = 2
- ARC_STOP = 4
- ARC_HAS_VAL = 8
- ARC_HAS_ACCEPT_VAL = 16
- MULTIBYTE_LABEL = 32
- # FST Value types
- class Values(object):
- """Base for classes the describe how to encode and decode FST values.
- """
- @staticmethod
- def is_valid(v):
- """Returns True if v is a valid object that can be stored by this
- class.
- """
- raise NotImplementedError
- @staticmethod
- def common(v1, v2):
- """Returns the "common" part of the two values, for whatever "common"
- means for this class. For example, a string implementation would return
- the common shared prefix, for an int implementation it would return
- the minimum of the two numbers.
- If there is no common part, this method should return None.
- """
- raise NotImplementedError
- @staticmethod
- def add(prefix, v):
- """Adds the given prefix (the result of a call to common()) to the
- given value.
- """
- raise NotImplementedError
- @staticmethod
- def subtract(v, prefix):
- """Subtracts the "common" part (the prefix) from the given value.
- """
- raise NotImplementedError
- @staticmethod
- def write(dbfile, v):
- """Writes value v to a file.
- """
- raise NotImplementedError
- @staticmethod
- def read(dbfile):
- """Reads a value from the given file.
- """
- raise NotImplementedError
- @classmethod
- def skip(cls, dbfile):
- """Skips over a value in the given file.
- """
- cls.read(dbfile)
- @staticmethod
- def to_bytes(v):
- """Returns a str (Python 2.x) or bytes (Python 3) representation of
- the given value. This is used for calculating node digests, so it
- should be unique but fast to calculate, and does not have to be
- parseable.
- """
- raise NotImplementedError
- @staticmethod
- def merge(v1, v2):
- raise NotImplementedError
- class IntValues(Values):
- """Stores integer values in an FST.
- """
- @staticmethod
- def is_valid(v):
- return isinstance(v, int) and v >= 0
- @staticmethod
- def common(v1, v2):
- if v1 is None or v2 is None:
- return None
- if v1 == v2:
- return v1
- return min(v1, v2)
- @staticmethod
- def add(base, v):
- if base is None:
- return v
- if v is None:
- return base
- return base + v
- @staticmethod
- def subtract(v, base):
- if v is None:
- return None
- if base is None:
- return v
- return v - base
- @staticmethod
- def write(dbfile, v):
- dbfile.write_uint(v)
- @staticmethod
- def read(dbfile):
- return dbfile.read_uint()
- @staticmethod
- def skip(dbfile):
- dbfile.seek(_INT_SIZE, 1)
- @staticmethod
- def to_bytes(v):
- return pack_int(v)
- class SequenceValues(Values):
- """Abstract base class for value types that store sequences.
- """
- @staticmethod
- def is_valid(v):
- return isinstance(self, (list, tuple))
- @staticmethod
- def common(v1, v2):
- if v1 is None or v2 is None:
- return None
- i = 0
- while i < len(v1) and i < len(v2):
- if v1[i] != v2[i]:
- break
- i += 1
- if i == 0:
- return None
- if i == len(v1):
- return v1
- if i == len(v2):
- return v2
- return v1[:i]
- @staticmethod
- def add(prefix, v):
- if prefix is None:
- return v
- if v is None:
- return prefix
- return prefix + v
- @staticmethod
- def subtract(v, prefix):
- if prefix is None:
- return v
- if v is None:
- return None
- if len(v) == len(prefix):
- return None
- if len(v) < len(prefix) or len(prefix) == 0:
- raise ValueError((v, prefix))
- return v[len(prefix):]
- @staticmethod
- def write(dbfile, v):
- dbfile.write_pickle(v)
- @staticmethod
- def read(dbfile):
- return dbfile.read_pickle()
- class BytesValues(SequenceValues):
- """Stores bytes objects (str in Python 2.x) in an FST.
- """
- @staticmethod
- def is_valid(v):
- return isinstance(v, bytes_type)
- @staticmethod
- def write(dbfile, v):
- dbfile.write_int(len(v))
- dbfile.write(v)
- @staticmethod
- def read(dbfile):
- length = dbfile.read_int()
- return dbfile.read(length)
- @staticmethod
- def skip(dbfile):
- length = dbfile.read_int()
- dbfile.seek(length, 1)
- @staticmethod
- def to_bytes(v):
- return v
- class ArrayValues(SequenceValues):
- """Stores array.array objects in an FST.
- """
- def __init__(self, typecode):
- self.typecode = typecode
- self.itemsize = array(self.typecode).itemsize
- def is_valid(self, v):
- return isinstance(v, array) and v.typecode == self.typecode
- @staticmethod
- def write(dbfile, v):
- dbfile.write(b(v.typecode))
- dbfile.write_int(len(v))
- dbfile.write_array(v)
- def read(self, dbfile):
- typecode = u(dbfile.read(1))
- length = dbfile.read_int()
- return dbfile.read_array(self.typecode, length)
- def skip(self, dbfile):
- length = dbfile.read_int()
- dbfile.seek(length * self.itemsize, 1)
- @staticmethod
- def to_bytes(v):
- return array_tobytes(v)
- class IntListValues(SequenceValues):
- """Stores lists of positive, increasing integers (that is, lists of
- integers where each number is >= 0 and each number is greater than or equal
- to the number that precedes it) in an FST.
- """
- @staticmethod
- def is_valid(v):
- if isinstance(v, (list, tuple)):
- if len(v) < 2:
- return True
- for i in xrange(1, len(v)):
- if not isinstance(v[i], int) or v[i] < v[i - 1]:
- return False
- return True
- return False
- @staticmethod
- def write(dbfile, v):
- base = 0
- dbfile.write_varint(len(v))
- for x in v:
- delta = x - base
- assert delta >= 0
- dbfile.write_varint(delta)
- base = x
- @staticmethod
- def read(dbfile):
- length = dbfile.read_varint()
- result = []
- if length > 0:
- base = 0
- for _ in xrange(length):
- base += dbfile.read_varint()
- result.append(base)
- return result
- @staticmethod
- def to_bytes(v):
- return b(repr(v))
- # Node-like interface wrappers
- class Node(object):
- """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
- arc-based interface of GraphReader into Node objects with methods to follow
- edges.
- """
- def __init__(self, owner, address, accept=False):
- self.owner = owner
- self.address = address
- self._edges = None
- self.accept = accept
- def __iter__(self):
- if not self._edges:
- self._load()
- return iterkeys(self._edges)
- def __contains__(self, key):
- if self._edges is None:
- self._load()
- return key in self._edges
- def _load(self):
- owner = self.owner
- if self.address is None:
- d = {}
- else:
- d = dict((arc.label, Node(owner, arc.target, arc.accept))
- for arc in self.owner.iter_arcs(self.address))
- self._edges = d
- def keys(self):
- if self._edges is None:
- self._load()
- return self._edges.keys()
- def all_edges(self):
- if self._edges is None:
- self._load()
- return self._edges
- def edge(self, key):
- if self._edges is None:
- self._load()
- return self._edges[key]
- def flatten(self, sofar=emptybytes):
- if self.accept:
- yield sofar
- for key in sorted(self):
- node = self.edge(key)
- for result in node.flatten(sofar + key):
- yield result
- def flatten_strings(self):
- return (utf8decode(k)[0] for k in self.flatten())
- class ComboNode(Node):
- """Base class for nodes that blend the nodes of two different graphs.
- Concrete subclasses need to implement the ``edge()`` method and possibly
- override the ``accept`` property.
- """
- def __init__(self, a, b):
- self.a = a
- self.b = b
- def __repr__(self):
- return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
- def __contains__(self, key):
- return key in self.a or key in self.b
- def __iter__(self):
- return iter(set(self.a) | set(self.b))
- @property
- def accept(self):
- return self.a.accept or self.b.accept
- class UnionNode(ComboNode):
- """Makes two graphs appear to be the union of the two graphs.
- """
- def edge(self, key):
- a = self.a
- b = self.b
- if key in a and key in b:
- return UnionNode(a.edge(key), b.edge(key))
- elif key in a:
- return a.edge(key)
- else:
- return b.edge(key)
- class IntersectionNode(ComboNode):
- """Makes two graphs appear to be the intersection of the two graphs.
- """
- def edge(self, key):
- a = self.a
- b = self.b
- if key in a and key in b:
- return IntersectionNode(a.edge(key), b.edge(key))
- # Cursor
- class BaseCursor(object):
- """Base class for a cursor-type object for navigating an FST/word graph,
- represented by a :class:`GraphReader` object.
- >>> cur = GraphReader(dawgfile).cursor()
- >>> for key in cur.follow():
- ... print(repr(key))
- The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
- """
- def is_active(self):
- """Returns True if this cursor is still active, that is it has not
- read past the last arc in the graph.
- """
- raise NotImplementedError
- def label(self):
- """Returns the label bytes of the current arc.
- """
- raise NotImplementedError
- def prefix(self):
- """Returns a sequence of the label bytes for the path from the root
- to the current arc.
- """
- raise NotImplementedError
- def prefix_bytes(self):
- """Returns the label bytes for the path from the root to the current
- arc as a single joined bytes object.
- """
- return emptybytes.join(self.prefix())
- def prefix_string(self):
- """Returns the labels of the path from the root to the current arc as
- a decoded unicode string.
- """
- return utf8decode(self.prefix_bytes())[0]
- def peek_key(self):
- """Returns a sequence of label bytes representing the next closest
- key in the graph.
- """
- for label in self.prefix():
- yield label
- c = self.copy()
- while not c.stopped():
- c.follow()
- yield c.label()
- def peek_key_bytes(self):
- """Returns the next closest key in the graph as a single bytes object.
- """
- return emptybytes.join(self.peek_key())
- def peek_key_string(self):
- """Returns the next closest key in the graph as a decoded unicode
- string.
- """
- return utf8decode(self.peek_key_bytes())[0]
- def stopped(self):
- """Returns True if the current arc leads to a stop state.
- """
- raise NotImplementedError
- def value(self):
- """Returns the value at the current arc, if reading an FST.
- """
- raise NotImplementedError
- def accept(self):
- """Returns True if the current arc leads to an accept state (the end
- of a valid key).
- """
- raise NotImplementedError
- def at_last_arc(self):
- """Returns True if the current arc is the last outgoing arc from the
- previous node.
- """
- raise NotImplementedError
- def next_arc(self):
- """Moves to the next outgoing arc from the previous node.
- """
- raise NotImplementedError
- def follow(self):
- """Follows the current arc.
- """
- raise NotImplementedError
- def switch_to(self, label):
- """Switch to the sibling arc with the given label bytes.
- """
- _label = self.label
- _at_last_arc = self.at_last_arc
- _next_arc = self.next_arc
- while True:
- thislabel = _label()
- if thislabel == label:
- return True
- if thislabel > label or _at_last_arc():
- return False
- _next_arc()
- def skip_to(self, key):
- """Moves the cursor to the path represented by the given key bytes.
- """
- _accept = self.accept
- _prefix = self.prefix
- _next_arc = self.next_arc
- keylist = list(key)
- while True:
- if _accept():
- thiskey = list(_prefix())
- if keylist == thiskey:
- return True
- elif keylist > thiskey:
- return False
- _next_arc()
- def flatten(self):
- """Yields the keys in the graph, starting at the current position.
- """
- _is_active = self.is_active
- _accept = self.accept
- _stopped = self.stopped
- _follow = self.follow
- _next_arc = self.next_arc
- _prefix_bytes = self.prefix_bytes
- if not _is_active():
- raise InactiveCursor
- while _is_active():
- if _accept():
- yield _prefix_bytes()
- if not _stopped():
- _follow()
- continue
- _next_arc()
- def flatten_v(self):
- """Yields (key, value) tuples in an FST, starting at the current
- position.
- """
- for key in self.flatten():
- yield key, self.value()
- def flatten_strings(self):
- return (utf8decode(k)[0] for k in self.flatten())
- def find_path(self, path):
- """Follows the labels in the given path, starting at the current
- position.
- """
- path = to_labels(path)
- _switch_to = self.switch_to
- _follow = self.follow
- _stopped = self.stopped
- first = True
- for i, label in enumerate(path):
- if not first:
- _follow()
- if not _switch_to(label):
- return False
- if _stopped():
- if i < len(path) - 1:
- return False
- first = False
- return True
- class Cursor(BaseCursor):
- def __init__(self, graph, root=None, stack=None):
- self.graph = graph
- self.vtype = graph.vtype
- self.root = root if root is not None else graph.default_root()
- if stack:
- self.stack = stack
- else:
- self.reset()
- def _current_attr(self, name):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- return getattr(stack[-1], name)
- def is_active(self):
- return bool(self.stack)
- def stopped(self):
- return self._current_attr("target") is None
- def accept(self):
- return self._current_attr("accept")
- def at_last_arc(self):
- return self._current_attr("lastarc")
- def label(self):
- return self._current_attr("label")
- def reset(self):
- self.stack = []
- self.sums = [None]
- self._push(self.graph.arc_at(self.root))
- def copy(self):
- return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
- def prefix(self):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- return (arc.label for arc in stack)
- # Override: more efficient implementation using graph methods directly
- def peek_key(self):
- if not self.stack:
- raise InactiveCursor
- for label in self.prefix():
- yield label
- arc = copy.copy(self.stack[-1])
- graph = self.graph
- while not arc.accept and arc.target is not None:
- graph.arc_at(arc.target, arc)
- yield arc.label
- def value(self):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- vtype = self.vtype
- if not vtype:
- raise Exception("No value type")
- v = self.sums[-1]
- current = stack[-1]
- if current.value:
- v = vtype.add(v, current.value)
- if current.accept and current.acceptval is not None:
- v = vtype.add(v, current.acceptval)
- return v
- def next_arc(self):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- while stack and stack[-1].lastarc:
- self.pop()
- if stack:
- current = stack[-1]
- self.graph.arc_at(current.endpos, current)
- return current
- def follow(self):
- address = self._current_attr("target")
- if address is None:
- raise Exception("Can't follow a stop arc")
- self._push(self.graph.arc_at(address))
- return self
- # Override: more efficient implementation manipulating the stack
- def skip_to(self, key):
- key = to_labels(key)
- stack = self.stack
- if not stack:
- raise InactiveCursor
- _follow = self.follow
- _next_arc = self.next_arc
- i = self._pop_to_prefix(key)
- while stack and i < len(key):
- curlabel = stack[-1].label
- keylabel = key[i]
- if curlabel == keylabel:
- _follow()
- i += 1
- elif curlabel > keylabel:
- return
- else:
- _next_arc()
- # Override: more efficient implementation using find_arc
- def switch_to(self, label):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- current = stack[-1]
- if label == current.label:
- return True
- else:
- arc = self.graph.find_arc(current.endpos, label, current)
- return arc
- def _push(self, arc):
- if self.vtype and self.stack:
- sums = self.sums
- sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
- self.stack.append(arc)
- def pop(self):
- self.stack.pop()
- if self.vtype:
- self.sums.pop()
- def _pop_to_prefix(self, key):
- stack = self.stack
- if not stack:
- raise InactiveCursor
- i = 0
- maxpre = min(len(stack), len(key))
- while i < maxpre and key[i] == stack[i].label:
- i += 1
- if stack[i].label > key[i]:
- self.current = None
- return
- while len(stack) > i + 1:
- self.pop()
- self.next_arc()
- return i
- class UncompiledNode(object):
- # Represents an "in-memory" node used by the GraphWriter before it is
- # written to disk.
- compiled = False
- def __init__(self, owner):
- self.owner = owner
- self._digest = None
- self.clear()
- def clear(self):
- self.arcs = []
- self.value = None
- self.accept = False
- self.inputcount = 0
- def __repr__(self):
- return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
- def digest(self):
- if self._digest is None:
- d = sha1()
- vtype = self.owner.vtype
- for arc in self.arcs:
- d.update(arc.label)
- if arc.target:
- d.update(pack_long(arc.target))
- else:
- d.update(b("z"))
- if arc.value:
- d.update(vtype.to_bytes(arc.value))
- if arc.accept:
- d.update(b("T"))
- self._digest = d.digest()
- return self._digest
- def edges(self):
- return self.arcs
- def last_value(self, label):
- assert self.arcs[-1].label == label
- return self.arcs[-1].value
- def add_arc(self, label, target):
- self.arcs.append(Arc(label, target))
- def replace_last(self, label, target, accept, acceptval=None):
- arc = self.arcs[-1]
- assert arc.label == label, "%r != %r" % (arc.label, label)
- arc.target = target
- arc.accept = accept
- arc.acceptval = acceptval
- def delete_last(self, label, target):
- arc = self.arcs.pop()
- assert arc.label == label
- assert arc.target == target
- def set_last_value(self, label, value):
- arc = self.arcs[-1]
- assert arc.label == label, "%r->%r" % (arc.label, label)
- arc.value = value
- def prepend_value(self, prefix):
- add = self.owner.vtype.add
- for arc in self.arcs:
- arc.value = add(prefix, arc.value)
- if self.accept:
- self.value = add(prefix, self.value)
- class Arc(object):
- """
- Represents a directed arc between two nodes in an FSA/FST graph.
- The ``lastarc`` attribute is True if this is the last outgoing arc from the
- previous node.
- """
- __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
- "endpos")
- def __init__(self, label=None, target=None, value=None, accept=False,
- acceptval=None):
- """
- :param label: The label bytes for this arc. For a word graph, this will
- be a character.
- :param target: The address of the node at the endpoint of this arc.
- :param value: The inner FST value at the endpoint of this arc.
- :param accept: Whether the endpoint of this arc is an accept state
- (e.g. the end of a valid word).
- :param acceptval: If the endpoint of this arc is an accept state, the
- final FST value for that accepted state.
- """
- self.label = label
- self.target = target
- self.value = value
- self.accept = accept
- self.lastarc = None
- self.acceptval = acceptval
- self.endpos = None
- def __repr__(self):
- return "<%r-%s %s%s>" % (self.label, self.target,
- "." if self.accept else "",
- (" %r" % self.value) if self.value else "")
- def __eq__(self, other):
- if (isinstance(other, self.__class__) and self.accept == other.accept
- and self.lastarc == other.lastarc and self.target == other.target
- and self.value == other.value and self.label == other.label):
- return True
- return False
- # Graph writer
- class GraphWriter(object):
- """Writes an FSA/FST graph to disk.
- Call ``insert(key)`` to insert keys into the graph. You must
- insert keys in sorted order. Call ``close()`` to finish the graph and close
- the file.
- >>> gw = GraphWriter(my_file)
- >>> gw.insert("alfa")
- >>> gw.insert("bravo")
- >>> gw.insert("charlie")
- >>> gw.close()
- The graph writer can write separate graphs for multiple fields. Use
- ``start_field(name)`` and ``finish_field()`` to separate fields.
- >>> gw = GraphWriter(my_file)
- >>> gw.start_field("content")
- >>> gw.insert("alfalfa")
- >>> gw.insert("apple")
- >>> gw.finish_field()
- >>> gw.start_field("title")
- >>> gw.insert("artichoke")
- >>> gw.finish_field()
- >>> gw.close()
- """
- version = 1
- def __init__(self, dbfile, vtype=None, merge=None):
- """
- :param dbfile: the file to write to.
- :param vtype: a :class:`Values` class to use for storing values. This
- is only necessary if you will be storing values for the keys.
- :param merge: a function that takes two values and returns a single
- value. This is called if you insert two identical keys with values.
- """
- self.dbfile = dbfile
- self.vtype = vtype
- self.merge = merge
- self.fieldroots = {}
- self.arc_count = 0
- self.node_count = 0
- self.fixed_count = 0
- dbfile.write(b("GRPH"))
- dbfile.write_int(self.version)
- dbfile.write_uint(0)
- self._infield = False
- def start_field(self, fieldname):
- """Starts a new graph for the given field.
- """
- if not fieldname:
- raise ValueError("Field name cannot be equivalent to False")
- if self._infield:
- self.finish_field()
- self.fieldname = fieldname
- self.seen = {}
- self.nodes = [UncompiledNode(self)]
- self.lastkey = ''
- self._inserted = False
- self._infield = True
- def finish_field(self):
- """Finishes the graph for the current field.
- """
- if not self._infield:
- raise Exception("Called finish_field before start_field")
- self._infield = False
- if self._inserted:
- self.fieldroots[self.fieldname] = self._finish()
- self.fieldname = None
- def close(self):
- """Finishes the current graph and closes the underlying file.
- """
- if self.fieldname is not None:
- self.finish_field()
- dbfile = self.dbfile
- here = dbfile.tell()
- dbfile.write_pickle(self.fieldroots)
- dbfile.flush()
- dbfile.seek(4 + _INT_SIZE) # Seek past magic and version number
- dbfile.write_uint(here)
- dbfile.close()
- def insert(self, key, value=None):
- """Inserts the given key into the graph.
- :param key: a sequence of bytes objects, a bytes object, or a string.
- :param value: an optional value to encode in the graph along with the
- key. If the writer was not instantiated with a value type, passing
- a value here will raise an error.
- """
- if not self._infield:
- raise Exception("Inserted %r before starting a field" % key)
- self._inserted = True
- key = to_labels(key) # Python 3 sucks
- vtype = self.vtype
- lastkey = self.lastkey
- nodes = self.nodes
- if len(key) < 1:
- raise KeyError("Can't store a null key %r" % (key,))
- if lastkey and lastkey > key:
- raise KeyError("Keys out of order %r..%r" % (lastkey, key))
- # Find the common prefix shared by this key and the previous one
- prefixlen = 0
- for i in xrange(min(len(lastkey), len(key))):
- if lastkey[i] != key[i]:
- break
- prefixlen += 1
- # Compile the nodes after the prefix, since they're not shared
- self._freeze_tail(prefixlen + 1)
- # Create new nodes for the parts of this key after the shared prefix
- for char in key[prefixlen:]:
- node = UncompiledNode(self)
- # Create an arc to this node on the previous node
- nodes[-1].add_arc(char, node)
- nodes.append(node)
- # Mark the last node as an accept state
- lastnode = nodes[-1]
- lastnode.accept = True
- if vtype:
- if value is not None and not vtype.is_valid(value):
- raise ValueError("%r is not valid for %s" % (value, vtype))
- # Push value commonalities through the tree
- common = None
- for i in xrange(1, prefixlen + 1):
- node = nodes[i]
- parent = nodes[i - 1]
- lastvalue = parent.last_value(key[i - 1])
- if lastvalue is not None:
- common = vtype.common(value, lastvalue)
- suffix = vtype.subtract(lastvalue, common)
- parent.set_last_value(key[i - 1], common)
- node.prepend_value(suffix)
- else:
- common = suffix = None
- value = vtype.subtract(value, common)
- if key == lastkey:
- # If this key is a duplicate, merge its value with the value of
- # the previous (same) key
- lastnode.value = self.merge(lastnode.value, value)
- else:
- nodes[prefixlen].set_last_value(key[prefixlen], value)
- elif value:
- raise Exception("Value %r but no value type" % value)
- self.lastkey = key
- def _freeze_tail(self, prefixlen):
- nodes = self.nodes
- lastkey = self.lastkey
- downto = max(1, prefixlen)
- while len(nodes) > downto:
- node = nodes.pop()
- parent = nodes[-1]
- inlabel = lastkey[len(nodes) - 1]
- self._compile_targets(node)
- accept = node.accept or len(node.arcs) == 0
- address = self._compile_node(node)
- parent.replace_last(inlabel, address, accept, node.value)
- def _finish(self):
- nodes = self.nodes
- root = nodes[0]
- # Minimize nodes in the last word's suffix
- self._freeze_tail(0)
- # Compile remaining targets
- self._compile_targets(root)
- return self._compile_node(root)
- def _compile_targets(self, node):
- for arc in node.arcs:
- if isinstance(arc.target, UncompiledNode):
- n = arc.target
- if len(n.arcs) == 0:
- arc.accept = n.accept = True
- arc.target = self._compile_node(n)
- def _compile_node(self, uncnode):
- seen = self.seen
- if len(uncnode.arcs) == 0:
- # Leaf node
- address = self._write_node(uncnode)
- else:
- d = uncnode.digest()
- address = seen.get(d)
- if address is None:
- address = self._write_node(uncnode)
- seen[d] = address
- return address
- def _write_node(self, uncnode):
- vtype = self.vtype
- dbfile = self.dbfile
- arcs = uncnode.arcs
- numarcs = len(arcs)
- if not numarcs:
- if uncnode.accept:
- return None
- else:
- # What does it mean for an arc to stop but not be accepted?
- raise Exception
- self.node_count += 1
- buf = StructFile(BytesIO())
- nodestart = dbfile.tell()
- #self.count += 1
- #self.arccount += numarcs
- fixedsize = -1
- arcstart = buf.tell()
- for i, arc in enumerate(arcs):
- self.arc_count += 1
- target = arc.target
- label = arc.label
- flags = 0
- if len(label) > 1:
- flags += MULTIBYTE_LABEL
- if i == numarcs - 1:
- flags += ARC_LAST
- if arc.accept:
- flags += ARC_ACCEPT
- if target is None:
- flags += ARC_STOP
- if arc.value is not None:
- flags += ARC_HAS_VAL
- if arc.acceptval is not None:
- flags += ARC_HAS_ACCEPT_VAL
- buf.write(pack_byte(flags))
- if len(label) > 1:
- buf.write(varint(len(label)))
- buf.write(label)
- if target is not None:
- buf.write(pack_uint(target))
- if arc.value is not None:
- vtype.write(buf, arc.value)
- if arc.acceptval is not None:
- vtype.write(buf, arc.acceptval)
- here = buf.tell()
- thissize = here - arcstart
- arcstart = here
- if fixedsize == -1:
- fixedsize = thissize
- elif fixedsize > 0 and thissize != fixedsize:
- fixedsize = 0
- if fixedsize > 0:
- # Write a fake arc containing the fixed size and number of arcs
- dbfile.write_byte(255) # FIXED_SIZE
- dbfile.write_int(fixedsize)
- dbfile.write_int(numarcs)
- self.fixed_count += 1
- dbfile.write(buf.file.getvalue())
- return nodestart
- # Graph reader
- class BaseGraphReader(object):
- def cursor(self, rootname=None):
- return Cursor(self, self.root(rootname))
- def has_root(self, rootname):
- raise NotImplementedError
- def root(self, rootname=None):
- raise NotImplementedError
- # Low level methods
- def arc_at(self, address, arc):
- raise NotImplementedError
- def iter_arcs(self, address, arc=None):
- raise NotImplementedError
- def find_arc(self, address, label, arc=None):
- arc = arc or Arc()
- for arc in self.iter_arcs(address, arc):
- thislabel = arc.label
- if thislabel == label:
- return arc
- elif thislabel > label:
- return None
- # Convenience methods
- def list_arcs(self, address):
- return list(copy.copy(arc) for arc in self.iter_arcs(address))
- def arc_dict(self, address):
- return dict((arc.label, copy.copy(arc))
- for arc in self.iter_arcs(address))
- def find_path(self, path, arc=None, address=None):
- path = to_labels(path)
- if arc:
- address = arc.target
- else:
- arc = Arc()
- if address is None:
- address = self._root
- for label in path:
- if address is None:
- return None
- if not self.find_arc(address, label, arc):
- return None
- address = arc.target
- return arc
- class GraphReader(BaseGraphReader):
- def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
- self.dbfile = dbfile
- self.vtype = vtype
- self.filebase = filebase
- dbfile.seek(filebase)
- magic = dbfile.read(4)
- if magic != b("GRPH"):
- raise FileVersionError
- self.version = dbfile.read_int()
- dbfile.seek(dbfile.read_uint())
- self.roots = dbfile.read_pickle()
- self._root = None
- if rootname is None and len(self.roots) == 1:
- # If there's only one root, just use it. Have to wrap a list around
- # the keys() method here because of Python 3.
- rootname = list(self.roots.keys())[0]
- if rootname is not None:
- self._root = self.root(rootname)
- def close(self):
- self.dbfile.close()
- # Overrides
- def has_root(self, rootname):
- return rootname in self.roots
- def root(self, rootname=None):
- if rootname is None:
- return self._root
- else:
- return self.roots[rootname]
- def default_root(self):
- return self._root
- def arc_at(self, address, arc=None):
- arc = arc or Arc()
- self.dbfile.seek(address)
- return self._read_arc(arc)
- def iter_arcs(self, address, arc=None):
- arc = arc or Arc()
- _read_arc = self._read_arc
- self.dbfile.seek(address)
- while True:
- _read_arc(arc)
- yield arc
- if arc.lastarc:
- break
- def find_arc(self, address, label, arc=None):
- arc = arc or Arc()
- dbfile = self.dbfile
- dbfile.seek(address)
- # If records are fixed size, we can do a binary search
- finfo = self._read_fixed_info()
- if finfo:
- size, count = finfo
- address = dbfile.tell()
- if count > 2:
- return self._binary_search(address, size, count, label, arc)
- # If records aren't fixed size, fall back to the parent's linear
- # search method
- return BaseGraphReader.find_arc(self, address, label, arc)
- # Implementations
- def _read_arc(self, toarc=None):
- toarc = toarc or Arc()
- dbfile = self.dbfile
- flags = dbfile.read_byte()
- if flags == 255:
- # This is a fake arc containing fixed size information; skip it
- # and read the next arc
- dbfile.seek(_INT_SIZE * 2, 1)
- flags = dbfile.read_byte()
- toarc.label = self._read_label(flags)
- return self._read_arc_data(flags, toarc)
- def _read_label(self, flags):
- dbfile = self.dbfile
- if flags & MULTIBYTE_LABEL:
- length = dbfile.read_varint()
- else:
- length = 1
- label = dbfile.read(length)
- return label
- def _read_fixed_info(self):
- dbfile = self.dbfile
- flags = dbfile.read_byte()
- if flags == 255:
- size = dbfile.read_int()
- count = dbfile.read_int()
- return (size, count)
- else:
- return None
- def _read_arc_data(self, flags, arc):
- dbfile = self.dbfile
- accept = arc.accept = bool(flags & ARC_ACCEPT)
- arc.lastarc = flags & ARC_LAST
- if flags & ARC_STOP:
- arc.target = None
- else:
- arc.target = dbfile.read_uint()
- if flags & ARC_HAS_VAL:
- arc.value = self.vtype.read(dbfile)
- else:
- arc.value = None
- if accept and flags & ARC_HAS_ACCEPT_VAL:
- arc.acceptval = self.vtype.read(dbfile)
- arc.endpos = dbfile.tell()
- return arc
- def _binary_search(self, address, size, count, label, arc):
- dbfile = self.dbfile
- _read_label = self._read_label
- lo = 0
- hi = count
- while lo < hi:
- mid = (lo + hi) // 2
- midaddr = address + mid * size
- dbfile.seek(midaddr)
- flags = dbfile.read_byte()
- midlabel = self._read_label(flags)
- if midlabel == label:
- arc.label = midlabel
- return self._read_arc_data(flags, arc)
- elif midlabel < label:
- lo = mid + 1
- else:
- hi = mid
- if lo == count:
- return None
- def to_labels(key):
- """Takes a string and returns a list of bytestrings, suitable for use as
- a key or path in an FSA/FST graph.
- """
- # Convert to tuples of bytestrings (must be tuples so they can be hashed)
- keytype = type(key)
- # I hate the Python 3 bytes object so friggin much
- if keytype is tuple or keytype is list:
- if not all(isinstance(e, bytes_type) for e in key):
- raise TypeError("%r contains a non-bytestring" % key)
- if keytype is list:
- key = tuple(key)
- elif isinstance(key, bytes_type):
- key = tuple(key[i:i + 1] for i in xrange(len(key)))
- elif isinstance(key, text_type):
- key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
- else:
- raise TypeError("Don't know how to convert %r" % key)
- return key
- # Within edit distance function
- def within(graph, text, k=1, prefix=0, address=None):
- """Yields a series of keys in the given graph within ``k`` edit distance of
- ``text``. If ``prefix`` is greater than 0, all keys must match the first
- ``prefix`` characters of ``text``.
- """
- text = to_labels(text)
- if address is None:
- address = graph._root
- sofar = emptybytes
- accept = False
- if prefix:
- prefixchars = text[:prefix]
- arc = graph.find_path(prefixchars, address=address)
- if arc is None:
- return
- sofar = emptybytes.join(prefixchars)
- address = arc.target
- accept = arc.accept
- stack = [(address, k, prefix, sofar, accept)]
- seen = set()
- while stack:
- state = stack.pop()
- # Have we already tried this state?
- if state in seen:
- continue
- seen.add(state)
- address, k, i, sofar, accept = state
- # If we're at the end of the text (or deleting enough chars would get
- # us to the end and still within K), and we're in the accept state,
- # yield the current result
- if (len(text) - i <= k) and accept:
- yield utf8decode(sofar)[0]
- # If we're in the stop state, give up
- if address is None:
- continue
- # Exact match
- if i < len(text):
- arc = graph.find_arc(address, text[i])
- if arc:
- stack.append((arc.target, k, i + 1, sofar + text[i],
- arc.accept))
- # If K is already 0, can't do any more edits
- if k < 1:
- continue
- k -= 1
- arcs = graph.arc_dict(address)
- # Insertions
- stack.extend((arc.target, k, i, sofar + char, arc.accept)
- for char, arc in iteritems(arcs))
- # Deletion, replacement, and transpo only work before the end
- if i >= len(text):
- continue
- char = text[i]
- # Deletion
- stack.append((address, k, i + 1, sofar, False))
- # Replacement
- for char2, arc in iteritems(arcs):
- if char2 != char:
- stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
- # Transposition
- if i < len(text) - 1:
- char2 = text[i + 1]
- if char != char2 and char2 in arcs:
- # Find arc from next char to this char
- target = arcs[char2].target
- if target:
- arc = graph.find_arc(target, char)
- if arc:
- stack.append((arc.target, k, i + 2,
- sofar + char2 + char, arc.accept))
- # Utility functions
- def dump_graph(graph, address=None, tab=0, out=None):
- if address is None:
- address = graph._root
- if out is None:
- out = sys.stdout
- here = "%06d" % address
- for i, arc in enumerate(graph.list_arcs(address)):
- if i == 0:
- out.write(here)
- else:
- out.write(" " * 6)
- out.write(" " * tab)
- out.write("%r %r %s %r\n"
- % (arc.label, arc.target, arc.accept, arc.value))
- if arc.target is not None:
- dump_graph(graph, arc.target, tab + 1, out=out)