PageRenderTime 89ms CodeModel.GetById 12ms app.highlight 66ms RepoModel.GetById 1ms app.codeStats 0ms

/bangkokhotel/lib/python2.5/site-packages/whoosh/support/dawg.py

https://bitbucket.org/luisrodriguez/bangkokhotel
Python | 1549 lines | 1466 code | 27 blank | 56 comment | 22 complexity | 6233492f368ddc4a5249095d8fbb480d MD5 | raw file
   1# Copyright 2009 Matt Chaput. All rights reserved.
   2#
   3# Redistribution and use in source and binary forms, with or without
   4# modification, are permitted provided that the following conditions are met:
   5#
   6#    1. Redistributions of source code must retain the above copyright notice,
   7#       this list of conditions and the following disclaimer.
   8#
   9#    2. Redistributions in binary form must reproduce the above copyright
  10#       notice, this list of conditions and the following disclaimer in the
  11#       documentation and/or other materials provided with the distribution.
  12#
  13# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
  14# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  15# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
  16# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  17# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  18# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
  19# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  20# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  21# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  22# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  23#
  24# The views and conclusions contained in the software and documentation are
  25# those of the authors and should not be interpreted as representing official
  26# policies, either expressed or implied, of Matt Chaput.
  27
  28"""
  29This module implements an FST/FSA writer and reader. An FST (Finite State
  30Transducer) stores a directed acyclic graph with values associated with the
  31leaves. Common elements of the values are pushed inside the tree. An FST that
  32does not store values is a regular FSA.
  33
  34The format of the leaf values is pluggable using subclasses of the Values
  35class.
  36
  37Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
  38use in (at least) spell checking.
  39"""
  40
  41
  42import sys, copy
  43from array import array
  44from hashlib import sha1  # @UnresolvedImport
  45
  46from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
  47                           bytes_type, text_type, izip, array_tobytes)
  48from whoosh.filedb.structfile import StructFile
  49from whoosh.system import (_INT_SIZE, pack_byte, pack_int, pack_uint,
  50                           pack_long, emptybytes)
  51from whoosh.util import utf8encode, utf8decode, varint
  52
  53
  54class FileVersionError(Exception):
  55    pass
  56
  57
  58class InactiveCursor(Exception):
  59    pass
  60
  61
  62ARC_LAST = 1
  63ARC_ACCEPT = 2
  64ARC_STOP = 4
  65ARC_HAS_VAL = 8
  66ARC_HAS_ACCEPT_VAL = 16
  67MULTIBYTE_LABEL = 32
  68
  69
  70# FST Value types
  71
  72class Values(object):
  73    """Base for classes the describe how to encode and decode FST values.
  74    """
  75
  76    @staticmethod
  77    def is_valid(v):
  78        """Returns True if v is a valid object that can be stored by this
  79        class.
  80        """
  81
  82        raise NotImplementedError
  83
  84    @staticmethod
  85    def common(v1, v2):
  86        """Returns the "common" part of the two values, for whatever "common"
  87        means for this class. For example, a string implementation would return
  88        the common shared prefix, for an int implementation it would return
  89        the minimum of the two numbers.
  90        
  91        If there is no common part, this method should return None.
  92        """
  93
  94        raise NotImplementedError
  95
  96    @staticmethod
  97    def add(prefix, v):
  98        """Adds the given prefix (the result of a call to common()) to the
  99        given value.
 100        """
 101
 102        raise NotImplementedError
 103
 104    @staticmethod
 105    def subtract(v, prefix):
 106        """Subtracts the "common" part (the prefix) from the given value.
 107        """
 108
 109        raise NotImplementedError
 110
 111    @staticmethod
 112    def write(dbfile, v):
 113        """Writes value v to a file.
 114        """
 115
 116        raise NotImplementedError
 117
 118    @staticmethod
 119    def read(dbfile):
 120        """Reads a value from the given file.
 121        """
 122
 123        raise NotImplementedError
 124
 125    @classmethod
 126    def skip(cls, dbfile):
 127        """Skips over a value in the given file.
 128        """
 129
 130        cls.read(dbfile)
 131
 132    @staticmethod
 133    def to_bytes(v):
 134        """Returns a str (Python 2.x) or bytes (Python 3) representation of
 135        the given value. This is used for calculating node digests, so it
 136        should be unique but fast to calculate, and does not have to be
 137        parseable.
 138        """
 139
 140        raise NotImplementedError
 141
 142    @staticmethod
 143    def merge(v1, v2):
 144        raise NotImplementedError
 145
 146
 147class IntValues(Values):
 148    """Stores integer values in an FST.
 149    """
 150
 151    @staticmethod
 152    def is_valid(v):
 153        return isinstance(v, int) and v >= 0
 154
 155    @staticmethod
 156    def common(v1, v2):
 157        if v1 is None or v2 is None:
 158            return None
 159        if v1 == v2:
 160            return v1
 161        return min(v1, v2)
 162
 163    @staticmethod
 164    def add(base, v):
 165        if base is None:
 166            return v
 167        if v is None:
 168            return base
 169        return base + v
 170
 171    @staticmethod
 172    def subtract(v, base):
 173        if v is None:
 174            return None
 175        if base is None:
 176            return v
 177        return v - base
 178
 179    @staticmethod
 180    def write(dbfile, v):
 181        dbfile.write_uint(v)
 182
 183    @staticmethod
 184    def read(dbfile):
 185        return dbfile.read_uint()
 186
 187    @staticmethod
 188    def skip(dbfile):
 189        dbfile.seek(_INT_SIZE, 1)
 190
 191    @staticmethod
 192    def to_bytes(v):
 193        return pack_int(v)
 194
 195
 196class SequenceValues(Values):
 197    """Abstract base class for value types that store sequences.
 198    """
 199
 200    @staticmethod
 201    def is_valid(v):
 202        return isinstance(self, (list, tuple))
 203
 204    @staticmethod
 205    def common(v1, v2):
 206        if v1 is None or v2 is None:
 207            return None
 208
 209        i = 0
 210        while i < len(v1) and i < len(v2):
 211            if v1[i] != v2[i]:
 212                break
 213            i += 1
 214
 215        if i == 0:
 216            return None
 217        if i == len(v1):
 218            return v1
 219        if i == len(v2):
 220            return v2
 221        return v1[:i]
 222
 223    @staticmethod
 224    def add(prefix, v):
 225        if prefix is None:
 226            return v
 227        if v is None:
 228            return prefix
 229        return prefix + v
 230
 231    @staticmethod
 232    def subtract(v, prefix):
 233        if prefix is None:
 234            return v
 235        if v is None:
 236            return None
 237        if len(v) == len(prefix):
 238            return None
 239        if len(v) < len(prefix) or len(prefix) == 0:
 240            raise ValueError((v, prefix))
 241        return v[len(prefix):]
 242
 243    @staticmethod
 244    def write(dbfile, v):
 245        dbfile.write_pickle(v)
 246
 247    @staticmethod
 248    def read(dbfile):
 249        return dbfile.read_pickle()
 250
 251
 252class BytesValues(SequenceValues):
 253    """Stores bytes objects (str in Python 2.x) in an FST.
 254    """
 255
 256    @staticmethod
 257    def is_valid(v):
 258        return isinstance(v, bytes_type)
 259
 260    @staticmethod
 261    def write(dbfile, v):
 262        dbfile.write_int(len(v))
 263        dbfile.write(v)
 264
 265    @staticmethod
 266    def read(dbfile):
 267        length = dbfile.read_int()
 268        return dbfile.read(length)
 269
 270    @staticmethod
 271    def skip(dbfile):
 272        length = dbfile.read_int()
 273        dbfile.seek(length, 1)
 274
 275    @staticmethod
 276    def to_bytes(v):
 277        return v
 278
 279
 280class ArrayValues(SequenceValues):
 281    """Stores array.array objects in an FST.
 282    """
 283
 284    def __init__(self, typecode):
 285        self.typecode = typecode
 286        self.itemsize = array(self.typecode).itemsize
 287
 288    def is_valid(self, v):
 289        return isinstance(v, array) and v.typecode == self.typecode
 290
 291    @staticmethod
 292    def write(dbfile, v):
 293        dbfile.write(b(v.typecode))
 294        dbfile.write_int(len(v))
 295        dbfile.write_array(v)
 296
 297    def read(self, dbfile):
 298        typecode = u(dbfile.read(1))
 299        length = dbfile.read_int()
 300        return dbfile.read_array(self.typecode, length)
 301
 302    def skip(self, dbfile):
 303        length = dbfile.read_int()
 304        dbfile.seek(length * self.itemsize, 1)
 305
 306    @staticmethod
 307    def to_bytes(v):
 308        return array_tobytes(v)
 309
 310
 311class IntListValues(SequenceValues):
 312    """Stores lists of positive, increasing integers (that is, lists of
 313    integers where each number is >= 0 and each number is greater than or equal
 314    to the number that precedes it) in an FST.
 315    """
 316
 317    @staticmethod
 318    def is_valid(v):
 319        if isinstance(v, (list, tuple)):
 320            if len(v) < 2:
 321                return True
 322            for i in xrange(1, len(v)):
 323                if not isinstance(v[i], int) or v[i] < v[i - 1]:
 324                    return False
 325            return True
 326        return False
 327
 328    @staticmethod
 329    def write(dbfile, v):
 330        base = 0
 331        dbfile.write_varint(len(v))
 332        for x in v:
 333            delta = x - base
 334            assert delta >= 0
 335            dbfile.write_varint(delta)
 336            base = x
 337
 338    @staticmethod
 339    def read(dbfile):
 340        length = dbfile.read_varint()
 341        result = []
 342        if length > 0:
 343            base = 0
 344            for _ in xrange(length):
 345                base += dbfile.read_varint()
 346                result.append(base)
 347        return result
 348
 349    @staticmethod
 350    def to_bytes(v):
 351        return b(repr(v))
 352
 353
 354# Node-like interface wrappers
 355
 356class Node(object):
 357    """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
 358    arc-based interface of GraphReader into Node objects with methods to follow
 359    edges.
 360    """
 361
 362    def __init__(self, owner, address, accept=False):
 363        self.owner = owner
 364        self.address = address
 365        self._edges = None
 366        self.accept = accept
 367
 368    def __iter__(self):
 369        if not self._edges:
 370            self._load()
 371        return iterkeys(self._edges)
 372
 373    def __contains__(self, key):
 374        if self._edges is None:
 375            self._load()
 376        return key in self._edges
 377
 378    def _load(self):
 379        owner = self.owner
 380        if self.address is None:
 381            d = {}
 382        else:
 383            d = dict((arc.label, Node(owner, arc.target, arc.accept))
 384                     for arc in self.owner.iter_arcs(self.address))
 385        self._edges = d
 386
 387    def keys(self):
 388        if self._edges is None:
 389            self._load()
 390        return self._edges.keys()
 391
 392    def all_edges(self):
 393        if self._edges is None:
 394            self._load()
 395        return self._edges
 396
 397    def edge(self, key):
 398        if self._edges is None:
 399            self._load()
 400        return self._edges[key]
 401
 402    def flatten(self, sofar=emptybytes):
 403        if self.accept:
 404            yield sofar
 405        for key in sorted(self):
 406            node = self.edge(key)
 407            for result in node.flatten(sofar + key):
 408                yield result
 409
 410    def flatten_strings(self):
 411        return (utf8decode(k)[0] for k in self.flatten())
 412
 413
 414class ComboNode(Node):
 415    """Base class for nodes that blend the nodes of two different graphs.
 416    
 417    Concrete subclasses need to implement the ``edge()`` method and possibly
 418    override the ``accept`` property.
 419    """
 420
 421    def __init__(self, a, b):
 422        self.a = a
 423        self.b = b
 424
 425    def __repr__(self):
 426        return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
 427
 428    def __contains__(self, key):
 429        return key in self.a or key in self.b
 430
 431    def __iter__(self):
 432        return iter(set(self.a) | set(self.b))
 433
 434    @property
 435    def accept(self):
 436        return self.a.accept or self.b.accept
 437
 438
 439class UnionNode(ComboNode):
 440    """Makes two graphs appear to be the union of the two graphs.
 441    """
 442
 443    def edge(self, key):
 444        a = self.a
 445        b = self.b
 446        if key in a and key in b:
 447            return UnionNode(a.edge(key), b.edge(key))
 448        elif key in a:
 449            return a.edge(key)
 450        else:
 451            return b.edge(key)
 452
 453
 454class IntersectionNode(ComboNode):
 455    """Makes two graphs appear to be the intersection of the two graphs.
 456    """
 457
 458    def edge(self, key):
 459        a = self.a
 460        b = self.b
 461        if key in a and key in b:
 462            return IntersectionNode(a.edge(key), b.edge(key))
 463
 464
 465# Cursor
 466
 467class BaseCursor(object):
 468    """Base class for a cursor-type object for navigating an FST/word graph,
 469    represented by a :class:`GraphReader` object.
 470    
 471    >>> cur = GraphReader(dawgfile).cursor()
 472    >>> for key in cur.follow():
 473    ...   print(repr(key))
 474    
 475    The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
 476    """
 477
 478    def is_active(self):
 479        """Returns True if this cursor is still active, that is it has not
 480        read past the last arc in the graph.
 481        """
 482
 483        raise NotImplementedError
 484
 485    def label(self):
 486        """Returns the label bytes of the current arc.
 487        """
 488
 489        raise NotImplementedError
 490
 491    def prefix(self):
 492        """Returns a sequence of the label bytes for the path from the root
 493        to the current arc.
 494        """
 495
 496        raise NotImplementedError
 497
 498    def prefix_bytes(self):
 499        """Returns the label bytes for the path from the root to the current
 500        arc as a single joined bytes object.
 501        """
 502
 503        return emptybytes.join(self.prefix())
 504
 505    def prefix_string(self):
 506        """Returns the labels of the path from the root to the current arc as
 507        a decoded unicode string.
 508        """
 509
 510        return utf8decode(self.prefix_bytes())[0]
 511
 512    def peek_key(self):
 513        """Returns a sequence of label bytes representing the next closest
 514        key in the graph.
 515        """
 516
 517        for label in self.prefix():
 518            yield label
 519        c = self.copy()
 520        while not c.stopped():
 521            c.follow()
 522            yield c.label()
 523
 524    def peek_key_bytes(self):
 525        """Returns the next closest key in the graph as a single bytes object.
 526        """
 527
 528        return emptybytes.join(self.peek_key())
 529
 530    def peek_key_string(self):
 531        """Returns the next closest key in the graph as a decoded unicode
 532        string.
 533        """
 534
 535        return utf8decode(self.peek_key_bytes())[0]
 536
 537    def stopped(self):
 538        """Returns True if the current arc leads to a stop state.
 539        """
 540
 541        raise NotImplementedError
 542
 543    def value(self):
 544        """Returns the value at the current arc, if reading an FST.
 545        """
 546
 547        raise NotImplementedError
 548
 549    def accept(self):
 550        """Returns True if the current arc leads to an accept state (the end
 551        of a valid key).
 552        """
 553
 554        raise NotImplementedError
 555
 556    def at_last_arc(self):
 557        """Returns True if the current arc is the last outgoing arc from the
 558        previous node.
 559        """
 560
 561        raise NotImplementedError
 562
 563    def next_arc(self):
 564        """Moves to the next outgoing arc from the previous node.
 565        """
 566
 567        raise NotImplementedError
 568
 569    def follow(self):
 570        """Follows the current arc.
 571        """
 572
 573        raise NotImplementedError
 574
 575    def switch_to(self, label):
 576        """Switch to the sibling arc with the given label bytes.
 577        """
 578
 579        _label = self.label
 580        _at_last_arc = self.at_last_arc
 581        _next_arc = self.next_arc
 582
 583        while True:
 584            thislabel = _label()
 585            if thislabel == label:
 586                return True
 587            if thislabel > label or _at_last_arc():
 588                return False
 589            _next_arc()
 590
 591    def skip_to(self, key):
 592        """Moves the cursor to the path represented by the given key bytes.
 593        """
 594
 595        _accept = self.accept
 596        _prefix = self.prefix
 597        _next_arc = self.next_arc
 598
 599        keylist = list(key)
 600        while True:
 601            if _accept():
 602                thiskey = list(_prefix())
 603                if keylist == thiskey:
 604                    return True
 605                elif keylist > thiskey:
 606                    return False
 607            _next_arc()
 608
 609    def flatten(self):
 610        """Yields the keys in the graph, starting at the current position.
 611        """
 612
 613        _is_active = self.is_active
 614        _accept = self.accept
 615        _stopped = self.stopped
 616        _follow = self.follow
 617        _next_arc = self.next_arc
 618        _prefix_bytes = self.prefix_bytes
 619
 620        if not _is_active():
 621            raise InactiveCursor
 622        while _is_active():
 623            if _accept():
 624                yield _prefix_bytes()
 625            if not _stopped():
 626                _follow()
 627                continue
 628            _next_arc()
 629
 630    def flatten_v(self):
 631        """Yields (key, value) tuples in an FST, starting at the current
 632        position.
 633        """
 634
 635        for key in self.flatten():
 636            yield key, self.value()
 637
 638    def flatten_strings(self):
 639        return (utf8decode(k)[0] for k in self.flatten())
 640
 641    def find_path(self, path):
 642        """Follows the labels in the given path, starting at the current
 643        position.
 644        """
 645
 646        path = to_labels(path)
 647        _switch_to = self.switch_to
 648        _follow = self.follow
 649        _stopped = self.stopped
 650
 651        first = True
 652        for i, label in enumerate(path):
 653            if not first:
 654                _follow()
 655            if not _switch_to(label):
 656                return False
 657            if _stopped():
 658                if i < len(path) - 1:
 659                    return False
 660            first = False
 661        return True
 662
 663
 664class Cursor(BaseCursor):
 665    def __init__(self, graph, root=None, stack=None):
 666        self.graph = graph
 667        self.vtype = graph.vtype
 668        self.root = root if root is not None else graph.default_root()
 669        if stack:
 670            self.stack = stack
 671        else:
 672            self.reset()
 673
 674    def _current_attr(self, name):
 675        stack = self.stack
 676        if not stack:
 677            raise InactiveCursor
 678        return getattr(stack[-1], name)
 679
 680    def is_active(self):
 681        return bool(self.stack)
 682
 683    def stopped(self):
 684        return self._current_attr("target") is None
 685
 686    def accept(self):
 687        return self._current_attr("accept")
 688
 689    def at_last_arc(self):
 690        return self._current_attr("lastarc")
 691
 692    def label(self):
 693        return self._current_attr("label")
 694
 695    def reset(self):
 696        self.stack = []
 697        self.sums = [None]
 698        self._push(self.graph.arc_at(self.root))
 699
 700    def copy(self):
 701        return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
 702
 703    def prefix(self):
 704        stack = self.stack
 705        if not stack:
 706            raise InactiveCursor
 707        return (arc.label for arc in stack)
 708
 709    # Override: more efficient implementation using graph methods directly
 710    def peek_key(self):
 711        if not self.stack:
 712            raise InactiveCursor
 713
 714        for label in self.prefix():
 715            yield label
 716        arc = copy.copy(self.stack[-1])
 717        graph = self.graph
 718        while not arc.accept and arc.target is not None:
 719            graph.arc_at(arc.target, arc)
 720            yield arc.label
 721
 722    def value(self):
 723        stack = self.stack
 724        if not stack:
 725            raise InactiveCursor
 726        vtype = self.vtype
 727        if not vtype:
 728            raise Exception("No value type")
 729
 730        v = self.sums[-1]
 731        current = stack[-1]
 732        if current.value:
 733            v = vtype.add(v, current.value)
 734        if current.accept and current.acceptval is not None:
 735            v = vtype.add(v, current.acceptval)
 736        return v
 737
 738    def next_arc(self):
 739        stack = self.stack
 740        if not stack:
 741            raise InactiveCursor
 742
 743        while stack and stack[-1].lastarc:
 744            self.pop()
 745        if stack:
 746            current = stack[-1]
 747            self.graph.arc_at(current.endpos, current)
 748            return current
 749
 750    def follow(self):
 751        address = self._current_attr("target")
 752        if address is None:
 753            raise Exception("Can't follow a stop arc")
 754        self._push(self.graph.arc_at(address))
 755        return self
 756
 757    # Override: more efficient implementation manipulating the stack
 758    def skip_to(self, key):
 759        key = to_labels(key)
 760        stack = self.stack
 761        if not stack:
 762            raise InactiveCursor
 763
 764        _follow = self.follow
 765        _next_arc = self.next_arc
 766
 767        i = self._pop_to_prefix(key)
 768        while stack and i < len(key):
 769            curlabel = stack[-1].label
 770            keylabel = key[i]
 771            if curlabel == keylabel:
 772                _follow()
 773                i += 1
 774            elif curlabel > keylabel:
 775                return
 776            else:
 777                _next_arc()
 778
 779    # Override: more efficient implementation using find_arc
 780    def switch_to(self, label):
 781        stack = self.stack
 782        if not stack:
 783            raise InactiveCursor
 784
 785        current = stack[-1]
 786        if label == current.label:
 787            return True
 788        else:
 789            arc = self.graph.find_arc(current.endpos, label, current)
 790            return arc
 791
 792    def _push(self, arc):
 793        if self.vtype and self.stack:
 794            sums = self.sums
 795            sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
 796        self.stack.append(arc)
 797
 798    def pop(self):
 799        self.stack.pop()
 800        if self.vtype:
 801            self.sums.pop()
 802
 803    def _pop_to_prefix(self, key):
 804        stack = self.stack
 805        if not stack:
 806            raise InactiveCursor
 807
 808        i = 0
 809        maxpre = min(len(stack), len(key))
 810        while i < maxpre and key[i] == stack[i].label:
 811            i += 1
 812        if stack[i].label > key[i]:
 813            self.current = None
 814            return
 815        while len(stack) > i + 1:
 816            self.pop()
 817        self.next_arc()
 818        return i
 819
 820
 821class UncompiledNode(object):
 822    # Represents an "in-memory" node used by the GraphWriter before it is
 823    # written to disk.
 824
 825    compiled = False
 826
 827    def __init__(self, owner):
 828        self.owner = owner
 829        self._digest = None
 830        self.clear()
 831
 832    def clear(self):
 833        self.arcs = []
 834        self.value = None
 835        self.accept = False
 836        self.inputcount = 0
 837
 838    def __repr__(self):
 839        return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
 840
 841    def digest(self):
 842        if self._digest is None:
 843            d = sha1()
 844            vtype = self.owner.vtype
 845            for arc in self.arcs:
 846                d.update(arc.label)
 847                if arc.target:
 848                    d.update(pack_long(arc.target))
 849                else:
 850                    d.update(b("z"))
 851                if arc.value:
 852                    d.update(vtype.to_bytes(arc.value))
 853                if arc.accept:
 854                    d.update(b("T"))
 855            self._digest = d.digest()
 856        return self._digest
 857
 858    def edges(self):
 859        return self.arcs
 860
 861    def last_value(self, label):
 862        assert self.arcs[-1].label == label
 863        return self.arcs[-1].value
 864
 865    def add_arc(self, label, target):
 866        self.arcs.append(Arc(label, target))
 867
 868    def replace_last(self, label, target, accept, acceptval=None):
 869        arc = self.arcs[-1]
 870        assert arc.label == label, "%r != %r" % (arc.label, label)
 871        arc.target = target
 872        arc.accept = accept
 873        arc.acceptval = acceptval
 874
 875    def delete_last(self, label, target):
 876        arc = self.arcs.pop()
 877        assert arc.label == label
 878        assert arc.target == target
 879
 880    def set_last_value(self, label, value):
 881        arc = self.arcs[-1]
 882        assert arc.label == label, "%r->%r" % (arc.label, label)
 883        arc.value = value
 884
 885    def prepend_value(self, prefix):
 886        add = self.owner.vtype.add
 887        for arc in self.arcs:
 888            arc.value = add(prefix, arc.value)
 889        if self.accept:
 890            self.value = add(prefix, self.value)
 891
 892
 893class Arc(object):
 894    """
 895    Represents a directed arc between two nodes in an FSA/FST graph.
 896    
 897    The ``lastarc`` attribute is True if this is the last outgoing arc from the
 898    previous node.
 899    """
 900
 901    __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
 902                 "endpos")
 903
 904    def __init__(self, label=None, target=None, value=None, accept=False,
 905                 acceptval=None):
 906        """
 907        :param label:The label bytes for this arc. For a word graph, this will
 908            be a character.
 909        :param target: The address of the node at the endpoint of this arc.
 910        :param value: The inner FST value at the endpoint of this arc.
 911        :param accept: Whether the endpoint of this arc is an accept state
 912            (eg the end of a valid word).
 913        :param acceptval: If the endpoint of this arc is an accept state, the
 914            final FST value for that accepted state.
 915        """
 916
 917        self.label = label
 918        self.target = target
 919        self.value = value
 920        self.accept = accept
 921        self.lastarc = None
 922        self.acceptval = acceptval
 923        self.endpos = None
 924
 925    def __repr__(self):
 926        return "<%r-%s %s%s>" % (self.label, self.target,
 927                                 "." if self.accept else "",
 928                                 (" %r" % self.value) if self.value else "")
 929
 930    def __eq__(self, other):
 931        if (isinstance(other, self.__class__) and self.accept == other.accept
 932            and self.lastarc == other.lastarc and self.target == other.target
 933            and self.value == other.value and self.label == other.label):
 934            return True
 935        return False
 936
 937
 938# Graph writer
 939
 940class GraphWriter(object):
 941    """Writes an FSA/FST graph to disk.
 942    
 943    Call ``insert(key)`` to insert keys into the graph. You must
 944    insert keys in sorted order. Call ``close()`` to finish the graph and close
 945    the file.
 946    
 947    >>> gw = GraphWriter(my_file)
 948    >>> gw.insert("alfa")
 949    >>> gw.insert("bravo")
 950    >>> gw.insert("charlie")
 951    >>> gw.close()
 952    
 953    The graph writer can write separate graphs for multiple fields. Use
 954    ``start_field(name)`` and ``finish_field()`` to separate fields.
 955    
 956    >>> gw = GraphWriter(my_file)
 957    >>> gw.start_field("content")
 958    >>> gw.insert("alfalfa")
 959    >>> gw.insert("apple")
 960    >>> gw.finish_field()
 961    >>> gw.start_field("title")
 962    >>> gw.insert("artichoke")
 963    >>> gw.finish_field()
 964    >>> gw.close()
 965    """
 966
 967    version = 1
 968
 969    def __init__(self, dbfile, vtype=None, merge=None):
 970        """
 971        :param dbfile: the file to write to.
 972        :param vtype: a :class:`Values` class to use for storing values. This
 973            is only necessary if you will be storing values for the keys.
 974        :param merge: a function that takes two values and returns a single
 975            value. This is called if you insert two identical keys with values.
 976        """
 977
 978        self.dbfile = dbfile
 979        self.vtype = vtype
 980        self.merge = merge
 981        self.fieldroots = {}
 982        self.arc_count = 0
 983        self.node_count = 0
 984        self.fixed_count = 0
 985
 986        dbfile.write(b("GRPH"))
 987        dbfile.write_int(self.version)
 988        dbfile.write_uint(0)
 989
 990        self._infield = False
 991
 992    def start_field(self, fieldname):
 993        """Starts a new graph for the given field.
 994        """
 995
 996        if not fieldname:
 997            raise ValueError("Field name cannot be equivalent to False")
 998        if self._infield:
 999            self.finish_field()
