/bangkokhotel/lib/python2.5/site-packages/whoosh/support/dawg.py
Python | 1549 lines | 1466 code | 27 blank | 56 comment | 22 complexity | 6233492f368ddc4a5249095d8fbb480d MD5 | raw file
1# Copyright 2009 Matt Chaput. All rights reserved.
2#
3# Redistribution and use in source and binary forms, with or without
4# modification, are permitted provided that the following conditions are met:
5#
6# 1. Redistributions of source code must retain the above copyright notice,
7# this list of conditions and the following disclaimer.
8#
9# 2. Redistributions in binary form must reproduce the above copyright
10# notice, this list of conditions and the following disclaimer in the
11# documentation and/or other materials provided with the distribution.
12#
13# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
14# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
17# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
19# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
20# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
21# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
22# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23#
24# The views and conclusions contained in the software and documentation are
25# those of the authors and should not be interpreted as representing official
26# policies, either expressed or implied, of Matt Chaput.
27
28"""
29This module implements an FST/FSA writer and reader. An FST (Finite State
30Transducer) stores a directed acyclic graph with values associated with the
31leaves. Common elements of the values are pushed inside the tree. An FST that
32does not store values is a regular FSA.
33
34The format of the leaf values is pluggable using subclasses of the Values
35class.
36
37Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
38use in (at least) spell checking.
39"""
40
41
42import sys, copy
43from array import array
44from hashlib import sha1 # @UnresolvedImport
45
46from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
47 bytes_type, text_type, izip, array_tobytes)
48from whoosh.filedb.structfile import StructFile
49from whoosh.system import (_INT_SIZE, pack_byte, pack_int, pack_uint,
50 pack_long, emptybytes)
51from whoosh.util import utf8encode, utf8decode, varint
52
53
54class FileVersionError(Exception):
55 pass
56
57
58class InactiveCursor(Exception):
59 pass
60
61
62ARC_LAST = 1
63ARC_ACCEPT = 2
64ARC_STOP = 4
65ARC_HAS_VAL = 8
66ARC_HAS_ACCEPT_VAL = 16
67MULTIBYTE_LABEL = 32
68
69
70# FST Value types
71
72class Values(object):
73 """Base for classes the describe how to encode and decode FST values.
74 """
75
76 @staticmethod
77 def is_valid(v):
78 """Returns True if v is a valid object that can be stored by this
79 class.
80 """
81
82 raise NotImplementedError
83
84 @staticmethod
85 def common(v1, v2):
86 """Returns the "common" part of the two values, for whatever "common"
87 means for this class. For example, a string implementation would return
88 the common shared prefix, for an int implementation it would return
89 the minimum of the two numbers.
90
91 If there is no common part, this method should return None.
92 """
93
94 raise NotImplementedError
95
96 @staticmethod
97 def add(prefix, v):
98 """Adds the given prefix (the result of a call to common()) to the
99 given value.
100 """
101
102 raise NotImplementedError
103
104 @staticmethod
105 def subtract(v, prefix):
106 """Subtracts the "common" part (the prefix) from the given value.
107 """
108
109 raise NotImplementedError
110
111 @staticmethod
112 def write(dbfile, v):
113 """Writes value v to a file.
114 """
115
116 raise NotImplementedError
117
118 @staticmethod
119 def read(dbfile):
120 """Reads a value from the given file.
121 """
122
123 raise NotImplementedError
124
125 @classmethod
126 def skip(cls, dbfile):
127 """Skips over a value in the given file.
128 """
129
130 cls.read(dbfile)
131
132 @staticmethod
133 def to_bytes(v):
134 """Returns a str (Python 2.x) or bytes (Python 3) representation of
135 the given value. This is used for calculating node digests, so it
136 should be unique but fast to calculate, and does not have to be
137 parseable.
138 """
139
140 raise NotImplementedError
141
142 @staticmethod
143 def merge(v1, v2):
144 raise NotImplementedError
145
146
147class IntValues(Values):
148 """Stores integer values in an FST.
149 """
150
151 @staticmethod
152 def is_valid(v):
153 return isinstance(v, int) and v >= 0
154
155 @staticmethod
156 def common(v1, v2):
157 if v1 is None or v2 is None:
158 return None
159 if v1 == v2:
160 return v1
161 return min(v1, v2)
162
163 @staticmethod
164 def add(base, v):
165 if base is None:
166 return v
167 if v is None:
168 return base
169 return base + v
170
171 @staticmethod
172 def subtract(v, base):
173 if v is None:
174 return None
175 if base is None:
176 return v
177 return v - base
178
179 @staticmethod
180 def write(dbfile, v):
181 dbfile.write_uint(v)
182
183 @staticmethod
184 def read(dbfile):
185 return dbfile.read_uint()
186
187 @staticmethod
188 def skip(dbfile):
189 dbfile.seek(_INT_SIZE, 1)
190
191 @staticmethod
192 def to_bytes(v):
193 return pack_int(v)
194
195
196class SequenceValues(Values):
197 """Abstract base class for value types that store sequences.
198 """
199
200 @staticmethod
201 def is_valid(v):
202 return isinstance(self, (list, tuple))
203
204 @staticmethod
205 def common(v1, v2):
206 if v1 is None or v2 is None:
207 return None
208
209 i = 0
210 while i < len(v1) and i < len(v2):
211 if v1[i] != v2[i]:
212 break
213 i += 1
214
215 if i == 0:
216 return None
217 if i == len(v1):
218 return v1
219 if i == len(v2):
220 return v2
221 return v1[:i]
222
223 @staticmethod
224 def add(prefix, v):
225 if prefix is None:
226 return v
227 if v is None:
228 return prefix
229 return prefix + v
230
231 @staticmethod
232 def subtract(v, prefix):
233 if prefix is None:
234 return v
235 if v is None:
236 return None
237 if len(v) == len(prefix):
238 return None
239 if len(v) < len(prefix) or len(prefix) == 0:
240 raise ValueError((v, prefix))
241 return v[len(prefix):]
242
243 @staticmethod
244 def write(dbfile, v):
245 dbfile.write_pickle(v)
246
247 @staticmethod
248 def read(dbfile):
249 return dbfile.read_pickle()
250
251
252class BytesValues(SequenceValues):
253 """Stores bytes objects (str in Python 2.x) in an FST.
254 """
255
256 @staticmethod
257 def is_valid(v):
258 return isinstance(v, bytes_type)
259
260 @staticmethod
261 def write(dbfile, v):
262 dbfile.write_int(len(v))
263 dbfile.write(v)
264
265 @staticmethod
266 def read(dbfile):
267 length = dbfile.read_int()
268 return dbfile.read(length)
269
270 @staticmethod
271 def skip(dbfile):
272 length = dbfile.read_int()
273 dbfile.seek(length, 1)
274
275 @staticmethod
276 def to_bytes(v):
277 return v
278
279
280class ArrayValues(SequenceValues):
281 """Stores array.array objects in an FST.
282 """
283
284 def __init__(self, typecode):
285 self.typecode = typecode
286 self.itemsize = array(self.typecode).itemsize
287
288 def is_valid(self, v):
289 return isinstance(v, array) and v.typecode == self.typecode
290
291 @staticmethod
292 def write(dbfile, v):
293 dbfile.write(b(v.typecode))
294 dbfile.write_int(len(v))
295 dbfile.write_array(v)
296
297 def read(self, dbfile):
298 typecode = u(dbfile.read(1))
299 length = dbfile.read_int()
300 return dbfile.read_array(self.typecode, length)
301
302 def skip(self, dbfile):
303 length = dbfile.read_int()
304 dbfile.seek(length * self.itemsize, 1)
305
306 @staticmethod
307 def to_bytes(v):
308 return array_tobytes(v)
309
310
311class IntListValues(SequenceValues):
312 """Stores lists of positive, increasing integers (that is, lists of
313 integers where each number is >= 0 and each number is greater than or equal
314 to the number that precedes it) in an FST.
315 """
316
317 @staticmethod
318 def is_valid(v):
319 if isinstance(v, (list, tuple)):
320 if len(v) < 2:
321 return True
322 for i in xrange(1, len(v)):
323 if not isinstance(v[i], int) or v[i] < v[i - 1]:
324 return False
325 return True
326 return False
327
328 @staticmethod
329 def write(dbfile, v):
330 base = 0
331 dbfile.write_varint(len(v))
332 for x in v:
333 delta = x - base
334 assert delta >= 0
335 dbfile.write_varint(delta)
336 base = x
337
338 @staticmethod
339 def read(dbfile):
340 length = dbfile.read_varint()
341 result = []
342 if length > 0:
343 base = 0
344 for _ in xrange(length):
345 base += dbfile.read_varint()
346 result.append(base)
347 return result
348
349 @staticmethod
350 def to_bytes(v):
351 return b(repr(v))
352
353
354# Node-like interface wrappers
355
356class Node(object):
357 """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
358 arc-based interface of GraphReader into Node objects with methods to follow
359 edges.
360 """
361
362 def __init__(self, owner, address, accept=False):
363 self.owner = owner
364 self.address = address
365 self._edges = None
366 self.accept = accept
367
368 def __iter__(self):
369 if not self._edges:
370 self._load()
371 return iterkeys(self._edges)
372
373 def __contains__(self, key):
374 if self._edges is None:
375 self._load()
376 return key in self._edges
377
378 def _load(self):
379 owner = self.owner
380 if self.address is None:
381 d = {}
382 else:
383 d = dict((arc.label, Node(owner, arc.target, arc.accept))
384 for arc in self.owner.iter_arcs(self.address))
385 self._edges = d
386
387 def keys(self):
388 if self._edges is None:
389 self._load()
390 return self._edges.keys()
391
392 def all_edges(self):
393 if self._edges is None:
394 self._load()
395 return self._edges
396
397 def edge(self, key):
398 if self._edges is None:
399 self._load()
400 return self._edges[key]
401
402 def flatten(self, sofar=emptybytes):
403 if self.accept:
404 yield sofar
405 for key in sorted(self):
406 node = self.edge(key)
407 for result in node.flatten(sofar + key):
408 yield result
409
410 def flatten_strings(self):
411 return (utf8decode(k)[0] for k in self.flatten())
412
413
414class ComboNode(Node):
415 """Base class for nodes that blend the nodes of two different graphs.
416
417 Concrete subclasses need to implement the ``edge()`` method and possibly
418 override the ``accept`` property.
419 """
420
421 def __init__(self, a, b):
422 self.a = a
423 self.b = b
424
425 def __repr__(self):
426 return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
427
428 def __contains__(self, key):
429 return key in self.a or key in self.b
430
431 def __iter__(self):
432 return iter(set(self.a) | set(self.b))
433
434 @property
435 def accept(self):
436 return self.a.accept or self.b.accept
437
438
439class UnionNode(ComboNode):
440 """Makes two graphs appear to be the union of the two graphs.
441 """
442
443 def edge(self, key):
444 a = self.a
445 b = self.b
446 if key in a and key in b:
447 return UnionNode(a.edge(key), b.edge(key))
448 elif key in a:
449 return a.edge(key)
450 else:
451 return b.edge(key)
452
453
454class IntersectionNode(ComboNode):
455 """Makes two graphs appear to be the intersection of the two graphs.
456 """
457
458 def edge(self, key):
459 a = self.a
460 b = self.b
461 if key in a and key in b:
462 return IntersectionNode(a.edge(key), b.edge(key))
463
464
465# Cursor
466
467class BaseCursor(object):
468 """Base class for a cursor-type object for navigating an FST/word graph,
469 represented by a :class:`GraphReader` object.
470
471 >>> cur = GraphReader(dawgfile).cursor()
472 >>> for key in cur.follow():
473 ... print(repr(key))
474
475 The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
476 """
477
478 def is_active(self):
479 """Returns True if this cursor is still active, that is it has not
480 read past the last arc in the graph.
481 """
482
483 raise NotImplementedError
484
485 def label(self):
486 """Returns the label bytes of the current arc.
487 """
488
489 raise NotImplementedError
490
491 def prefix(self):
492 """Returns a sequence of the label bytes for the path from the root
493 to the current arc.
494 """
495
496 raise NotImplementedError
497
498 def prefix_bytes(self):
499 """Returns the label bytes for the path from the root to the current
500 arc as a single joined bytes object.
501 """
502
503 return emptybytes.join(self.prefix())
504
505 def prefix_string(self):
506 """Returns the labels of the path from the root to the current arc as
507 a decoded unicode string.
508 """
509
510 return utf8decode(self.prefix_bytes())[0]
511
512 def peek_key(self):
513 """Returns a sequence of label bytes representing the next closest
514 key in the graph.
515 """
516
517 for label in self.prefix():
518 yield label
519 c = self.copy()
520 while not c.stopped():
521 c.follow()
522 yield c.label()
523
524 def peek_key_bytes(self):
525 """Returns the next closest key in the graph as a single bytes object.
526 """
527
528 return emptybytes.join(self.peek_key())
529
530 def peek_key_string(self):
531 """Returns the next closest key in the graph as a decoded unicode
532 string.
533 """
534
535 return utf8decode(self.peek_key_bytes())[0]
536
537 def stopped(self):
538 """Returns True if the current arc leads to a stop state.
539 """
540
541 raise NotImplementedError
542
543 def value(self):
544 """Returns the value at the current arc, if reading an FST.
545 """
546
547 raise NotImplementedError
548
549 def accept(self):
550 """Returns True if the current arc leads to an accept state (the end
551 of a valid key).
552 """
553
554 raise NotImplementedError
555
556 def at_last_arc(self):
557 """Returns True if the current arc is the last outgoing arc from the
558 previous node.
559 """
560
561 raise NotImplementedError
562
563 def next_arc(self):
564 """Moves to the next outgoing arc from the previous node.
565 """
566
567 raise NotImplementedError
568
569 def follow(self):
570 """Follows the current arc.
571 """
572
573 raise NotImplementedError
574
575 def switch_to(self, label):
576 """Switch to the sibling arc with the given label bytes.
577 """
578
579 _label = self.label
580 _at_last_arc = self.at_last_arc
581 _next_arc = self.next_arc
582
583 while True:
584 thislabel = _label()
585 if thislabel == label:
586 return True
587 if thislabel > label or _at_last_arc():
588 return False
589 _next_arc()
590
591 def skip_to(self, key):
592 """Moves the cursor to the path represented by the given key bytes.
593 """
594
595 _accept = self.accept
596 _prefix = self.prefix
597 _next_arc = self.next_arc
598
599 keylist = list(key)
600 while True:
601 if _accept():
602 thiskey = list(_prefix())
603 if keylist == thiskey:
604 return True
605 elif keylist > thiskey:
606 return False
607 _next_arc()
608
609 def flatten(self):
610 """Yields the keys in the graph, starting at the current position.
611 """
612
613 _is_active = self.is_active
614 _accept = self.accept
615 _stopped = self.stopped
616 _follow = self.follow
617 _next_arc = self.next_arc
618 _prefix_bytes = self.prefix_bytes
619
620 if not _is_active():
621 raise InactiveCursor
622 while _is_active():
623 if _accept():
624 yield _prefix_bytes()
625 if not _stopped():
626 _follow()
627 continue
628 _next_arc()
629
630 def flatten_v(self):
631 """Yields (key, value) tuples in an FST, starting at the current
632 position.
633 """
634
635 for key in self.flatten():
636 yield key, self.value()
637
638 def flatten_strings(self):
639 return (utf8decode(k)[0] for k in self.flatten())
640
641 def find_path(self, path):
642 """Follows the labels in the given path, starting at the current
643 position.
644 """
645
646 path = to_labels(path)
647 _switch_to = self.switch_to
648 _follow = self.follow
649 _stopped = self.stopped
650
651 first = True
652 for i, label in enumerate(path):
653 if not first:
654 _follow()
655 if not _switch_to(label):
656 return False
657 if _stopped():
658 if i < len(path) - 1:
659 return False
660 first = False
661 return True
662
663
664class Cursor(BaseCursor):
665 def __init__(self, graph, root=None, stack=None):
666 self.graph = graph
667 self.vtype = graph.vtype
668 self.root = root if root is not None else graph.default_root()
669 if stack:
670 self.stack = stack
671 else:
672 self.reset()
673
674 def _current_attr(self, name):
675 stack = self.stack
676 if not stack:
677 raise InactiveCursor
678 return getattr(stack[-1], name)
679
680 def is_active(self):
681 return bool(self.stack)
682
683 def stopped(self):
684 return self._current_attr("target") is None
685
686 def accept(self):
687 return self._current_attr("accept")
688
689 def at_last_arc(self):
690 return self._current_attr("lastarc")
691
692 def label(self):
693 return self._current_attr("label")
694
695 def reset(self):
696 self.stack = []
697 self.sums = [None]
698 self._push(self.graph.arc_at(self.root))
699
700 def copy(self):
701 return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
702
703 def prefix(self):
704 stack = self.stack
705 if not stack:
706 raise InactiveCursor
707 return (arc.label for arc in stack)
708
709 # Override: more efficient implementation using graph methods directly
710 def peek_key(self):
711 if not self.stack:
712 raise InactiveCursor
713
714 for label in self.prefix():
715 yield label
716 arc = copy.copy(self.stack[-1])
717 graph = self.graph
718 while not arc.accept and arc.target is not None:
719 graph.arc_at(arc.target, arc)
720 yield arc.label
721
722 def value(self):
723 stack = self.stack
724 if not stack:
725 raise InactiveCursor
726 vtype = self.vtype
727 if not vtype:
728 raise Exception("No value type")
729
730 v = self.sums[-1]
731 current = stack[-1]
732 if current.value:
733 v = vtype.add(v, current.value)
734 if current.accept and current.acceptval is not None:
735 v = vtype.add(v, current.acceptval)
736 return v
737
738 def next_arc(self):
739 stack = self.stack
740 if not stack:
741 raise InactiveCursor
742
743 while stack and stack[-1].lastarc:
744 self.pop()
745 if stack:
746 current = stack[-1]
747 self.graph.arc_at(current.endpos, current)
748 return current
749
750 def follow(self):
751 address = self._current_attr("target")
752 if address is None:
753 raise Exception("Can't follow a stop arc")
754 self._push(self.graph.arc_at(address))
755 return self
756
757 # Override: more efficient implementation manipulating the stack
758 def skip_to(self, key):
759 key = to_labels(key)
760 stack = self.stack
761 if not stack:
762 raise InactiveCursor
763
764 _follow = self.follow
765 _next_arc = self.next_arc
766
767 i = self._pop_to_prefix(key)
768 while stack and i < len(key):
769 curlabel = stack[-1].label
770 keylabel = key[i]
771 if curlabel == keylabel:
772 _follow()
773 i += 1
774 elif curlabel > keylabel:
775 return
776 else:
777 _next_arc()
778
779 # Override: more efficient implementation using find_arc
780 def switch_to(self, label):
781 stack = self.stack
782 if not stack:
783 raise InactiveCursor
784
785 current = stack[-1]
786 if label == current.label:
787 return True
788 else:
789 arc = self.graph.find_arc(current.endpos, label, current)
790 return arc
791
792 def _push(self, arc):
793 if self.vtype and self.stack:
794 sums = self.sums
795 sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
796 self.stack.append(arc)
797
798 def pop(self):
799 self.stack.pop()
800 if self.vtype:
801 self.sums.pop()
802
803 def _pop_to_prefix(self, key):
804 stack = self.stack
805 if not stack:
806 raise InactiveCursor
807
808 i = 0
809 maxpre = min(len(stack), len(key))
810 while i < maxpre and key[i] == stack[i].label:
811 i += 1
812 if stack[i].label > key[i]:
813 self.current = None
814 return
815 while len(stack) > i + 1:
816 self.pop()
817 self.next_arc()
818 return i
819
820
821class UncompiledNode(object):
822 # Represents an "in-memory" node used by the GraphWriter before it is
823 # written to disk.
824
825 compiled = False
826
827 def __init__(self, owner):
828 self.owner = owner
829 self._digest = None
830 self.clear()
831
832 def clear(self):
833 self.arcs = []
834 self.value = None
835 self.accept = False
836 self.inputcount = 0
837
838 def __repr__(self):
839 return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
840
841 def digest(self):
842 if self._digest is None:
843 d = sha1()
844 vtype = self.owner.vtype
845 for arc in self.arcs:
846 d.update(arc.label)
847 if arc.target:
848 d.update(pack_long(arc.target))
849 else:
850 d.update(b("z"))
851 if arc.value:
852 d.update(vtype.to_bytes(arc.value))
853 if arc.accept:
854 d.update(b("T"))
855 self._digest = d.digest()
856 return self._digest
857
858 def edges(self):
859 return self.arcs
860
861 def last_value(self, label):
862 assert self.arcs[-1].label == label
863 return self.arcs[-1].value
864
865 def add_arc(self, label, target):
866 self.arcs.append(Arc(label, target))
867
868 def replace_last(self, label, target, accept, acceptval=None):
869 arc = self.arcs[-1]
870 assert arc.label == label, "%r != %r" % (arc.label, label)
871 arc.target = target
872 arc.accept = accept
873 arc.acceptval = acceptval
874
875 def delete_last(self, label, target):
876 arc = self.arcs.pop()
877 assert arc.label == label
878 assert arc.target == target
879
880 def set_last_value(self, label, value):
881 arc = self.arcs[-1]
882 assert arc.label == label, "%r->%r" % (arc.label, label)
883 arc.value = value
884
885 def prepend_value(self, prefix):
886 add = self.owner.vtype.add
887 for arc in self.arcs:
888 arc.value = add(prefix, arc.value)
889 if self.accept:
890 self.value = add(prefix, self.value)
891
892
893class Arc(object):
894 """
895 Represents a directed arc between two nodes in an FSA/FST graph.
896
897 The ``lastarc`` attribute is True if this is the last outgoing arc from the
898 previous node.
899 """
900
901 __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
902 "endpos")
903
904 def __init__(self, label=None, target=None, value=None, accept=False,
905 acceptval=None):
906 """
907 :param label:The label bytes for this arc. For a word graph, this will
908 be a character.
909 :param target: The address of the node at the endpoint of this arc.
910 :param value: The inner FST value at the endpoint of this arc.
911 :param accept: Whether the endpoint of this arc is an accept state
912 (eg the end of a valid word).
913 :param acceptval: If the endpoint of this arc is an accept state, the
914 final FST value for that accepted state.
915 """
916
917 self.label = label
918 self.target = target
919 self.value = value
920 self.accept = accept
921 self.lastarc = None
922 self.acceptval = acceptval
923 self.endpos = None
924
925 def __repr__(self):
926 return "<%r-%s %s%s>" % (self.label, self.target,
927 "." if self.accept else "",
928 (" %r" % self.value) if self.value else "")
929
930 def __eq__(self, other):
931 if (isinstance(other, self.__class__) and self.accept == other.accept
932 and self.lastarc == other.lastarc and self.target == other.target
933 and self.value == other.value and self.label == other.label):
934 return True
935 return False
936
937
938# Graph writer
939
940class GraphWriter(object):
941 """Writes an FSA/FST graph to disk.
942
943 Call ``insert(key)`` to insert keys into the graph. You must
944 insert keys in sorted order. Call ``close()`` to finish the graph and close
945 the file.
946
947 >>> gw = GraphWriter(my_file)
948 >>> gw.insert("alfa")
949 >>> gw.insert("bravo")
950 >>> gw.insert("charlie")
951 >>> gw.close()
952
953 The graph writer can write separate graphs for multiple fields. Use
954 ``start_field(name)`` and ``finish_field()`` to separate fields.
955
956 >>> gw = GraphWriter(my_file)
957 >>> gw.start_field("content")
958 >>> gw.insert("alfalfa")
959 >>> gw.insert("apple")
960 >>> gw.finish_field()
961 >>> gw.start_field("title")
962 >>> gw.insert("artichoke")
963 >>> gw.finish_field()
964 >>> gw.close()
965 """
966
967 version = 1
968
969 def __init__(self, dbfile, vtype=None, merge=None):
970 """
971 :param dbfile: the file to write to.
972 :param vtype: a :class:`Values` class to use for storing values. This
973 is only necessary if you will be storing values for the keys.
974 :param merge: a function that takes two values and returns a single
975 value. This is called if you insert two identical keys with values.
976 """
977
978 self.dbfile = dbfile
979 self.vtype = vtype
980 self.merge = merge
981 self.fieldroots = {}
982 self.arc_count = 0
983 self.node_count = 0
984 self.fixed_count = 0
985
986 dbfile.write(b("GRPH"))
987 dbfile.write_int(self.version)
988 dbfile.write_uint(0)
989
990 self._infield = False
991
992 def start_field(self, fieldname):
993 """Starts a new graph for the given field.
994 """
995
996 if not fieldname:
997 raise ValueError("Field name cannot be equivalent to False")
998 if self._infield:
999 self.finish_field()
1000 self.fieldname = fieldname
1001 self.seen = {}
1002 self.nodes = [UncompiledNode(self)]
1003 self.lastkey = ''
1004 self._inserted = False
1005 self._infield = True
1006
1007 def finish_field(self):
1008 """Finishes the graph for the current field.
1009 """
1010
1011 if not self._infield:
1012 raise Exception("Called finish_field before start_field")
1013 self._infield = False
1014 if self._inserted:
1015 self.fieldroots[self.fieldname] = self._finish()
1016 self.fieldname = None
1017
1018 def close(self):
1019 """Finishes the current graph and closes the underlying file.
1020 """
1021
1022 if self.fieldname is not None:
1023 self.finish_field()
1024 dbfile = self.dbfile
1025 here = dbfile.tell()
1026 dbfile.write_pickle(self.fieldroots)
1027 dbfile.flush()
1028 dbfile.seek(4 + _INT_SIZE) # Seek past magic and version number
1029 dbfile.write_uint(here)
1030 dbfile.close()
1031
1032 def insert(self, key, value=None):
1033 """Inserts the given key into the graph.
1034
1035 :param key: a sequence of bytes objects, a bytes object, or a string.
1036 :param value: an optional value to encode in the graph along with the
1037 key. If the writer was not instantiated with a value type, passing
1038 a value here will raise an error.
1039 """
1040
1041 if not self._infield:
1042 raise Exception("Inserted %r before starting a field" % key)
1043 self._inserted = True
1044 key = to_labels(key) # Python 3 sucks
1045
1046 vtype = self.vtype
1047 lastkey = self.lastkey
1048 nodes = self.nodes
1049 if len(key) < 1:
1050 raise KeyError("Can't store a null key %r" % (key,))
1051 if lastkey and lastkey > key:
1052 raise KeyError("Keys out of order %r..%r" % (lastkey, key))
1053
1054 # Find the common prefix shared by this key and the previous one
1055 prefixlen = 0
1056 for i in xrange(min(len(lastkey), len(key))):
1057 if lastkey[i] != key[i]:
1058 break
1059 prefixlen += 1
1060 # Compile the nodes after the prefix, since they're not shared
1061 self._freeze_tail(prefixlen + 1)
1062
1063 # Create new nodes for the parts of this key after the shared prefix
1064 for char in key[prefixlen:]:
1065 node = UncompiledNode(self)
1066 # Create an arc to this node on the previous node
1067 nodes[-1].add_arc(char, node)
1068 nodes.append(node)
1069 # Mark the last node as an accept state
1070 lastnode = nodes[-1]
1071 lastnode.accept = True
1072
1073 if vtype:
1074 if value is not None and not vtype.is_valid(value):
1075 raise ValueError("%r is not valid for %s" % (value, vtype))
1076
1077 # Push value commonalities through the tree
1078 common = None
1079 for i in xrange(1, prefixlen + 1):
1080 node = nodes[i]
1081 parent = nodes[i - 1]
1082 lastvalue = parent.last_value(key[i - 1])
1083 if lastvalue is not None:
1084 common = vtype.common(value, lastvalue)
1085 suffix = vtype.subtract(lastvalue, common)
1086 parent.set_last_value(key[i - 1], common)
1087 node.prepend_value(suffix)
1088 else:
1089 common = suffix = None
1090 value = vtype.subtract(value, common)
1091
1092 if key == lastkey:
1093 # If this key is a duplicate, merge its value with the value of
1094 # the previous (same) key
1095 lastnode.value = self.merge(lastnode.value, value)
1096 else:
1097 nodes[prefixlen].set_last_value(key[prefixlen], value)
1098 elif value:
1099 raise Exception("Value %r but no value type" % value)
1100
1101 self.lastkey = key
1102
1103 def _freeze_tail(self, prefixlen):
1104 nodes = self.nodes
1105 lastkey = self.lastkey
1106 downto = max(1, prefixlen)
1107
1108 while len(nodes) > downto:
1109 node = nodes.pop()
1110 parent = nodes[-1]
1111 inlabel = lastkey[len(nodes) - 1]
1112
1113 self._compile_targets(node)
1114 accept = node.accept or len(node.arcs) == 0
1115 address = self._compile_node(node)
1116 parent.replace_last(inlabel, address, accept, node.value)
1117
1118 def _finish(self):
1119 nodes = self.nodes
1120 root = nodes[0]
1121 # Minimize nodes in the last word's suffix
1122 self._freeze_tail(0)
1123 # Compile remaining targets
1124 self._compile_targets(root)
1125 return self._compile_node(root)
1126
1127 def _compile_targets(self, node):
1128 for arc in node.arcs:
1129 if isinstance(arc.target, UncompiledNode):
1130 n = arc.target
1131 if len(n.arcs) == 0:
1132 arc.accept = n.accept = True
1133 arc.target = self._compile_node(n)
1134
1135 def _compile_node(self, uncnode):
1136 seen = self.seen
1137
1138 if len(uncnode.arcs) == 0:
1139 # Leaf node
1140 address = self._write_node(uncnode)
1141 else:
1142 d = uncnode.digest()
1143 address = seen.get(d)
1144 if address is None:
1145 address = self._write_node(uncnode)
1146 seen[d] = address
1147 return address
1148
1149 def _write_node(self, uncnode):
1150 vtype = self.vtype
1151 dbfile = self.dbfile
1152 arcs = uncnode.arcs
1153 numarcs = len(arcs)
1154
1155 if not numarcs:
1156 if uncnode.accept:
1157 return None
1158 else:
1159 # What does it mean for an arc to stop but not be accepted?
1160 raise Exception
1161 self.node_count += 1
1162
1163 buf = StructFile(BytesIO())
1164 nodestart = dbfile.tell()
1165 #self.count += 1
1166 #self.arccount += numarcs
1167
1168 fixedsize = -1
1169 arcstart = buf.tell()
1170 for i, arc in enumerate(arcs):
1171 self.arc_count += 1
1172 target = arc.target
1173 label = arc.label
1174
1175 flags = 0
1176 if len(label) > 1:
1177 flags += MULTIBYTE_LABEL
1178 if i == numarcs - 1:
1179 flags += ARC_LAST
1180 if arc.accept:
1181 flags += ARC_ACCEPT
1182 if target is None:
1183 flags += ARC_STOP
1184 if arc.value is not None:
1185 flags += ARC_HAS_VAL
1186 if arc.acceptval is not None:
1187 flags += ARC_HAS_ACCEPT_VAL
1188
1189 buf.write(pack_byte(flags))
1190 if len(label) > 1:
1191 buf.write(varint(len(label)))
1192 buf.write(label)
1193 if target is not None:
1194 buf.write(pack_uint(target))
1195 if arc.value is not None:
1196 vtype.write(buf, arc.value)
1197 if arc.acceptval is not None:
1198 vtype.write(buf, arc.acceptval)
1199
1200 here = buf.tell()
1201 thissize = here - arcstart
1202 arcstart = here
1203 if fixedsize == -1:
1204 fixedsize = thissize
1205 elif fixedsize > 0 and thissize != fixedsize:
1206 fixedsize = 0
1207
1208 if fixedsize > 0:
1209 # Write a fake arc containing the fixed size and number of arcs
1210 dbfile.write_byte(255) # FIXED_SIZE
1211 dbfile.write_int(fixedsize)
1212 dbfile.write_int(numarcs)
1213 self.fixed_count += 1
1214 dbfile.write(buf.file.getvalue())
1215
1216 return nodestart
1217
1218
1219# Graph reader
1220
1221class BaseGraphReader(object):
1222 def cursor(self, rootname=None):
1223 return Cursor(self, self.root(rootname))
1224
1225 def has_root(self, rootname):
1226 raise NotImplementedError
1227
1228 def root(self, rootname=None):
1229 raise NotImplementedError
1230
1231 # Low level methods
1232
1233 def arc_at(self, address, arc):
1234 raise NotImplementedError
1235
1236 def iter_arcs(self, address, arc=None):
1237 raise NotImplementedError
1238
1239 def find_arc(self, address, label, arc=None):
1240 arc = arc or Arc()
1241 for arc in self.iter_arcs(address, arc):
1242 thislabel = arc.label
1243 if thislabel == label:
1244 return arc
1245 elif thislabel > label:
1246 return None
1247
1248 # Convenience methods
1249
1250 def list_arcs(self, address):
1251 return list(copy.copy(arc) for arc in self.iter_arcs(address))
1252
1253 def arc_dict(self, address):
1254 return dict((arc.label, copy.copy(arc))
1255 for arc in self.iter_arcs(address))
1256
1257 def find_path(self, path, arc=None, address=None):
1258 path = to_labels(path)
1259
1260 if arc:
1261 address = arc.target
1262 else:
1263 arc = Arc()
1264
1265 if address is None:
1266 address = self._root
1267
1268 for label in path:
1269 if address is None:
1270 return None
1271 if not self.find_arc(address, label, arc):
1272 return None
1273 address = arc.target
1274 return arc
1275
1276
1277class GraphReader(BaseGraphReader):
1278 def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
1279 self.dbfile = dbfile
1280 self.vtype = vtype
1281 self.filebase = filebase
1282
1283 dbfile.seek(filebase)
1284 magic = dbfile.read(4)
1285 if magic != b("GRPH"):
1286 raise FileVersionError
1287 self.version = dbfile.read_int()
1288 dbfile.seek(dbfile.read_uint())
1289 self.roots = dbfile.read_pickle()
1290
1291 self._root = None
1292 if rootname is None and len(self.roots) == 1:
1293 # If there's only one root, just use it. Have to wrap a list around
1294 # the keys() method here because of Python 3.
1295 rootname = list(self.roots.keys())[0]
1296 if rootname is not None:
1297 self._root = self.root(rootname)
1298
1299 def close(self):
1300 self.dbfile.close()
1301
1302 # Overrides
1303
1304 def has_root(self, rootname):
1305 return rootname in self.roots
1306
1307 def root(self, rootname=None):
1308 if rootname is None:
1309 return self._root
1310 else:
1311 return self.roots[rootname]
1312
1313 def default_root(self):
1314 return self._root
1315
1316 def arc_at(self, address, arc=None):
1317 arc = arc or Arc()
1318 self.dbfile.seek(address)
1319 return self._read_arc(arc)
1320
1321 def iter_arcs(self, address, arc=None):
1322 arc = arc or Arc()
1323 _read_arc = self._read_arc
1324
1325 self.dbfile.seek(address)
1326 while True:
1327 _read_arc(arc)
1328 yield arc
1329 if arc.lastarc:
1330 break
1331
1332 def find_arc(self, address, label, arc=None):
1333 arc = arc or Arc()
1334 dbfile = self.dbfile
1335 dbfile.seek(address)
1336
1337 # If records are fixed size, we can do a binary search
1338 finfo = self._read_fixed_info()
1339 if finfo:
1340 size, count = finfo
1341 address = dbfile.tell()
1342 if count > 2:
1343 return self._binary_search(address, size, count, label, arc)
1344
1345 # If records aren't fixed size, fall back to the parent's linear
1346 # search method
1347 return BaseGraphReader.find_arc(self, address, label, arc)
1348
1349 # Implementations
1350
1351 def _read_arc(self, toarc=None):
1352 toarc = toarc or Arc()
1353 dbfile = self.dbfile
1354 flags = dbfile.read_byte()
1355 if flags == 255:
1356 # This is a fake arc containing fixed size information; skip it
1357 # and read the next arc
1358 dbfile.seek(_INT_SIZE * 2, 1)
1359 flags = dbfile.read_byte()
1360 toarc.label = self._read_label(flags)
1361 return self._read_arc_data(flags, toarc)
1362
1363 def _read_label(self, flags):
1364 dbfile = self.dbfile
1365 if flags & MULTIBYTE_LABEL:
1366 length = dbfile.read_varint()
1367 else:
1368 length = 1
1369 label = dbfile.read(length)
1370 return label
1371
1372 def _read_fixed_info(self):
1373 dbfile = self.dbfile
1374
1375 flags = dbfile.read_byte()
1376 if flags == 255:
1377 size = dbfile.read_int()
1378 count = dbfile.read_int()
1379 return (size, count)
1380 else:
1381 return None
1382
1383 def _read_arc_data(self, flags, arc):
1384 dbfile = self.dbfile
1385 accept = arc.accept = bool(flags & ARC_ACCEPT)
1386 arc.lastarc = flags & ARC_LAST
1387 if flags & ARC_STOP:
1388 arc.target = None
1389 else:
1390 arc.target = dbfile.read_uint()
1391 if flags & ARC_HAS_VAL:
1392 arc.value = self.vtype.read(dbfile)
1393 else:
1394 arc.value = None
1395 if accept and flags & ARC_HAS_ACCEPT_VAL:
1396 arc.acceptval = self.vtype.read(dbfile)
1397 arc.endpos = dbfile.tell()
1398 return arc
1399
1400 def _binary_search(self, address, size, count, label, arc):
1401 dbfile = self.dbfile
1402 _read_label = self._read_label
1403
1404 lo = 0
1405 hi = count
1406 while lo < hi:
1407 mid = (lo + hi) // 2
1408 midaddr = address + mid * size
1409 dbfile.seek(midaddr)
1410 flags = dbfile.read_byte()
1411 midlabel = self._read_label(flags)
1412 if midlabel == label:
1413 arc.label = midlabel
1414 return self._read_arc_data(flags, arc)
1415 elif midlabel < label:
1416 lo = mid + 1
1417 else:
1418 hi = mid
1419 if lo == count:
1420 return None
1421
1422
1423def to_labels(key):
1424 """Takes a string and returns a list of bytestrings, suitable for use as
1425 a key or path in an FSA/FST graph.
1426 """
1427
1428 # Convert to tuples of bytestrings (must be tuples so they can be hashed)
1429 keytype = type(key)
1430
1431 # I hate the Python 3 bytes object so friggin much
1432 if keytype is tuple or keytype is list:
1433 if not all(isinstance(e, bytes_type) for e in key):
1434 raise TypeError("%r contains a non-bytestring")
1435 if keytype is list:
1436 key = tuple(key)
1437 elif isinstance(key, bytes_type):
1438 key = tuple(key[i:i + 1] for i in xrange(len(key)))
1439 elif isinstance(key, text_type):
1440 key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
1441 else:
1442 raise TypeError("Don't know how to convert %r" % key)
1443 return key
1444
1445
1446# Within edit distance function
1447
1448def within(graph, text, k=1, prefix=0, address=None):
1449 """Yields a series of keys in the given graph within ``k`` edit distance of
1450 ``text``. If ``prefix`` is greater than 0, all keys must match the first
1451 ``prefix`` characters of ``text``.
1452 """
1453
1454 text = to_labels(text)
1455 if address is None:
1456 address = graph._root
1457
1458 sofar = emptybytes
1459 accept = False
1460 if prefix:
1461 prefixchars = text[:prefix]
1462 arc = graph.find_path(prefixchars, address=address)
1463 if arc is None:
1464 return
1465 sofar = emptybytes.join(prefixchars)
1466 address = arc.target
1467 accept = arc.accept
1468
1469 stack = [(address, k, prefix, sofar, accept)]
1470 seen = set()
1471 while stack:
1472 state = stack.pop()
1473 # Have we already tried this state?
1474 if state in seen:
1475 continue
1476 seen.add(state)
1477
1478 address, k, i, sofar, accept = state
1479 # If we're at the end of the text (or deleting enough chars would get
1480 # us to the end and still within K), and we're in the accept state,
1481 # yield the current result
1482 if (len(text) - i <= k) and accept:
1483 yield utf8decode(sofar)[0]
1484
1485 # If we're in the stop state, give up
1486 if address is None:
1487 continue
1488
1489 # Exact match
1490 if i < len(text):
1491 arc = graph.find_arc(address, text[i])
1492 if arc:
1493 stack.append((arc.target, k, i + 1, sofar + text[i],
1494 arc.accept))
1495 # If K is already 0, can't do any more edits
1496 if k < 1:
1497 continue
1498 k -= 1
1499
1500 arcs = graph.arc_dict(address)
1501 # Insertions
1502 stack.extend((arc.target, k, i, sofar + char, arc.accept)
1503 for char, arc in iteritems(arcs))
1504
1505 # Deletion, replacement, and transpo only work before the end
1506 if i >= len(text):
1507 continue
1508 char = text[i]
1509
1510 # Deletion
1511 stack.append((address, k, i + 1, sofar, False))
1512 # Replacement
1513 for char2, arc in iteritems(arcs):
1514 if char2 != char:
1515 stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
1516 # Transposition
1517 if i < len(text) - 1:
1518 char2 = text[i + 1]
1519 if char != char2 and char2 in arcs:
1520 # Find arc from next char to this char
1521 target = arcs[char2].target
1522 if target:
1523 arc = graph.find_arc(target, char)
1524 if arc:
1525 stack.append((arc.target, k, i + 2,
1526 sofar + char2 + char, arc.accept))
1527
1528
1529# Utility functions
1530
1531def dump_graph(graph, address=None, tab=0, out=None):
1532 if address is None:
1533 address = graph._root
1534 if out is None:
1535 out = sys.stdout
1536
1537 here = "%06d" % address
1538 for i, arc in enumerate(graph.list_arcs(address)):
1539 if i == 0:
1540 out.write(here)
1541 else:
1542 out.write(" " * 6)
1543 out.write(" " * tab)
1544 out.write("%r %r %s %r\n"
1545 % (arc.label, arc.target, arc.accept, arc.value))
1546 if arc.target is not None:
1547 dump_graph(graph, arc.target, tab + 1, out=out)
1548
1549