/bangkokhotel/lib/python2.5/site-packages/whoosh/support/dawg.py

https://bitbucket.org/luisrodriguez/bangkokhotel · Python · 1549 lines · 1019 code · 290 blank · 240 comment · 255 complexity · 6233492f368ddc4a5249095d8fbb480d MD5 · raw file

  1. # Copyright 2009 Matt Chaput. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions are met:
  5. #
  6. # 1. Redistributions of source code must retain the above copyright notice,
  7. # this list of conditions and the following disclaimer.
  8. #
  9. # 2. Redistributions in binary form must reproduce the above copyright
  10. # notice, this list of conditions and the following disclaimer in the
  11. # documentation and/or other materials provided with the distribution.
  12. #
  13. # THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
  14. # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  15. # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
  16. # EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  17. # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  18. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
  19. # OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  20. # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  21. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  22. # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  23. #
  24. # The views and conclusions contained in the software and documentation are
  25. # those of the authors and should not be interpreted as representing official
  26. # policies, either expressed or implied, of Matt Chaput.
  27. """
  28. This module implements an FST/FSA writer and reader. An FST (Finite State
  29. Transducer) stores a directed acyclic graph with values associated with the
  30. leaves. Common elements of the values are pushed inside the tree. An FST that
  31. does not store values is a regular FSA.
  32. The format of the leaf values is pluggable using subclasses of the Values
  33. class.
  34. Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
  35. use in (at least) spell checking.
  36. """
  37. import sys, copy
  38. from array import array
  39. from hashlib import sha1 # @UnresolvedImport
  40. from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
  41. bytes_type, text_type, izip, array_tobytes)
  42. from whoosh.filedb.structfile import StructFile
  43. from whoosh.system import (_INT_SIZE, pack_byte, pack_int, pack_uint,
  44. pack_long, emptybytes)
  45. from whoosh.util import utf8encode, utf8decode, varint
  46. class FileVersionError(Exception):
  47. pass
  48. class InactiveCursor(Exception):
  49. pass
  50. ARC_LAST = 1
  51. ARC_ACCEPT = 2
  52. ARC_STOP = 4
  53. ARC_HAS_VAL = 8
  54. ARC_HAS_ACCEPT_VAL = 16
  55. MULTIBYTE_LABEL = 32
  56. # FST Value types
  57. class Values(object):
  58. """Base for classes the describe how to encode and decode FST values.
  59. """
  60. @staticmethod
  61. def is_valid(v):
  62. """Returns True if v is a valid object that can be stored by this
  63. class.
  64. """
  65. raise NotImplementedError
  66. @staticmethod
  67. def common(v1, v2):
  68. """Returns the "common" part of the two values, for whatever "common"
  69. means for this class. For example, a string implementation would return
  70. the common shared prefix, for an int implementation it would return
  71. the minimum of the two numbers.
  72. If there is no common part, this method should return None.
  73. """
  74. raise NotImplementedError
  75. @staticmethod
  76. def add(prefix, v):
  77. """Adds the given prefix (the result of a call to common()) to the
  78. given value.
  79. """
  80. raise NotImplementedError
  81. @staticmethod
  82. def subtract(v, prefix):
  83. """Subtracts the "common" part (the prefix) from the given value.
  84. """
  85. raise NotImplementedError
  86. @staticmethod
  87. def write(dbfile, v):
  88. """Writes value v to a file.
  89. """
  90. raise NotImplementedError
  91. @staticmethod
  92. def read(dbfile):
  93. """Reads a value from the given file.
  94. """
  95. raise NotImplementedError
  96. @classmethod
  97. def skip(cls, dbfile):
  98. """Skips over a value in the given file.
  99. """
  100. cls.read(dbfile)
  101. @staticmethod
  102. def to_bytes(v):
  103. """Returns a str (Python 2.x) or bytes (Python 3) representation of
  104. the given value. This is used for calculating node digests, so it
  105. should be unique but fast to calculate, and does not have to be
  106. parseable.
  107. """
  108. raise NotImplementedError
  109. @staticmethod
  110. def merge(v1, v2):
  111. raise NotImplementedError
  112. class IntValues(Values):
  113. """Stores integer values in an FST.
  114. """
  115. @staticmethod
  116. def is_valid(v):
  117. return isinstance(v, int) and v >= 0
  118. @staticmethod
  119. def common(v1, v2):
  120. if v1 is None or v2 is None:
  121. return None
  122. if v1 == v2:
  123. return v1
  124. return min(v1, v2)
  125. @staticmethod
  126. def add(base, v):
  127. if base is None:
  128. return v
  129. if v is None:
  130. return base
  131. return base + v
  132. @staticmethod
  133. def subtract(v, base):
  134. if v is None:
  135. return None
  136. if base is None:
  137. return v
  138. return v - base
  139. @staticmethod
  140. def write(dbfile, v):
  141. dbfile.write_uint(v)
  142. @staticmethod
  143. def read(dbfile):
  144. return dbfile.read_uint()
  145. @staticmethod
  146. def skip(dbfile):
  147. dbfile.seek(_INT_SIZE, 1)
  148. @staticmethod
  149. def to_bytes(v):
  150. return pack_int(v)
  151. class SequenceValues(Values):
  152. """Abstract base class for value types that store sequences.
  153. """
  154. @staticmethod
  155. def is_valid(v):
  156. return isinstance(self, (list, tuple))
  157. @staticmethod
  158. def common(v1, v2):
  159. if v1 is None or v2 is None:
  160. return None
  161. i = 0
  162. while i < len(v1) and i < len(v2):
  163. if v1[i] != v2[i]:
  164. break
  165. i += 1
  166. if i == 0:
  167. return None
  168. if i == len(v1):
  169. return v1
  170. if i == len(v2):
  171. return v2
  172. return v1[:i]
  173. @staticmethod
  174. def add(prefix, v):
  175. if prefix is None:
  176. return v
  177. if v is None:
  178. return prefix
  179. return prefix + v
  180. @staticmethod
  181. def subtract(v, prefix):
  182. if prefix is None:
  183. return v
  184. if v is None:
  185. return None
  186. if len(v) == len(prefix):
  187. return None
  188. if len(v) < len(prefix) or len(prefix) == 0:
  189. raise ValueError((v, prefix))
  190. return v[len(prefix):]
  191. @staticmethod
  192. def write(dbfile, v):
  193. dbfile.write_pickle(v)
  194. @staticmethod
  195. def read(dbfile):
  196. return dbfile.read_pickle()
  197. class BytesValues(SequenceValues):
  198. """Stores bytes objects (str in Python 2.x) in an FST.
  199. """
  200. @staticmethod
  201. def is_valid(v):
  202. return isinstance(v, bytes_type)
  203. @staticmethod
  204. def write(dbfile, v):
  205. dbfile.write_int(len(v))
  206. dbfile.write(v)
  207. @staticmethod
  208. def read(dbfile):
  209. length = dbfile.read_int()
  210. return dbfile.read(length)
  211. @staticmethod
  212. def skip(dbfile):
  213. length = dbfile.read_int()
  214. dbfile.seek(length, 1)
  215. @staticmethod
  216. def to_bytes(v):
  217. return v
  218. class ArrayValues(SequenceValues):
  219. """Stores array.array objects in an FST.
  220. """
  221. def __init__(self, typecode):
  222. self.typecode = typecode
  223. self.itemsize = array(self.typecode).itemsize
  224. def is_valid(self, v):
  225. return isinstance(v, array) and v.typecode == self.typecode
  226. @staticmethod
  227. def write(dbfile, v):
  228. dbfile.write(b(v.typecode))
  229. dbfile.write_int(len(v))
  230. dbfile.write_array(v)
  231. def read(self, dbfile):
  232. typecode = u(dbfile.read(1))
  233. length = dbfile.read_int()
  234. return dbfile.read_array(self.typecode, length)
  235. def skip(self, dbfile):
  236. length = dbfile.read_int()
  237. dbfile.seek(length * self.itemsize, 1)
  238. @staticmethod
  239. def to_bytes(v):
  240. return array_tobytes(v)
  241. class IntListValues(SequenceValues):
  242. """Stores lists of positive, increasing integers (that is, lists of
  243. integers where each number is >= 0 and each number is greater than or equal
  244. to the number that precedes it) in an FST.
  245. """
  246. @staticmethod
  247. def is_valid(v):
  248. if isinstance(v, (list, tuple)):
  249. if len(v) < 2:
  250. return True
  251. for i in xrange(1, len(v)):
  252. if not isinstance(v[i], int) or v[i] < v[i - 1]:
  253. return False
  254. return True
  255. return False
  256. @staticmethod
  257. def write(dbfile, v):
  258. base = 0
  259. dbfile.write_varint(len(v))
  260. for x in v:
  261. delta = x - base
  262. assert delta >= 0
  263. dbfile.write_varint(delta)
  264. base = x
  265. @staticmethod
  266. def read(dbfile):
  267. length = dbfile.read_varint()
  268. result = []
  269. if length > 0:
  270. base = 0
  271. for _ in xrange(length):
  272. base += dbfile.read_varint()
  273. result.append(base)
  274. return result
  275. @staticmethod
  276. def to_bytes(v):
  277. return b(repr(v))
  278. # Node-like interface wrappers
  279. class Node(object):
  280. """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
  281. arc-based interface of GraphReader into Node objects with methods to follow
  282. edges.
  283. """
  284. def __init__(self, owner, address, accept=False):
  285. self.owner = owner
  286. self.address = address
  287. self._edges = None
  288. self.accept = accept
  289. def __iter__(self):
  290. if not self._edges:
  291. self._load()
  292. return iterkeys(self._edges)
  293. def __contains__(self, key):
  294. if self._edges is None:
  295. self._load()
  296. return key in self._edges
  297. def _load(self):
  298. owner = self.owner
  299. if self.address is None:
  300. d = {}
  301. else:
  302. d = dict((arc.label, Node(owner, arc.target, arc.accept))
  303. for arc in self.owner.iter_arcs(self.address))
  304. self._edges = d
  305. def keys(self):
  306. if self._edges is None:
  307. self._load()
  308. return self._edges.keys()
  309. def all_edges(self):
  310. if self._edges is None:
  311. self._load()
  312. return self._edges
  313. def edge(self, key):
  314. if self._edges is None:
  315. self._load()
  316. return self._edges[key]
  317. def flatten(self, sofar=emptybytes):
  318. if self.accept:
  319. yield sofar
  320. for key in sorted(self):
  321. node = self.edge(key)
  322. for result in node.flatten(sofar + key):
  323. yield result
  324. def flatten_strings(self):
  325. return (utf8decode(k)[0] for k in self.flatten())
  326. class ComboNode(Node):
  327. """Base class for nodes that blend the nodes of two different graphs.
  328. Concrete subclasses need to implement the ``edge()`` method and possibly
  329. override the ``accept`` property.
  330. """
  331. def __init__(self, a, b):
  332. self.a = a
  333. self.b = b
  334. def __repr__(self):
  335. return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
  336. def __contains__(self, key):
  337. return key in self.a or key in self.b
  338. def __iter__(self):
  339. return iter(set(self.a) | set(self.b))
  340. @property
  341. def accept(self):
  342. return self.a.accept or self.b.accept
  343. class UnionNode(ComboNode):
  344. """Makes two graphs appear to be the union of the two graphs.
  345. """
  346. def edge(self, key):
  347. a = self.a
  348. b = self.b
  349. if key in a and key in b:
  350. return UnionNode(a.edge(key), b.edge(key))
  351. elif key in a:
  352. return a.edge(key)
  353. else:
  354. return b.edge(key)
  355. class IntersectionNode(ComboNode):
  356. """Makes two graphs appear to be the intersection of the two graphs.
  357. """
  358. def edge(self, key):
  359. a = self.a
  360. b = self.b
  361. if key in a and key in b:
  362. return IntersectionNode(a.edge(key), b.edge(key))
  363. # Cursor
  364. class BaseCursor(object):
  365. """Base class for a cursor-type object for navigating an FST/word graph,
  366. represented by a :class:`GraphReader` object.
  367. >>> cur = GraphReader(dawgfile).cursor()
  368. >>> for key in cur.follow():
  369. ... print(repr(key))
  370. The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
  371. """
  372. def is_active(self):
  373. """Returns True if this cursor is still active, that is it has not
  374. read past the last arc in the graph.
  375. """
  376. raise NotImplementedError
  377. def label(self):
  378. """Returns the label bytes of the current arc.
  379. """
  380. raise NotImplementedError
  381. def prefix(self):
  382. """Returns a sequence of the label bytes for the path from the root
  383. to the current arc.
  384. """
  385. raise NotImplementedError
  386. def prefix_bytes(self):
  387. """Returns the label bytes for the path from the root to the current
  388. arc as a single joined bytes object.
  389. """
  390. return emptybytes.join(self.prefix())
  391. def prefix_string(self):
  392. """Returns the labels of the path from the root to the current arc as
  393. a decoded unicode string.
  394. """
  395. return utf8decode(self.prefix_bytes())[0]
  396. def peek_key(self):
  397. """Returns a sequence of label bytes representing the next closest
  398. key in the graph.
  399. """
  400. for label in self.prefix():
  401. yield label
  402. c = self.copy()
  403. while not c.stopped():
  404. c.follow()
  405. yield c.label()
  406. def peek_key_bytes(self):
  407. """Returns the next closest key in the graph as a single bytes object.
  408. """
  409. return emptybytes.join(self.peek_key())
  410. def peek_key_string(self):
  411. """Returns the next closest key in the graph as a decoded unicode
  412. string.
  413. """
  414. return utf8decode(self.peek_key_bytes())[0]
  415. def stopped(self):
  416. """Returns True if the current arc leads to a stop state.
  417. """
  418. raise NotImplementedError
  419. def value(self):
  420. """Returns the value at the current arc, if reading an FST.
  421. """
  422. raise NotImplementedError
  423. def accept(self):
  424. """Returns True if the current arc leads to an accept state (the end
  425. of a valid key).
  426. """
  427. raise NotImplementedError
  428. def at_last_arc(self):
  429. """Returns True if the current arc is the last outgoing arc from the
  430. previous node.
  431. """
  432. raise NotImplementedError
  433. def next_arc(self):
  434. """Moves to the next outgoing arc from the previous node.
  435. """
  436. raise NotImplementedError
  437. def follow(self):
  438. """Follows the current arc.
  439. """
  440. raise NotImplementedError
  441. def switch_to(self, label):
  442. """Switch to the sibling arc with the given label bytes.
  443. """
  444. _label = self.label
  445. _at_last_arc = self.at_last_arc
  446. _next_arc = self.next_arc
  447. while True:
  448. thislabel = _label()
  449. if thislabel == label:
  450. return True
  451. if thislabel > label or _at_last_arc():
  452. return False
  453. _next_arc()
  454. def skip_to(self, key):
  455. """Moves the cursor to the path represented by the given key bytes.
  456. """
  457. _accept = self.accept
  458. _prefix = self.prefix
  459. _next_arc = self.next_arc
  460. keylist = list(key)
  461. while True:
  462. if _accept():
  463. thiskey = list(_prefix())
  464. if keylist == thiskey:
  465. return True
  466. elif keylist > thiskey:
  467. return False
  468. _next_arc()
  469. def flatten(self):
  470. """Yields the keys in the graph, starting at the current position.
  471. """
  472. _is_active = self.is_active
  473. _accept = self.accept
  474. _stopped = self.stopped
  475. _follow = self.follow
  476. _next_arc = self.next_arc
  477. _prefix_bytes = self.prefix_bytes
  478. if not _is_active():
  479. raise InactiveCursor
  480. while _is_active():
  481. if _accept():
  482. yield _prefix_bytes()
  483. if not _stopped():
  484. _follow()
  485. continue
  486. _next_arc()
  487. def flatten_v(self):
  488. """Yields (key, value) tuples in an FST, starting at the current
  489. position.
  490. """
  491. for key in self.flatten():
  492. yield key, self.value()
  493. def flatten_strings(self):
  494. return (utf8decode(k)[0] for k in self.flatten())
  495. def find_path(self, path):
  496. """Follows the labels in the given path, starting at the current
  497. position.
  498. """
  499. path = to_labels(path)
  500. _switch_to = self.switch_to
  501. _follow = self.follow
  502. _stopped = self.stopped
  503. first = True
  504. for i, label in enumerate(path):
  505. if not first:
  506. _follow()
  507. if not _switch_to(label):
  508. return False
  509. if _stopped():
  510. if i < len(path) - 1:
  511. return False
  512. first = False
  513. return True
  514. class Cursor(BaseCursor):
  515. def __init__(self, graph, root=None, stack=None):
  516. self.graph = graph
  517. self.vtype = graph.vtype
  518. self.root = root if root is not None else graph.default_root()
  519. if stack:
  520. self.stack = stack
  521. else:
  522. self.reset()
  523. def _current_attr(self, name):
  524. stack = self.stack
  525. if not stack:
  526. raise InactiveCursor
  527. return getattr(stack[-1], name)
  528. def is_active(self):
  529. return bool(self.stack)
  530. def stopped(self):
  531. return self._current_attr("target") is None
  532. def accept(self):
  533. return self._current_attr("accept")
  534. def at_last_arc(self):
  535. return self._current_attr("lastarc")
  536. def label(self):
  537. return self._current_attr("label")
  538. def reset(self):
  539. self.stack = []
  540. self.sums = [None]
  541. self._push(self.graph.arc_at(self.root))
  542. def copy(self):
  543. return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
  544. def prefix(self):
  545. stack = self.stack
  546. if not stack:
  547. raise InactiveCursor
  548. return (arc.label for arc in stack)
  549. # Override: more efficient implementation using graph methods directly
  550. def peek_key(self):
  551. if not self.stack:
  552. raise InactiveCursor
  553. for label in self.prefix():
  554. yield label
  555. arc = copy.copy(self.stack[-1])
  556. graph = self.graph
  557. while not arc.accept and arc.target is not None:
  558. graph.arc_at(arc.target, arc)
  559. yield arc.label
  560. def value(self):
  561. stack = self.stack
  562. if not stack:
  563. raise InactiveCursor
  564. vtype = self.vtype
  565. if not vtype:
  566. raise Exception("No value type")
  567. v = self.sums[-1]
  568. current = stack[-1]
  569. if current.value:
  570. v = vtype.add(v, current.value)
  571. if current.accept and current.acceptval is not None:
  572. v = vtype.add(v, current.acceptval)
  573. return v
  574. def next_arc(self):
  575. stack = self.stack
  576. if not stack:
  577. raise InactiveCursor
  578. while stack and stack[-1].lastarc:
  579. self.pop()
  580. if stack:
  581. current = stack[-1]
  582. self.graph.arc_at(current.endpos, current)
  583. return current
  584. def follow(self):
  585. address = self._current_attr("target")
  586. if address is None:
  587. raise Exception("Can't follow a stop arc")
  588. self._push(self.graph.arc_at(address))
  589. return self
  590. # Override: more efficient implementation manipulating the stack
  591. def skip_to(self, key):
  592. key = to_labels(key)
  593. stack = self.stack
  594. if not stack:
  595. raise InactiveCursor
  596. _follow = self.follow
  597. _next_arc = self.next_arc
  598. i = self._pop_to_prefix(key)
  599. while stack and i < len(key):
  600. curlabel = stack[-1].label
  601. keylabel = key[i]
  602. if curlabel == keylabel:
  603. _follow()
  604. i += 1
  605. elif curlabel > keylabel:
  606. return
  607. else:
  608. _next_arc()
  609. # Override: more efficient implementation using find_arc
  610. def switch_to(self, label):
  611. stack = self.stack
  612. if not stack:
  613. raise InactiveCursor
  614. current = stack[-1]
  615. if label == current.label:
  616. return True
  617. else:
  618. arc = self.graph.find_arc(current.endpos, label, current)
  619. return arc
  620. def _push(self, arc):
  621. if self.vtype and self.stack:
  622. sums = self.sums
  623. sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
  624. self.stack.append(arc)
  625. def pop(self):
  626. self.stack.pop()
  627. if self.vtype:
  628. self.sums.pop()
  629. def _pop_to_prefix(self, key):
  630. stack = self.stack
  631. if not stack:
  632. raise InactiveCursor
  633. i = 0
  634. maxpre = min(len(stack), len(key))
  635. while i < maxpre and key[i] == stack[i].label:
  636. i += 1
  637. if stack[i].label > key[i]:
  638. self.current = None
  639. return
  640. while len(stack) > i + 1:
  641. self.pop()
  642. self.next_arc()
  643. return i
  644. class UncompiledNode(object):
  645. # Represents an "in-memory" node used by the GraphWriter before it is
  646. # written to disk.
  647. compiled = False
  648. def __init__(self, owner):
  649. self.owner = owner
  650. self._digest = None
  651. self.clear()
  652. def clear(self):
  653. self.arcs = []
  654. self.value = None
  655. self.accept = False
  656. self.inputcount = 0
  657. def __repr__(self):
  658. return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
  659. def digest(self):
  660. if self._digest is None:
  661. d = sha1()
  662. vtype = self.owner.vtype
  663. for arc in self.arcs:
  664. d.update(arc.label)
  665. if arc.target:
  666. d.update(pack_long(arc.target))
  667. else:
  668. d.update(b("z"))
  669. if arc.value:
  670. d.update(vtype.to_bytes(arc.value))
  671. if arc.accept:
  672. d.update(b("T"))
  673. self._digest = d.digest()
  674. return self._digest
  675. def edges(self):
  676. return self.arcs
  677. def last_value(self, label):
  678. assert self.arcs[-1].label == label
  679. return self.arcs[-1].value
  680. def add_arc(self, label, target):
  681. self.arcs.append(Arc(label, target))
  682. def replace_last(self, label, target, accept, acceptval=None):
  683. arc = self.arcs[-1]
  684. assert arc.label == label, "%r != %r" % (arc.label, label)
  685. arc.target = target
  686. arc.accept = accept
  687. arc.acceptval = acceptval
  688. def delete_last(self, label, target):
  689. arc = self.arcs.pop()
  690. assert arc.label == label
  691. assert arc.target == target
  692. def set_last_value(self, label, value):
  693. arc = self.arcs[-1]
  694. assert arc.label == label, "%r->%r" % (arc.label, label)
  695. arc.value = value
  696. def prepend_value(self, prefix):
  697. add = self.owner.vtype.add
  698. for arc in self.arcs:
  699. arc.value = add(prefix, arc.value)
  700. if self.accept:
  701. self.value = add(prefix, self.value)
  702. class Arc(object):
  703. """
  704. Represents a directed arc between two nodes in an FSA/FST graph.
  705. The ``lastarc`` attribute is True if this is the last outgoing arc from the
  706. previous node.
  707. """
  708. __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
  709. "endpos")
  710. def __init__(self, label=None, target=None, value=None, accept=False,
  711. acceptval=None):
  712. """
  713. :param label:The label bytes for this arc. For a word graph, this will
  714. be a character.
  715. :param target: The address of the node at the endpoint of this arc.
  716. :param value: The inner FST value at the endpoint of this arc.
  717. :param accept: Whether the endpoint of this arc is an accept state
  718. (eg the end of a valid word).
  719. :param acceptval: If the endpoint of this arc is an accept state, the
  720. final FST value for that accepted state.
  721. """
  722. self.label = label
  723. self.target = target
  724. self.value = value
  725. self.accept = accept
  726. self.lastarc = None
  727. self.acceptval = acceptval
  728. self.endpos = None
  729. def __repr__(self):
  730. return "<%r-%s %s%s>" % (self.label, self.target,
  731. "." if self.accept else "",
  732. (" %r" % self.value) if self.value else "")
  733. def __eq__(self, other):
  734. if (isinstance(other, self.__class__) and self.accept == other.accept
  735. and self.lastarc == other.lastarc and self.target == other.target
  736. and self.value == other.value and self.label == other.label):
  737. return True
  738. return False
  739. # Graph writer
  740. class GraphWriter(object):
  741. """Writes an FSA/FST graph to disk.
  742. Call ``insert(key)`` to insert keys into the graph. You must
  743. insert keys in sorted order. Call ``close()`` to finish the graph and close
  744. the file.
  745. >>> gw = GraphWriter(my_file)
  746. >>> gw.insert("alfa")
  747. >>> gw.insert("bravo")
  748. >>> gw.insert("charlie")
  749. >>> gw.close()
  750. The graph writer can write separate graphs for multiple fields. Use
  751. ``start_field(name)`` and ``finish_field()`` to separate fields.
  752. >>> gw = GraphWriter(my_file)
  753. >>> gw.start_field("content")
  754. >>> gw.insert("alfalfa")
  755. >>> gw.insert("apple")
  756. >>> gw.finish_field()
  757. >>> gw.start_field("title")
  758. >>> gw.insert("artichoke")
  759. >>> gw.finish_field()
  760. >>> gw.close()
  761. """
  762. version = 1
  763. def __init__(self, dbfile, vtype=None, merge=None):
  764. """
  765. :param dbfile: the file to write to.
  766. :param vtype: a :class:`Values` class to use for storing values. This
  767. is only necessary if you will be storing values for the keys.
  768. :param merge: a function that takes two values and returns a single
  769. value. This is called if you insert two identical keys with values.
  770. """
  771. self.dbfile = dbfile
  772. self.vtype = vtype
  773. self.merge = merge
  774. self.fieldroots = {}
  775. self.arc_count = 0
  776. self.node_count = 0
  777. self.fixed_count = 0
  778. dbfile.write(b("GRPH"))
  779. dbfile.write_int(self.version)
  780. dbfile.write_uint(0)
  781. self._infield = False
  782. def start_field(self, fieldname):
  783. """Starts a new graph for the given field.
  784. """
  785. if not fieldname:
  786. raise ValueError("Field name cannot be equivalent to False")
  787. if self._infield:
  788. self.finish_field()
  789. self.fieldname = fieldname
  790. self.seen = {}
  791. self.nodes = [UncompiledNode(self)]
  792. self.lastkey = ''
  793. self._inserted = False
  794. self._infield = True
  795. def finish_field(self):
  796. """Finishes the graph for the current field.
  797. """
  798. if not self._infield:
  799. raise Exception("Called finish_field before start_field")
  800. self._infield = False
  801. if self._inserted:
  802. self.fieldroots[self.fieldname] = self._finish()
  803. self.fieldname = None
  804. def close(self):
  805. """Finishes the current graph and closes the underlying file.
  806. """
  807. if self.fieldname is not None:
  808. self.finish_field()
  809. dbfile = self.dbfile
  810. here = dbfile.tell()
  811. dbfile.write_pickle(self.fieldroots)
  812. dbfile.flush()
  813. dbfile.seek(4 + _INT_SIZE) # Seek past magic and version number
  814. dbfile.write_uint(here)
  815. dbfile.close()
  816. def insert(self, key, value=None):
  817. """Inserts the given key into the graph.
  818. :param key: a sequence of bytes objects, a bytes object, or a string.
  819. :param value: an optional value to encode in the graph along with the
  820. key. If the writer was not instantiated with a value type, passing
  821. a value here will raise an error.
  822. """
  823. if not self._infield:
  824. raise Exception("Inserted %r before starting a field" % key)
  825. self._inserted = True
  826. key = to_labels(key) # Python 3 sucks
  827. vtype = self.vtype
  828. lastkey = self.lastkey
  829. nodes = self.nodes
  830. if len(key) < 1:
  831. raise KeyError("Can't store a null key %r" % (key,))
  832. if lastkey and lastkey > key:
  833. raise KeyError("Keys out of order %r..%r" % (lastkey, key))
  834. # Find the common prefix shared by this key and the previous one
  835. prefixlen = 0
  836. for i in xrange(min(len(lastkey), len(key))):
  837. if lastkey[i] != key[i]:
  838. break
  839. prefixlen += 1
  840. # Compile the nodes after the prefix, since they're not shared
  841. self._freeze_tail(prefixlen + 1)
  842. # Create new nodes for the parts of this key after the shared prefix
  843. for char in key[prefixlen:]:
  844. node = UncompiledNode(self)
  845. # Create an arc to this node on the previous node
  846. nodes[-1].add_arc(char, node)
  847. nodes.append(node)
  848. # Mark the last node as an accept state
  849. lastnode = nodes[-1]
  850. lastnode.accept = True
  851. if vtype:
  852. if value is not None and not vtype.is_valid(value):
  853. raise ValueError("%r is not valid for %s" % (value, vtype))
  854. # Push value commonalities through the tree
  855. common = None
  856. for i in xrange(1, prefixlen + 1):
  857. node = nodes[i]
  858. parent = nodes[i - 1]
  859. lastvalue = parent.last_value(key[i - 1])
  860. if lastvalue is not None:
  861. common = vtype.common(value, lastvalue)
  862. suffix = vtype.subtract(lastvalue, common)
  863. parent.set_last_value(key[i - 1], common)
  864. node.prepend_value(suffix)
  865. else:
  866. common = suffix = None
  867. value = vtype.subtract(value, common)
  868. if key == lastkey:
  869. # If this key is a duplicate, merge its value with the value of
  870. # the previous (same) key
  871. lastnode.value = self.merge(lastnode.value, value)
  872. else:
  873. nodes[prefixlen].set_last_value(key[prefixlen], value)
  874. elif value:
  875. raise Exception("Value %r but no value type" % value)
  876. self.lastkey = key
  877. def _freeze_tail(self, prefixlen):
  878. nodes = self.nodes
  879. lastkey = self.lastkey
  880. downto = max(1, prefixlen)
  881. while len(nodes) > downto:
  882. node = nodes.pop()
  883. parent = nodes[-1]
  884. inlabel = lastkey[len(nodes) - 1]
  885. self._compile_targets(node)
  886. accept = node.accept or len(node.arcs) == 0
  887. address = self._compile_node(node)
  888. parent.replace_last(inlabel, address, accept, node.value)
  889. def _finish(self):
  890. nodes = self.nodes
  891. root = nodes[0]
  892. # Minimize nodes in the last word's suffix
  893. self._freeze_tail(0)
  894. # Compile remaining targets
  895. self._compile_targets(root)
  896. return self._compile_node(root)
  897. def _compile_targets(self, node):
  898. for arc in node.arcs:
  899. if isinstance(arc.target, UncompiledNode):
  900. n = arc.target
  901. if len(n.arcs) == 0:
  902. arc.accept = n.accept = True
  903. arc.target = self._compile_node(n)
  904. def _compile_node(self, uncnode):
  905. seen = self.seen
  906. if len(uncnode.arcs) == 0:
  907. # Leaf node
  908. address = self._write_node(uncnode)
  909. else:
  910. d = uncnode.digest()
  911. address = seen.get(d)
  912. if address is None:
  913. address = self._write_node(uncnode)
  914. seen[d] = address
  915. return address
  916. def _write_node(self, uncnode):
  917. vtype = self.vtype
  918. dbfile = self.dbfile
  919. arcs = uncnode.arcs
  920. numarcs = len(arcs)
  921. if not numarcs:
  922. if uncnode.accept:
  923. return None
  924. else:
  925. # What does it mean for an arc to stop but not be accepted?
  926. raise Exception
  927. self.node_count += 1
  928. buf = StructFile(BytesIO())
  929. nodestart = dbfile.tell()
  930. #self.count += 1
  931. #self.arccount += numarcs
  932. fixedsize = -1
  933. arcstart = buf.tell()
  934. for i, arc in enumerate(arcs):
  935. self.arc_count += 1
  936. target = arc.target
  937. label = arc.label
  938. flags = 0
  939. if len(label) > 1:
  940. flags += MULTIBYTE_LABEL
  941. if i == numarcs - 1:
  942. flags += ARC_LAST
  943. if arc.accept:
  944. flags += ARC_ACCEPT
  945. if target is None:
  946. flags += ARC_STOP
  947. if arc.value is not None:
  948. flags += ARC_HAS_VAL
  949. if arc.acceptval is not None:
  950. flags += ARC_HAS_ACCEPT_VAL
  951. buf.write(pack_byte(flags))
  952. if len(label) > 1:
  953. buf.write(varint(len(label)))
  954. buf.write(label)
  955. if target is not None:
  956. buf.write(pack_uint(target))
  957. if arc.value is not None:
  958. vtype.write(buf, arc.value)
  959. if arc.acceptval is not None:
  960. vtype.write(buf, arc.acceptval)
  961. here = buf.tell()
  962. thissize = here - arcstart
  963. arcstart = here
  964. if fixedsize == -1:
  965. fixedsize = thissize
  966. elif fixedsize > 0 and thissize != fixedsize:
  967. fixedsize = 0
  968. if fixedsize > 0:
  969. # Write a fake arc containing the fixed size and number of arcs
  970. dbfile.write_byte(255) # FIXED_SIZE
  971. dbfile.write_int(fixedsize)
  972. dbfile.write_int(numarcs)
  973. self.fixed_count += 1
  974. dbfile.write(buf.file.getvalue())
  975. return nodestart
  976. # Graph reader
  977. class BaseGraphReader(object):
  978. def cursor(self, rootname=None):
  979. return Cursor(self, self.root(rootname))
  980. def has_root(self, rootname):
  981. raise NotImplementedError
  982. def root(self, rootname=None):
  983. raise NotImplementedError
  984. # Low level methods
  985. def arc_at(self, address, arc):
  986. raise NotImplementedError
  987. def iter_arcs(self, address, arc=None):
  988. raise NotImplementedError
  989. def find_arc(self, address, label, arc=None):
  990. arc = arc or Arc()
  991. for arc in self.iter_arcs(address, arc):
  992. thislabel = arc.label
  993. if thislabel == label:
  994. return arc
  995. elif thislabel > label:
  996. return None
  997. # Convenience methods
  998. def list_arcs(self, address):
  999. return list(copy.copy(arc) for arc in self.iter_arcs(address))
  1000. def arc_dict(self, address):
  1001. return dict((arc.label, copy.copy(arc))
  1002. for arc in self.iter_arcs(address))
  1003. def find_path(self, path, arc=None, address=None):
  1004. path = to_labels(path)
  1005. if arc:
  1006. address = arc.target
  1007. else:
  1008. arc = Arc()
  1009. if address is None:
  1010. address = self._root
  1011. for label in path:
  1012. if address is None:
  1013. return None
  1014. if not self.find_arc(address, label, arc):
  1015. return None
  1016. address = arc.target
  1017. return arc
  1018. class GraphReader(BaseGraphReader):
  1019. def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
  1020. self.dbfile = dbfile
  1021. self.vtype = vtype
  1022. self.filebase = filebase
  1023. dbfile.seek(filebase)
  1024. magic = dbfile.read(4)
  1025. if magic != b("GRPH"):
  1026. raise FileVersionError
  1027. self.version = dbfile.read_int()
  1028. dbfile.seek(dbfile.read_uint())
  1029. self.roots = dbfile.read_pickle()
  1030. self._root = None
  1031. if rootname is None and len(self.roots) == 1:
  1032. # If there's only one root, just use it. Have to wrap a list around
  1033. # the keys() method here because of Python 3.
  1034. rootname = list(self.roots.keys())[0]
  1035. if rootname is not None:
  1036. self._root = self.root(rootname)
  1037. def close(self):
  1038. self.dbfile.close()
  1039. # Overrides
  1040. def has_root(self, rootname):
  1041. return rootname in self.roots
  1042. def root(self, rootname=None):
  1043. if rootname is None:
  1044. return self._root
  1045. else:
  1046. return self.roots[rootname]
  1047. def default_root(self):
  1048. return self._root
  1049. def arc_at(self, address, arc=None):
  1050. arc = arc or Arc()
  1051. self.dbfile.seek(address)
  1052. return self._read_arc(arc)
  1053. def iter_arcs(self, address, arc=None):
  1054. arc = arc or Arc()
  1055. _read_arc = self._read_arc
  1056. self.dbfile.seek(address)
  1057. while True:
  1058. _read_arc(arc)
  1059. yield arc
  1060. if arc.lastarc:
  1061. break
  1062. def find_arc(self, address, label, arc=None):
  1063. arc = arc or Arc()
  1064. dbfile = self.dbfile
  1065. dbfile.seek(address)
  1066. # If records are fixed size, we can do a binary search
  1067. finfo = self._read_fixed_info()
  1068. if finfo:
  1069. size, count = finfo
  1070. address = dbfile.tell()
  1071. if count > 2:
  1072. return self._binary_search(address, size, count, label, arc)
  1073. # If records aren't fixed size, fall back to the parent's linear
  1074. # search method
  1075. return BaseGraphReader.find_arc(self, address, label, arc)
  1076. # Implementations
  1077. def _read_arc(self, toarc=None):
  1078. toarc = toarc or Arc()
  1079. dbfile = self.dbfile
  1080. flags = dbfile.read_byte()
  1081. if flags == 255:
  1082. # This is a fake arc containing fixed size information; skip it
  1083. # and read the next arc
  1084. dbfile.seek(_INT_SIZE * 2, 1)
  1085. flags = dbfile.read_byte()
  1086. toarc.label = self._read_label(flags)
  1087. return self._read_arc_data(flags, toarc)
  1088. def _read_label(self, flags):
  1089. dbfile = self.dbfile
  1090. if flags & MULTIBYTE_LABEL:
  1091. length = dbfile.read_varint()
  1092. else:
  1093. length = 1
  1094. label = dbfile.read(length)
  1095. return label
  1096. def _read_fixed_info(self):
  1097. dbfile = self.dbfile
  1098. flags = dbfile.read_byte()
  1099. if flags == 255:
  1100. size = dbfile.read_int()
  1101. count = dbfile.read_int()
  1102. return (size, count)
  1103. else:
  1104. return None
  1105. def _read_arc_data(self, flags, arc):
  1106. dbfile = self.dbfile
  1107. accept = arc.accept = bool(flags & ARC_ACCEPT)
  1108. arc.lastarc = flags & ARC_LAST
  1109. if flags & ARC_STOP:
  1110. arc.target = None
  1111. else:
  1112. arc.target = dbfile.read_uint()
  1113. if flags & ARC_HAS_VAL:
  1114. arc.value = self.vtype.read(dbfile)
  1115. else:
  1116. arc.value = None
  1117. if accept and flags & ARC_HAS_ACCEPT_VAL:
  1118. arc.acceptval = self.vtype.read(dbfile)
  1119. arc.endpos = dbfile.tell()
  1120. return arc
  1121. def _binary_search(self, address, size, count, label, arc):
  1122. dbfile = self.dbfile
  1123. _read_label = self._read_label
  1124. lo = 0
  1125. hi = count
  1126. while lo < hi:
  1127. mid = (lo + hi) // 2
  1128. midaddr = address + mid * size
  1129. dbfile.seek(midaddr)
  1130. flags = dbfile.read_byte()
  1131. midlabel = self._read_label(flags)
  1132. if midlabel == label:
  1133. arc.label = midlabel
  1134. return self._read_arc_data(flags, arc)
  1135. elif midlabel < label:
  1136. lo = mid + 1
  1137. else:
  1138. hi = mid
  1139. if lo == count:
  1140. return None
  1141. def to_labels(key):
  1142. """Takes a string and returns a list of bytestrings, suitable for use as
  1143. a key or path in an FSA/FST graph.
  1144. """
  1145. # Convert to tuples of bytestrings (must be tuples so they can be hashed)
  1146. keytype = type(key)
  1147. # I hate the Python 3 bytes object so friggin much
  1148. if keytype is tuple or keytype is list:
  1149. if not all(isinstance(e, bytes_type) for e in key):
  1150. raise TypeError("%r contains a non-bytestring")
  1151. if keytype is list:
  1152. key = tuple(key)
  1153. elif isinstance(key, bytes_type):
  1154. key = tuple(key[i:i + 1] for i in xrange(len(key)))
  1155. elif isinstance(key, text_type):
  1156. key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
  1157. else:
  1158. raise TypeError("Don't know how to convert %r" % key)
  1159. return key
  1160. # Within edit distance function
  1161. def within(graph, text, k=1, prefix=0, address=None):
  1162. """Yields a series of keys in the given graph within ``k`` edit distance of
  1163. ``text``. If ``prefix`` is greater than 0, all keys must match the first
  1164. ``prefix`` characters of ``text``.
  1165. """
  1166. text = to_labels(text)
  1167. if address is None:
  1168. address = graph._root
  1169. sofar = emptybytes
  1170. accept = False
  1171. if prefix:
  1172. prefixchars = text[:prefix]
  1173. arc = graph.find_path(prefixchars, address=address)
  1174. if arc is None:
  1175. return
  1176. sofar = emptybytes.join(prefixchars)
  1177. address = arc.target
  1178. accept = arc.accept
  1179. stack = [(address, k, prefix, sofar, accept)]
  1180. seen = set()
  1181. while stack:
  1182. state = stack.pop()
  1183. # Have we already tried this state?
  1184. if state in seen:
  1185. continue
  1186. seen.add(state)
  1187. address, k, i, sofar, accept = state
  1188. # If we're at the end of the text (or deleting enough chars would get
  1189. # us to the end and still within K), and we're in the accept state,
  1190. # yield the current result
  1191. if (len(text) - i <= k) and accept:
  1192. yield utf8decode(sofar)[0]
  1193. # If we're in the stop state, give up
  1194. if address is None:
  1195. continue
  1196. # Exact match
  1197. if i < len(text):
  1198. arc = graph.find_arc(address, text[i])
  1199. if arc:
  1200. stack.append((arc.target, k, i + 1, sofar + text[i],
  1201. arc.accept))
  1202. # If K is already 0, can't do any more edits
  1203. if k < 1:
  1204. continue
  1205. k -= 1
  1206. arcs = graph.arc_dict(address)
  1207. # Insertions
  1208. stack.extend((arc.target, k, i, sofar + char, arc.accept)
  1209. for char, arc in iteritems(arcs))
  1210. # Deletion, replacement, and transpo only work before the end
  1211. if i >= len(text):
  1212. continue
  1213. char = text[i]
  1214. # Deletion
  1215. stack.append((address, k, i + 1, sofar, False))
  1216. # Replacement
  1217. for char2, arc in iteritems(arcs):
  1218. if char2 != char:
  1219. stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
  1220. # Transposition
  1221. if i < len(text) - 1:
  1222. char2 = text[i + 1]
  1223. if char != char2 and char2 in arcs:
  1224. # Find arc from next char to this char
  1225. target = arcs[char2].target
  1226. if target:
  1227. arc = graph.find_arc(target, char)
  1228. if arc:
  1229. stack.append((arc.target, k, i + 2,
  1230. sofar + char2 + char, arc.accept))
  1231. # Utility functions
  1232. def dump_graph(graph, address=None, tab=0, out=None):
  1233. if address is None:
  1234. address = graph._root
  1235. if out is None:
  1236. out = sys.stdout
  1237. here = "%06d" % address
  1238. for i, arc in enumerate(graph.list_arcs(address)):
  1239. if i == 0:
  1240. out.write(here)
  1241. else:
  1242. out.write(" " * 6)
  1243. out.write(" " * tab)
  1244. out.write("%r %r %s %r\n"
  1245. % (arc.label, arc.target, arc.accept, arc.value))
  1246. if arc.target is not None:
  1247. dump_graph(graph, arc.target, tab + 1, out=out)