1000        self.fieldname = fieldname
1001        self.seen = {}
1002        self.nodes = [UncompiledNode(self)]
1003        self.lastkey = ''
1004        self._inserted = False
1005        self._infield = True
1006
1007    def finish_field(self):
1008        """Finishes the graph for the current field.
1009        """
1010
1011        if not self._infield:
1012            raise Exception("Called finish_field before start_field")
1013        self._infield = False
1014        if self._inserted:
1015            self.fieldroots[self.fieldname] = self._finish()
1016        self.fieldname = None
1017
1018    def close(self):
1019        """Finishes the current graph and closes the underlying file.
1020        """
1021
1022        if self.fieldname is not None:
1023            self.finish_field()
1024        dbfile = self.dbfile
1025        here = dbfile.tell()
1026        dbfile.write_pickle(self.fieldroots)
1027        dbfile.flush()
1028        dbfile.seek(4 + _INT_SIZE)  # Seek past magic and version number
1029        dbfile.write_uint(here)
1030        dbfile.close()
1031
1032    def insert(self, key, value=None):
1033        """Inserts the given key into the graph.
1034        
1035        :param key: a sequence of bytes objects, a bytes object, or a string.
1036        :param value: an optional value to encode in the graph along with the
1037            key. If the writer was not instantiated with a value type, passing
1038            a value here will raise an error.
1039        """
1040
1041        if not self._infield:
1042            raise Exception("Inserted %r before starting a field" % key)
1043        self._inserted = True
1044        key = to_labels(key)  # Python 3 sucks
1045
1046        vtype = self.vtype
1047        lastkey = self.lastkey
1048        nodes = self.nodes
1049        if len(key) < 1:
1050            raise KeyError("Can't store a null key %r" % (key,))
1051        if lastkey and lastkey > key:
1052            raise KeyError("Keys out of order %r..%r" % (lastkey, key))
1053
1054        # Find the common prefix shared by this key and the previous one
1055        prefixlen = 0
1056        for i in xrange(min(len(lastkey), len(key))):
1057            if lastkey[i] != key[i]:
1058                break
1059            prefixlen += 1
1060        # Compile the nodes after the prefix, since they're not shared
1061        self._freeze_tail(prefixlen + 1)
1062
1063        # Create new nodes for the parts of this key after the shared prefix
1064        for char in key[prefixlen:]:
1065            node = UncompiledNode(self)
1066            # Create an arc to this node on the previous node
1067            nodes[-1].add_arc(char, node)
1068            nodes.append(node)
1069        # Mark the last node as an accept state
1070        lastnode = nodes[-1]
1071        lastnode.accept = True
1072
1073        if vtype:
1074            if value is not None and not vtype.is_valid(value):
1075                raise ValueError("%r is not valid for %s" % (value, vtype))
1076
1077            # Push value commonalities through the tree
1078            common = None
1079            for i in xrange(1, prefixlen + 1):
1080                node = nodes[i]
1081                parent = nodes[i - 1]
1082                lastvalue = parent.last_value(key[i - 1])
1083                if lastvalue is not None:
1084                    common = vtype.common(value, lastvalue)
1085                    suffix = vtype.subtract(lastvalue, common)
1086                    parent.set_last_value(key[i - 1], common)
1087                    node.prepend_value(suffix)
1088                else:
1089                    common = suffix = None
1090                value = vtype.subtract(value, common)
1091
1092            if key == lastkey:
1093                # If this key is a duplicate, merge its value with the value of
1094                # the previous (same) key
1095                lastnode.value = self.merge(lastnode.value, value)
1096            else:
1097                nodes[prefixlen].set_last_value(key[prefixlen], value)
1098        elif value:
1099            raise Exception("Value %r but no value type" % value)
1100
1101        self.lastkey = key
1102
1103    def _freeze_tail(self, prefixlen):
1104        nodes = self.nodes
1105        lastkey = self.lastkey
1106        downto = max(1, prefixlen)
1107
1108        while len(nodes) > downto:
1109            node = nodes.pop()
1110            parent = nodes[-1]
1111            inlabel = lastkey[len(nodes) - 1]
1112
1113            self._compile_targets(node)
1114            accept = node.accept or len(node.arcs) == 0
1115            address = self._compile_node(node)
1116            parent.replace_last(inlabel, address, accept, node.value)
1117
1118    def _finish(self):
1119        nodes = self.nodes
1120        root = nodes[0]
1121        # Minimize nodes in the last word's suffix
1122        self._freeze_tail(0)
1123        # Compile remaining targets
1124        self._compile_targets(root)
1125        return self._compile_node(root)
1126
1127    def _compile_targets(self, node):
1128        for arc in node.arcs:
1129            if isinstance(arc.target, UncompiledNode):
1130                n = arc.target
1131                if len(n.arcs) == 0:
1132                    arc.accept = n.accept = True
1133                arc.target = self._compile_node(n)
1134
1135    def _compile_node(self, uncnode):
1136        seen = self.seen
1137
1138        if len(uncnode.arcs) == 0:
1139            # Leaf node
1140            address = self._write_node(uncnode)
1141        else:
1142            d = uncnode.digest()
1143            address = seen.get(d)
1144            if address is None:
1145                address = self._write_node(uncnode)
1146                seen[d] = address
1147        return address
1148
1149    def _write_node(self, uncnode):
1150        vtype = self.vtype
1151        dbfile = self.dbfile
1152        arcs = uncnode.arcs
1153        numarcs = len(arcs)
1154
1155        if not numarcs:
1156            if uncnode.accept:
1157                return None
1158            else:
1159                # What does it mean for an arc to stop but not be accepted?
1160                raise Exception
1161        self.node_count += 1
1162
1163        buf = StructFile(BytesIO())
1164        nodestart = dbfile.tell()
1165        #self.count += 1
1166        #self.arccount += numarcs
1167
1168        fixedsize = -1
1169        arcstart = buf.tell()
1170        for i, arc in enumerate(arcs):
1171            self.arc_count += 1
1172            target = arc.target
1173            label = arc.label
1174
1175            flags = 0
1176            if len(label) > 1:
1177                flags += MULTIBYTE_LABEL
1178            if i == numarcs - 1:
1179                flags += ARC_LAST
1180            if arc.accept:
1181                flags += ARC_ACCEPT
1182            if target is None:
1183                flags += ARC_STOP
1184            if arc.value is not None:
1185                flags += ARC_HAS_VAL
1186            if arc.acceptval is not None:
1187                flags += ARC_HAS_ACCEPT_VAL
1188
1189            buf.write(pack_byte(flags))
1190            if len(label) > 1:
1191                buf.write(varint(len(label)))
1192            buf.write(label)
1193            if target is not None:
1194                buf.write(pack_uint(target))
1195            if arc.value is not None:
1196                vtype.write(buf, arc.value)
1197            if arc.acceptval is not None:
1198                vtype.write(buf, arc.acceptval)
1199
1200            here = buf.tell()
1201            thissize = here - arcstart
1202            arcstart = here
1203            if fixedsize == -1:
1204                fixedsize = thissize
1205            elif fixedsize > 0 and thissize != fixedsize:
1206                fixedsize = 0
1207
1208        if fixedsize > 0:
1209            # Write a fake arc containing the fixed size and number of arcs
1210            dbfile.write_byte(255)  # FIXED_SIZE
1211            dbfile.write_int(fixedsize)
1212            dbfile.write_int(numarcs)
1213            self.fixed_count += 1
1214        dbfile.write(buf.file.getvalue())
1215
1216        return nodestart
1217
1218
1219# Graph reader
1220
1221class BaseGraphReader(object):
1222    def cursor(self, rootname=None):
1223        return Cursor(self, self.root(rootname))
1224
1225    def has_root(self, rootname):
1226        raise NotImplementedError
1227
1228    def root(self, rootname=None):
1229        raise NotImplementedError
1230
1231    # Low level methods
1232
1233    def arc_at(self, address, arc):
1234        raise NotImplementedError
1235
1236    def iter_arcs(self, address, arc=None):
1237        raise NotImplementedError
1238
1239    def find_arc(self, address, label, arc=None):
1240        arc = arc or Arc()
1241        for arc in self.iter_arcs(address, arc):
1242            thislabel = arc.label
1243            if thislabel == label:
1244                return arc
1245            elif thislabel > label:
1246                return None
1247
1248    # Convenience methods
1249
1250    def list_arcs(self, address):
1251        return list(copy.copy(arc) for arc in self.iter_arcs(address))
1252
1253    def arc_dict(self, address):
1254        return dict((arc.label, copy.copy(arc))
1255                    for arc in self.iter_arcs(address))
1256
1257    def find_path(self, path, arc=None, address=None):
1258        path = to_labels(path)
1259
1260        if arc:
1261            address = arc.target
1262        else:
1263            arc = Arc()
1264
1265        if address is None:
1266            address = self._root
1267
1268        for label in path:
1269            if address is None:
1270                return None
1271            if not self.find_arc(address, label, arc):
1272                return None
1273            address = arc.target
1274        return arc
1275
1276
1277class GraphReader(BaseGraphReader):
1278    def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
1279        self.dbfile = dbfile
1280        self.vtype = vtype
1281        self.filebase = filebase
1282
1283        dbfile.seek(filebase)
1284        magic = dbfile.read(4)
1285        if magic != b("GRPH"):
1286            raise FileVersionError
1287        self.version = dbfile.read_int()
1288        dbfile.seek(dbfile.read_uint())
1289        self.roots = dbfile.read_pickle()
1290
1291        self._root = None
1292        if rootname is None and len(self.roots) == 1:
1293            # If there's only one root, just use it. Have to wrap a list around
1294            # the keys() method here because of Python 3.
1295            rootname = list(self.roots.keys())[0]
1296        if rootname is not None:
1297            self._root = self.root(rootname)
1298
1299    def close(self):
1300        self.dbfile.close()
1301
1302    # Overrides
1303
1304    def has_root(self, rootname):
1305        return rootname in self.roots
1306
1307    def root(self, rootname=None):
1308        if rootname is None:
1309            return self._root
1310        else:
1311            return self.roots[rootname]
1312
1313    def default_root(self):
1314        return self._root
1315
1316    def arc_at(self, address, arc=None):
1317        arc = arc or Arc()
1318        self.dbfile.seek(address)
1319        return self._read_arc(arc)
1320
1321    def iter_arcs(self, address, arc=None):
1322        arc = arc or Arc()
1323        _read_arc = self._read_arc
1324
1325        self.dbfile.seek(address)
1326        while True:
1327            _read_arc(arc)
1328            yield arc
1329            if arc.lastarc:
1330                break
1331
1332    def find_arc(self, address, label, arc=None):
1333        arc = arc or Arc()
1334        dbfile = self.dbfile
1335        dbfile.seek(address)
1336
1337        # If records are fixed size, we can do a binary search
1338        finfo = self._read_fixed_info()
1339        if finfo:
1340            size, count = finfo
1341            address = dbfile.tell()
1342            if count > 2:
1343                return self._binary_search(address, size, count, label, arc)
1344
1345        # If records aren't fixed size, fall back to the parent's linear
1346        # search method
1347        return BaseGraphReader.find_arc(self, address, label, arc)
1348
1349    # Implementations
1350
1351    def _read_arc(self, toarc=None):
1352        toarc = toarc or Arc()
1353        dbfile = self.dbfile
1354        flags = dbfile.read_byte()
1355        if flags == 255:
1356            # This is a fake arc containing fixed size information; skip it
1357            # and read the next arc
1358            dbfile.seek(_INT_SIZE * 2, 1)
1359            flags = dbfile.read_byte()
1360        toarc.label = self._read_label(flags)
1361        return self._read_arc_data(flags, toarc)
1362
1363    def _read_label(self, flags):
1364        dbfile = self.dbfile
1365        if flags & MULTIBYTE_LABEL:
1366            length = dbfile.read_varint()
1367        else:
1368            length = 1
1369        label = dbfile.read(length)
1370        return label
1371
1372    def _read_fixed_info(self):
1373        dbfile = self.dbfile
1374
1375        flags = dbfile.read_byte()
1376        if flags == 255:
1377            size = dbfile.read_int()
1378            count = dbfile.read_int()
1379            return (size, count)
1380        else:
1381            return None
1382
1383    def _read_arc_data(self, flags, arc):
1384        dbfile = self.dbfile
1385        accept = arc.accept = bool(flags & ARC_ACCEPT)
1386        arc.lastarc = flags & ARC_LAST
1387        if flags & ARC_STOP:
1388            arc.target = None
1389        else:
1390            arc.target = dbfile.read_uint()
1391        if flags & ARC_HAS_VAL:
1392            arc.value = self.vtype.read(dbfile)
1393        else:
1394            arc.value = None
1395        if accept and flags & ARC_HAS_ACCEPT_VAL:
1396            arc.acceptval = self.vtype.read(dbfile)
1397        arc.endpos = dbfile.tell()
1398        return arc
1399
1400    def _binary_search(self, address, size, count, label, arc):
1401        dbfile = self.dbfile
1402        _read_label = self._read_label
1403
1404        lo = 0
1405        hi = count
1406        while lo < hi:
1407            mid = (lo + hi) // 2
1408            midaddr = address + mid * size
1409            dbfile.seek(midaddr)
1410            flags = dbfile.read_byte()
1411            midlabel = self._read_label(flags)
1412            if midlabel == label:
1413                arc.label = midlabel
1414                return self._read_arc_data(flags, arc)
1415            elif midlabel < label:
1416                lo = mid + 1
1417            else:
1418                hi = mid
1419        if lo == count:
1420            return None
1421
1422
1423def to_labels(key):
1424    """Takes a string and returns a list of bytestrings, suitable for use as
1425    a key or path in an FSA/FST graph.
1426    """
1427
1428    # Convert to tuples of bytestrings (must be tuples so they can be hashed)
1429    keytype = type(key)
1430
1431    # I hate the Python 3 bytes object so friggin much
1432    if keytype is tuple or keytype is list:
1433        if not all(isinstance(e, bytes_type) for e in key):
1434            raise TypeError("%r contains a non-bytestring")
1435        if keytype is list:
1436            key = tuple(key)
1437    elif isinstance(key, bytes_type):
1438        key = tuple(key[i:i + 1] for i in xrange(len(key)))
1439    elif isinstance(key, text_type):
1440        key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
1441    else:
1442        raise TypeError("Don't know how to convert %r" % key)
1443    return key
1444
1445
1446# Within edit distance function
1447
1448def within(graph, text, k=1, prefix=0, address=None):
1449    """Yields a series of keys in the given graph within ``k`` edit distance of
1450    ``text``. If ``prefix`` is greater than 0, all keys must match the first
1451    ``prefix`` characters of ``text``.
1452    """
1453
1454    text = to_labels(text)
1455    if address is None:
1456        address = graph._root
1457
1458    sofar = emptybytes
1459    accept = False
1460    if prefix:
1461        prefixchars = text[:prefix]
1462        arc = graph.find_path(prefixchars, address=address)
1463        if arc is None:
1464            return
1465        sofar = emptybytes.join(prefixchars)
1466        address = arc.target
1467        accept = arc.accept
1468
1469    stack = [(address, k, prefix, sofar, accept)]
1470    seen = set()
1471    while stack:
1472        state = stack.pop()
1473        # Have we already tried this state?
1474        if state in seen:
1475            continue
1476        seen.add(state)
1477
1478        address, k, i, sofar, accept = state
1479        # If we're at the end of the text (or deleting enough chars would get
1480        # us to the end and still within K), and we're in the accept state,
1481        # yield the current result
1482        if (len(text) - i <= k) and accept:
1483            yield utf8decode(sofar)[0]
1484
1485        # If we're in the stop state, give up
1486        if address is None:
1487            continue
1488
1489        # Exact match
1490        if i < len(text):
1491            arc = graph.find_arc(address, text[i])
1492            if arc:
1493                stack.append((arc.target, k, i + 1, sofar + text[i],
1494                              arc.accept))
1495        # If K is already 0, can't do any more edits
1496        if k < 1:
1497            continue
1498        k -= 1
1499
1500        arcs = graph.arc_dict(address)
1501        # Insertions
1502        stack.extend((arc.target, k, i, sofar + char, arc.accept)
1503                     for char, arc in iteritems(arcs))
1504
1505        # Deletion, replacement, and transpo only work before the end
1506        if i >= len(text):
1507            continue
1508        char = text[i]
1509
1510        # Deletion
1511        stack.append((address, k, i + 1, sofar, False))
1512        # Replacement
1513        for char2, arc in iteritems(arcs):
1514            if char2 != char:
1515                stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
1516        # Transposition
1517        if i < len(text) - 1:
1518            char2 = text[i + 1]
1519            if char != char2 and char2 in arcs:
1520                # Find arc from next char to this char
1521                target = arcs[char2].target
1522                if target:
1523                    arc = graph.find_arc(target, char)
1524                    if arc:
1525                        stack.append((arc.target, k, i + 2,
1526                                      sofar + char2 + char, arc.accept))
1527
1528
1529# Utility functions
1530
1531def dump_graph(graph, address=None, tab=0, out=None):
1532    if address is None:
1533        address = graph._root
1534    if out is None:
1535        out = sys.stdout
1536
1537    here = "%06d" % address
1538    for i, arc in enumerate(graph.list_arcs(address)):
1539        if i == 0:
1540            out.write(here)
1541        else:
1542            out.write(" " * 6)
1543        out.write("  " * tab)
1544        out.write("%r %r %s %r\n"
1545                  % (arc.label, arc.target, arc.accept, arc.value))
1546        if arc.target is not None:
1547            dump_graph(graph, arc.target, tab + 1, out=out)
1548
1549