PageRenderTime 59ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 1ms

/src/whoosh/fst.py

https://bitbucket.org/rayleyva/whoosh
Python | 1550 lines | 1469 code | 25 blank | 56 comment | 22 complexity | ec47613e30de35ccb738d6548d55ebc7 MD5 | raw file
Possible License(s): Apache-2.0
  1. # Copyright 2009 Matt Chaput. All rights reserved.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions are met:
  5. #
  6. # 1. Redistributions of source code must retain the above copyright notice,
  7. # this list of conditions and the following disclaimer.
  8. #
  9. # 2. Redistributions in binary form must reproduce the above copyright
  10. # notice, this list of conditions and the following disclaimer in the
  11. # documentation and/or other materials provided with the distribution.
  12. #
  13. # THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
  14. # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  15. # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
  16. # EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
  17. # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  18. # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
  19. # OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  20. # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  21. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
  22. # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  23. #
  24. # The views and conclusions contained in the software and documentation are
  25. # those of the authors and should not be interpreted as representing official
  26. # policies, either expressed or implied, of Matt Chaput.
  27. """
  28. This module implements an FST/FSA writer and reader. An FST (Finite State
  29. Transducer) stores a directed acyclic graph with values associated with the
  30. leaves. Common elements of the values are pushed inside the tree. An FST that
  31. does not store values is a regular FSA.
  32. The format of the leaf values is pluggable using subclasses of the Values
  33. class.
  34. Whoosh uses these structures to store a directed acyclic word graph (DAWG) for
  35. use in (at least) spell checking.
  36. """
  37. import sys, copy
  38. from array import array
  39. from hashlib import sha1 # @UnresolvedImport
  40. from whoosh.compat import b, u, BytesIO
  41. from whoosh.compat import xrange, iteritems, iterkeys, izip, array_tobytes
  42. from whoosh.compat import bytes_type, text_type
  43. from whoosh.filedb.structfile import StructFile
  44. from whoosh.system import _INT_SIZE
  45. from whoosh.system import pack_byte, pack_int, pack_uint, pack_long
  46. from whoosh.system import emptybytes
  47. from whoosh.util.text import utf8encode, utf8decode
  48. from whoosh.util.varints import varint
  49. class FileVersionError(Exception):
  50. pass
  51. class InactiveCursor(Exception):
  52. pass
  53. ARC_LAST = 1
  54. ARC_ACCEPT = 2
  55. ARC_STOP = 4
  56. ARC_HAS_VAL = 8
  57. ARC_HAS_ACCEPT_VAL = 16
  58. MULTIBYTE_LABEL = 32
  59. # FST Value types
  60. class Values(object):
  61. """Base for classes the describe how to encode and decode FST values.
  62. """
  63. @staticmethod
  64. def is_valid(v):
  65. """Returns True if v is a valid object that can be stored by this
  66. class.
  67. """
  68. raise NotImplementedError
  69. @staticmethod
  70. def common(v1, v2):
  71. """Returns the "common" part of the two values, for whatever "common"
  72. means for this class. For example, a string implementation would return
  73. the common shared prefix, for an int implementation it would return
  74. the minimum of the two numbers.
  75. If there is no common part, this method should return None.
  76. """
  77. raise NotImplementedError
  78. @staticmethod
  79. def add(prefix, v):
  80. """Adds the given prefix (the result of a call to common()) to the
  81. given value.
  82. """
  83. raise NotImplementedError
  84. @staticmethod
  85. def subtract(v, prefix):
  86. """Subtracts the "common" part (the prefix) from the given value.
  87. """
  88. raise NotImplementedError
  89. @staticmethod
  90. def write(dbfile, v):
  91. """Writes value v to a file.
  92. """
  93. raise NotImplementedError
  94. @staticmethod
  95. def read(dbfile):
  96. """Reads a value from the given file.
  97. """
  98. raise NotImplementedError
  99. @classmethod
  100. def skip(cls, dbfile):
  101. """Skips over a value in the given file.
  102. """
  103. cls.read(dbfile)
  104. @staticmethod
  105. def to_bytes(v):
  106. """Returns a str (Python 2.x) or bytes (Python 3) representation of
  107. the given value. This is used for calculating node digests, so it
  108. should be unique but fast to calculate, and does not have to be
  109. parseable.
  110. """
  111. raise NotImplementedError
  112. @staticmethod
  113. def merge(v1, v2):
  114. raise NotImplementedError
  115. class IntValues(Values):
  116. """Stores integer values in an FST.
  117. """
  118. @staticmethod
  119. def is_valid(v):
  120. return isinstance(v, int) and v >= 0
  121. @staticmethod
  122. def common(v1, v2):
  123. if v1 is None or v2 is None:
  124. return None
  125. if v1 == v2:
  126. return v1
  127. return min(v1, v2)
  128. @staticmethod
  129. def add(base, v):
  130. if base is None:
  131. return v
  132. if v is None:
  133. return base
  134. return base + v
  135. @staticmethod
  136. def subtract(v, base):
  137. if v is None:
  138. return None
  139. if base is None:
  140. return v
  141. return v - base
  142. @staticmethod
  143. def write(dbfile, v):
  144. dbfile.write_uint(v)
  145. @staticmethod
  146. def read(dbfile):
  147. return dbfile.read_uint()
  148. @staticmethod
  149. def skip(dbfile):
  150. dbfile.seek(_INT_SIZE, 1)
  151. @staticmethod
  152. def to_bytes(v):
  153. return pack_int(v)
  154. class SequenceValues(Values):
  155. """Abstract base class for value types that store sequences.
  156. """
  157. @staticmethod
  158. def is_valid(v):
  159. return isinstance(self, (list, tuple))
  160. @staticmethod
  161. def common(v1, v2):
  162. if v1 is None or v2 is None:
  163. return None
  164. i = 0
  165. while i < len(v1) and i < len(v2):
  166. if v1[i] != v2[i]:
  167. break
  168. i += 1
  169. if i == 0:
  170. return None
  171. if i == len(v1):
  172. return v1
  173. if i == len(v2):
  174. return v2
  175. return v1[:i]
  176. @staticmethod
  177. def add(prefix, v):
  178. if prefix is None:
  179. return v
  180. if v is None:
  181. return prefix
  182. return prefix + v
  183. @staticmethod
  184. def subtract(v, prefix):
  185. if prefix is None:
  186. return v
  187. if v is None:
  188. return None
  189. if len(v) == len(prefix):
  190. return None
  191. if len(v) < len(prefix) or len(prefix) == 0:
  192. raise ValueError((v, prefix))
  193. return v[len(prefix):]
  194. @staticmethod
  195. def write(dbfile, v):
  196. dbfile.write_pickle(v)
  197. @staticmethod
  198. def read(dbfile):
  199. return dbfile.read_pickle()
  200. class BytesValues(SequenceValues):
  201. """Stores bytes objects (str in Python 2.x) in an FST.
  202. """
  203. @staticmethod
  204. def is_valid(v):
  205. return isinstance(v, bytes_type)
  206. @staticmethod
  207. def write(dbfile, v):
  208. dbfile.write_int(len(v))
  209. dbfile.write(v)
  210. @staticmethod
  211. def read(dbfile):
  212. length = dbfile.read_int()
  213. return dbfile.read(length)
  214. @staticmethod
  215. def skip(dbfile):
  216. length = dbfile.read_int()
  217. dbfile.seek(length, 1)
  218. @staticmethod
  219. def to_bytes(v):
  220. return v
  221. class ArrayValues(SequenceValues):
  222. """Stores array.array objects in an FST.
  223. """
  224. def __init__(self, typecode):
  225. self.typecode = typecode
  226. self.itemsize = array(self.typecode).itemsize
  227. def is_valid(self, v):
  228. return isinstance(v, array) and v.typecode == self.typecode
  229. @staticmethod
  230. def write(dbfile, v):
  231. dbfile.write(b(v.typecode))
  232. dbfile.write_int(len(v))
  233. dbfile.write_array(v)
  234. def read(self, dbfile):
  235. typecode = u(dbfile.read(1))
  236. length = dbfile.read_int()
  237. return dbfile.read_array(self.typecode, length)
  238. def skip(self, dbfile):
  239. length = dbfile.read_int()
  240. dbfile.seek(length * self.itemsize, 1)
  241. @staticmethod
  242. def to_bytes(v):
  243. return array_tobytes(v)
  244. class IntListValues(SequenceValues):
  245. """Stores lists of positive, increasing integers (that is, lists of
  246. integers where each number is >= 0 and each number is greater than or equal
  247. to the number that precedes it) in an FST.
  248. """
  249. @staticmethod
  250. def is_valid(v):
  251. if isinstance(v, (list, tuple)):
  252. if len(v) < 2:
  253. return True
  254. for i in xrange(1, len(v)):
  255. if not isinstance(v[i], int) or v[i] < v[i - 1]:
  256. return False
  257. return True
  258. return False
  259. @staticmethod
  260. def write(dbfile, v):
  261. base = 0
  262. dbfile.write_varint(len(v))
  263. for x in v:
  264. delta = x - base
  265. assert delta >= 0
  266. dbfile.write_varint(delta)
  267. base = x
  268. @staticmethod
  269. def read(dbfile):
  270. length = dbfile.read_varint()
  271. result = []
  272. if length > 0:
  273. base = 0
  274. for _ in xrange(length):
  275. base += dbfile.read_varint()
  276. result.append(base)
  277. return result
  278. @staticmethod
  279. def to_bytes(v):
  280. return b(repr(v))
  281. # Node-like interface wrappers
  282. class Node(object):
  283. """A slow but easier-to-use wrapper for FSA/DAWGs. Translates the low-level
  284. arc-based interface of GraphReader into Node objects with methods to follow
  285. edges.
  286. """
  287. def __init__(self, owner, address, accept=False):
  288. self.owner = owner
  289. self.address = address
  290. self._edges = None
  291. self.accept = accept
  292. def __iter__(self):
  293. if not self._edges:
  294. self._load()
  295. return iterkeys(self._edges)
  296. def __contains__(self, key):
  297. if self._edges is None:
  298. self._load()
  299. return key in self._edges
  300. def _load(self):
  301. owner = self.owner
  302. if self.address is None:
  303. d = {}
  304. else:
  305. d = dict((arc.label, Node(owner, arc.target, arc.accept))
  306. for arc in self.owner.iter_arcs(self.address))
  307. self._edges = d
  308. def keys(self):
  309. if self._edges is None:
  310. self._load()
  311. return self._edges.keys()
  312. def all_edges(self):
  313. if self._edges is None:
  314. self._load()
  315. return self._edges
  316. def edge(self, key):
  317. if self._edges is None:
  318. self._load()
  319. return self._edges[key]
  320. def flatten(self, sofar=emptybytes):
  321. if self.accept:
  322. yield sofar
  323. for key in sorted(self):
  324. node = self.edge(key)
  325. for result in node.flatten(sofar + key):
  326. yield result
  327. def flatten_strings(self):
  328. return (utf8decode(k)[0] for k in self.flatten())
  329. class ComboNode(Node):
  330. """Base class for nodes that blend the nodes of two different graphs.
  331. Concrete subclasses need to implement the ``edge()`` method and possibly
  332. override the ``accept`` property.
  333. """
  334. def __init__(self, a, b):
  335. self.a = a
  336. self.b = b
  337. def __repr__(self):
  338. return "<%s %r %r>" % (self.__class__.__name__, self.a, self.b)
  339. def __contains__(self, key):
  340. return key in self.a or key in self.b
  341. def __iter__(self):
  342. return iter(set(self.a) | set(self.b))
  343. @property
  344. def accept(self):
  345. return self.a.accept or self.b.accept
  346. class UnionNode(ComboNode):
  347. """Makes two graphs appear to be the union of the two graphs.
  348. """
  349. def edge(self, key):
  350. a = self.a
  351. b = self.b
  352. if key in a and key in b:
  353. return UnionNode(a.edge(key), b.edge(key))
  354. elif key in a:
  355. return a.edge(key)
  356. else:
  357. return b.edge(key)
  358. class IntersectionNode(ComboNode):
  359. """Makes two graphs appear to be the intersection of the two graphs.
  360. """
  361. def edge(self, key):
  362. a = self.a
  363. b = self.b
  364. if key in a and key in b:
  365. return IntersectionNode(a.edge(key), b.edge(key))
  366. # Cursor
  367. class BaseCursor(object):
  368. """Base class for a cursor-type object for navigating an FST/word graph,
  369. represented by a :class:`GraphReader` object.
  370. >>> cur = GraphReader(dawgfile).cursor()
  371. >>> for key in cur.follow():
  372. ... print(repr(key))
  373. The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
  374. """
  375. def is_active(self):
  376. """Returns True if this cursor is still active, that is it has not
  377. read past the last arc in the graph.
  378. """
  379. raise NotImplementedError
  380. def label(self):
  381. """Returns the label bytes of the current arc.
  382. """
  383. raise NotImplementedError
  384. def prefix(self):
  385. """Returns a sequence of the label bytes for the path from the root
  386. to the current arc.
  387. """
  388. raise NotImplementedError
  389. def prefix_bytes(self):
  390. """Returns the label bytes for the path from the root to the current
  391. arc as a single joined bytes object.
  392. """
  393. return emptybytes.join(self.prefix())
  394. def prefix_string(self):
  395. """Returns the labels of the path from the root to the current arc as
  396. a decoded unicode string.
  397. """
  398. return utf8decode(self.prefix_bytes())[0]
  399. def peek_key(self):
  400. """Returns a sequence of label bytes representing the next closest
  401. key in the graph.
  402. """
  403. for label in self.prefix():
  404. yield label
  405. c = self.copy()
  406. while not c.stopped():
  407. c.follow()
  408. yield c.label()
  409. def peek_key_bytes(self):
  410. """Returns the next closest key in the graph as a single bytes object.
  411. """
  412. return emptybytes.join(self.peek_key())
  413. def peek_key_string(self):
  414. """Returns the next closest key in the graph as a decoded unicode
  415. string.
  416. """
  417. return utf8decode(self.peek_key_bytes())[0]
  418. def stopped(self):
  419. """Returns True if the current arc leads to a stop state.
  420. """
  421. raise NotImplementedError
  422. def value(self):
  423. """Returns the value at the current arc, if reading an FST.
  424. """
  425. raise NotImplementedError
  426. def accept(self):
  427. """Returns True if the current arc leads to an accept state (the end
  428. of a valid key).
  429. """
  430. raise NotImplementedError
  431. def at_last_arc(self):
  432. """Returns True if the current arc is the last outgoing arc from the
  433. previous node.
  434. """
  435. raise NotImplementedError
  436. def next_arc(self):
  437. """Moves to the next outgoing arc from the previous node.
  438. """
  439. raise NotImplementedError
  440. def follow(self):
  441. """Follows the current arc.
  442. """
  443. raise NotImplementedError
  444. def switch_to(self, label):
  445. """Switch to the sibling arc with the given label bytes.
  446. """
  447. _label = self.label
  448. _at_last_arc = self.at_last_arc
  449. _next_arc = self.next_arc
  450. while True:
  451. thislabel = _label()
  452. if thislabel == label:
  453. return True
  454. if thislabel > label or _at_last_arc():
  455. return False
  456. _next_arc()
  457. def skip_to(self, key):
  458. """Moves the cursor to the path represented by the given key bytes.
  459. """
  460. _accept = self.accept
  461. _prefix = self.prefix
  462. _next_arc = self.next_arc
  463. keylist = list(key)
  464. while True:
  465. if _accept():
  466. thiskey = list(_prefix())
  467. if keylist == thiskey:
  468. return True
  469. elif keylist > thiskey:
  470. return False
  471. _next_arc()
  472. def flatten(self):
  473. """Yields the keys in the graph, starting at the current position.
  474. """
  475. _is_active = self.is_active
  476. _accept = self.accept
  477. _stopped = self.stopped
  478. _follow = self.follow
  479. _next_arc = self.next_arc
  480. _prefix_bytes = self.prefix_bytes
  481. if not _is_active():
  482. raise InactiveCursor
  483. while _is_active():
  484. if _accept():
  485. yield _prefix_bytes()
  486. if not _stopped():
  487. _follow()
  488. continue
  489. _next_arc()
  490. def flatten_v(self):
  491. """Yields (key, value) tuples in an FST, starting at the current
  492. position.
  493. """
  494. for key in self.flatten():
  495. yield key, self.value()
  496. def flatten_strings(self):
  497. return (utf8decode(k)[0] for k in self.flatten())
  498. def find_path(self, path):
  499. """Follows the labels in the given path, starting at the current
  500. position.
  501. """
  502. path = to_labels(path)
  503. _switch_to = self.switch_to
  504. _follow = self.follow
  505. _stopped = self.stopped
  506. first = True
  507. for i, label in enumerate(path):
  508. if not first:
  509. _follow()
  510. if not _switch_to(label):
  511. return False
  512. if _stopped():
  513. if i < len(path) - 1:
  514. return False
  515. first = False
  516. return True
  517. class Cursor(BaseCursor):
  518. def __init__(self, graph, root=None, stack=None):
  519. self.graph = graph
  520. self.vtype = graph.vtype
  521. self.root = root if root is not None else graph.default_root()
  522. if stack:
  523. self.stack = stack
  524. else:
  525. self.reset()
  526. def _current_attr(self, name):
  527. stack = self.stack
  528. if not stack:
  529. raise InactiveCursor
  530. return getattr(stack[-1], name)
  531. def is_active(self):
  532. return bool(self.stack)
  533. def stopped(self):
  534. return self._current_attr("target") is None
  535. def accept(self):
  536. return self._current_attr("accept")
  537. def at_last_arc(self):
  538. return self._current_attr("lastarc")
  539. def label(self):
  540. return self._current_attr("label")
  541. def reset(self):
  542. self.stack = []
  543. self.sums = [None]
  544. self._push(self.graph.arc_at(self.root))
  545. def copy(self):
  546. return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
  547. def prefix(self):
  548. stack = self.stack
  549. if not stack:
  550. raise InactiveCursor
  551. return (arc.label for arc in stack)
  552. # Override: more efficient implementation using graph methods directly
  553. def peek_key(self):
  554. if not self.stack:
  555. raise InactiveCursor
  556. for label in self.prefix():
  557. yield label
  558. arc = copy.copy(self.stack[-1])
  559. graph = self.graph
  560. while not arc.accept and arc.target is not None:
  561. graph.arc_at(arc.target, arc)
  562. yield arc.label
  563. def value(self):
  564. stack = self.stack
  565. if not stack:
  566. raise InactiveCursor
  567. vtype = self.vtype
  568. if not vtype:
  569. raise Exception("No value type")
  570. v = self.sums[-1]
  571. current = stack[-1]
  572. if current.value:
  573. v = vtype.add(v, current.value)
  574. if current.accept and current.acceptval is not None:
  575. v = vtype.add(v, current.acceptval)
  576. return v
  577. def next_arc(self):
  578. stack = self.stack
  579. if not stack:
  580. raise InactiveCursor
  581. while stack and stack[-1].lastarc:
  582. self.pop()
  583. if stack:
  584. current = stack[-1]
  585. self.graph.arc_at(current.endpos, current)
  586. return current
  587. def follow(self):
  588. address = self._current_attr("target")
  589. if address is None:
  590. raise Exception("Can't follow a stop arc")
  591. self._push(self.graph.arc_at(address))
  592. return self
  593. # Override: more efficient implementation manipulating the stack
  594. def skip_to(self, key):
  595. key = to_labels(key)
  596. stack = self.stack
  597. if not stack:
  598. raise InactiveCursor
  599. _follow = self.follow
  600. _next_arc = self.next_arc
  601. i = self._pop_to_prefix(key)
  602. while stack and i < len(key):
  603. curlabel = stack[-1].label
  604. keylabel = key[i]
  605. if curlabel == keylabel:
  606. _follow()
  607. i += 1
  608. elif curlabel > keylabel:
  609. return
  610. else:
  611. _next_arc()
  612. # Override: more efficient implementation using find_arc
  613. def switch_to(self, label):
  614. stack = self.stack
  615. if not stack:
  616. raise InactiveCursor
  617. current = stack[-1]
  618. if label == current.label:
  619. return True
  620. else:
  621. arc = self.graph.find_arc(current.endpos, label, current)
  622. return arc
  623. def _push(self, arc):
  624. if self.vtype and self.stack:
  625. sums = self.sums
  626. sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
  627. self.stack.append(arc)
  628. def pop(self):
  629. self.stack.pop()
  630. if self.vtype:
  631. self.sums.pop()
  632. def _pop_to_prefix(self, key):
  633. stack = self.stack
  634. if not stack:
  635. raise InactiveCursor
  636. i = 0
  637. maxpre = min(len(stack), len(key))
  638. while i < maxpre and key[i] == stack[i].label:
  639. i += 1
  640. if stack[i].label > key[i]:
  641. self.current = None
  642. return
  643. while len(stack) > i + 1:
  644. self.pop()
  645. self.next_arc()
  646. return i
  647. class UncompiledNode(object):
  648. # Represents an "in-memory" node used by the GraphWriter before it is
  649. # written to disk.
  650. compiled = False
  651. def __init__(self, owner):
  652. self.owner = owner
  653. self._digest = None
  654. self.clear()
  655. def clear(self):
  656. self.arcs = []
  657. self.value = None
  658. self.accept = False
  659. self.inputcount = 0
  660. def __repr__(self):
  661. return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
  662. def digest(self):
  663. if self._digest is None:
  664. d = sha1()
  665. vtype = self.owner.vtype
  666. for arc in self.arcs:
  667. d.update(arc.label)
  668. if arc.target:
  669. d.update(pack_long(arc.target))
  670. else:
  671. d.update(b("z"))
  672. if arc.value:
  673. d.update(vtype.to_bytes(arc.value))
  674. if arc.accept:
  675. d.update(b("T"))
  676. self._digest = d.digest()
  677. return self._digest
  678. def edges(self):
  679. return self.arcs
  680. def last_value(self, label):
  681. assert self.arcs[-1].label == label
  682. return self.arcs[-1].value
  683. def add_arc(self, label, target):
  684. self.arcs.append(Arc(label, target))
  685. def replace_last(self, label, target, accept, acceptval=None):
  686. arc = self.arcs[-1]
  687. assert arc.label == label, "%r != %r" % (arc.label, label)
  688. arc.target = target
  689. arc.accept = accept
  690. arc.acceptval = acceptval
  691. def delete_last(self, label, target):
  692. arc = self.arcs.pop()
  693. assert arc.label == label
  694. assert arc.target == target
  695. def set_last_value(self, label, value):
  696. arc = self.arcs[-1]
  697. assert arc.label == label, "%r->%r" % (arc.label, label)
  698. arc.value = value
  699. def prepend_value(self, prefix):
  700. add = self.owner.vtype.add
  701. for arc in self.arcs:
  702. arc.value = add(prefix, arc.value)
  703. if self.accept:
  704. self.value = add(prefix, self.value)
  705. class Arc(object):
  706. """
  707. Represents a directed arc between two nodes in an FSA/FST graph.
  708. The ``lastarc`` attribute is True if this is the last outgoing arc from the
  709. previous node.
  710. """
  711. __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
  712. "endpos")
  713. def __init__(self, label=None, target=None, value=None, accept=False,
  714. acceptval=None):
  715. """
  716. :param label: The label bytes for this arc. For a word graph, this will
  717. be a character.
  718. :param target: The address of the node at the endpoint of this arc.
  719. :param value: The inner FST value at the endpoint of this arc.
  720. :param accept: Whether the endpoint of this arc is an accept state
  721. (e.g. the end of a valid word).
  722. :param acceptval: If the endpoint of this arc is an accept state, the
  723. final FST value for that accepted state.
  724. """
  725. self.label = label
  726. self.target = target
  727. self.value = value
  728. self.accept = accept
  729. self.lastarc = None
  730. self.acceptval = acceptval
  731. self.endpos = None
  732. def __repr__(self):
  733. return "<%r-%s %s%s>" % (self.label, self.target,
  734. "." if self.accept else "",
  735. (" %r" % self.value) if self.value else "")
  736. def __eq__(self, other):
  737. if (isinstance(other, self.__class__) and self.accept == other.accept
  738. and self.lastarc == other.lastarc and self.target == other.target
  739. and self.value == other.value and self.label == other.label):
  740. return True
  741. return False
  742. # Graph writer
  743. class GraphWriter(object):
  744. """Writes an FSA/FST graph to disk.
  745. Call ``insert(key)`` to insert keys into the graph. You must
  746. insert keys in sorted order. Call ``close()`` to finish the graph and close
  747. the file.
  748. >>> gw = GraphWriter(my_file)
  749. >>> gw.insert("alfa")
  750. >>> gw.insert("bravo")
  751. >>> gw.insert("charlie")
  752. >>> gw.close()
  753. The graph writer can write separate graphs for multiple fields. Use
  754. ``start_field(name)`` and ``finish_field()`` to separate fields.
  755. >>> gw = GraphWriter(my_file)
  756. >>> gw.start_field("content")
  757. >>> gw.insert("alfalfa")
  758. >>> gw.insert("apple")
  759. >>> gw.finish_field()
  760. >>> gw.start_field("title")
  761. >>> gw.insert("artichoke")
  762. >>> gw.finish_field()
  763. >>> gw.close()
  764. """
  765. version = 1
  766. def __init__(self, dbfile, vtype=None, merge=None):
  767. """
  768. :param dbfile: the file to write to.
  769. :param vtype: a :class:`Values` class to use for storing values. This
  770. is only necessary if you will be storing values for the keys.
  771. :param merge: a function that takes two values and returns a single
  772. value. This is called if you insert two identical keys with values.
  773. """
  774. self.dbfile = dbfile
  775. self.vtype = vtype
  776. self.merge = merge
  777. self.fieldroots = {}
  778. self.arc_count = 0
  779. self.node_count = 0
  780. self.fixed_count = 0
  781. dbfile.write(b("GRPH"))
  782. dbfile.write_int(self.version)
  783. dbfile.write_uint(0)
  784. self._infield = False
  785. def start_field(self, fieldname):
  786. """Starts a new graph for the given field.
  787. """
  788. if not fieldname:
  789. raise ValueError("Field name cannot be equivalent to False")
  790. if self._infield:
  791. self.finish_field()
  792. self.fieldname = fieldname
  793. self.seen = {}
  794. self.nodes = [UncompiledNode(self)]
  795. self.lastkey = ''
  796. self._inserted = False
  797. self._infield = True
  798. def finish_field(self):
  799. """Finishes the graph for the current field.
  800. """
  801. if not self._infield:
  802. raise Exception("Called finish_field before start_field")
  803. self._infield = False
  804. if self._inserted:
  805. self.fieldroots[self.fieldname] = self._finish()
  806. self.fieldname = None
  807. def close(self):
  808. """Finishes the current graph and closes the underlying file.
  809. """
  810. if self.fieldname is not None:
  811. self.finish_field()
  812. dbfile = self.dbfile
  813. here = dbfile.tell()
  814. dbfile.write_pickle(self.fieldroots)
  815. dbfile.flush()
  816. dbfile.seek(4 + _INT_SIZE) # Seek past magic and version number
  817. dbfile.write_uint(here)
  818. dbfile.close()
  819. def insert(self, key, value=None):
  820. """Inserts the given key into the graph.
  821. :param key: a sequence of bytes objects, a bytes object, or a string.
  822. :param value: an optional value to encode in the graph along with the
  823. key. If the writer was not instantiated with a value type, passing
  824. a value here will raise an error.
  825. """
  826. if not self._infield:
  827. raise Exception("Inserted %r before starting a field" % key)
  828. self._inserted = True
  829. key = to_labels(key) # Python 3 sucks
  830. vtype = self.vtype
  831. lastkey = self.lastkey
  832. nodes = self.nodes
  833. if len(key) < 1:
  834. raise KeyError("Can't store a null key %r" % (key,))
  835. if lastkey and lastkey > key:
  836. raise KeyError("Keys out of order %r..%r" % (lastkey, key))
  837. # Find the common prefix shared by this key and the previous one
  838. prefixlen = 0
  839. for i in xrange(min(len(lastkey), len(key))):
  840. if lastkey[i] != key[i]:
  841. break
  842. prefixlen += 1
  843. # Compile the nodes after the prefix, since they're not shared
  844. self._freeze_tail(prefixlen + 1)
  845. # Create new nodes for the parts of this key after the shared prefix
  846. for char in key[prefixlen:]:
  847. node = UncompiledNode(self)
  848. # Create an arc to this node on the previous node
  849. nodes[-1].add_arc(char, node)
  850. nodes.append(node)
  851. # Mark the last node as an accept state
  852. lastnode = nodes[-1]
  853. lastnode.accept = True
  854. if vtype:
  855. if value is not None and not vtype.is_valid(value):
  856. raise ValueError("%r is not valid for %s" % (value, vtype))
  857. # Push value commonalities through the tree
  858. common = None
  859. for i in xrange(1, prefixlen + 1):
  860. node = nodes[i]
  861. parent = nodes[i - 1]
  862. lastvalue = parent.last_value(key[i - 1])
  863. if lastvalue is not None:
  864. common = vtype.common(value, lastvalue)
  865. suffix = vtype.subtract(lastvalue, common)
  866. parent.set_last_value(key[i - 1], common)
  867. node.prepend_value(suffix)
  868. else:
  869. common = suffix = None
  870. value = vtype.subtract(value, common)
  871. if key == lastkey:
  872. # If this key is a duplicate, merge its value with the value of
  873. # the previous (same) key
  874. lastnode.value = self.merge(lastnode.value, value)
  875. else:
  876. nodes[prefixlen].set_last_value(key[prefixlen], value)
  877. elif value:
  878. raise Exception("Value %r but no value type" % value)
  879. self.lastkey = key
  880. def _freeze_tail(self, prefixlen):
  881. nodes = self.nodes
  882. lastkey = self.lastkey
  883. downto = max(1, prefixlen)
  884. while len(nodes) > downto:
  885. node = nodes.pop()
  886. parent = nodes[-1]
  887. inlabel = lastkey[len(nodes) - 1]
  888. self._compile_targets(node)
  889. accept = node.accept or len(node.arcs) == 0
  890. address = self._compile_node(node)
  891. parent.replace_last(inlabel, address, accept, node.value)
  892. def _finish(self):
  893. nodes = self.nodes
  894. root = nodes[0]
  895. # Minimize nodes in the last word's suffix
  896. self._freeze_tail(0)
  897. # Compile remaining targets
  898. self._compile_targets(root)
  899. return self._compile_node(root)
  900. def _compile_targets(self, node):
  901. for arc in node.arcs:
  902. if isinstance(arc.target, UncompiledNode):
  903. n = arc.target
  904. if len(n.arcs) == 0:
  905. arc.accept = n.accept = True
  906. arc.target = self._compile_node(n)
  907. def _compile_node(self, uncnode):
  908. seen = self.seen
  909. if len(uncnode.arcs) == 0:
  910. # Leaf node
  911. address = self._write_node(uncnode)
  912. else:
  913. d = uncnode.digest()
  914. address = seen.get(d)
  915. if address is None:
  916. address = self._write_node(uncnode)
  917. seen[d] = address
  918. return address
  919. def _write_node(self, uncnode):
  920. vtype = self.vtype
  921. dbfile = self.dbfile
  922. arcs = uncnode.arcs
  923. numarcs = len(arcs)
  924. if not numarcs:
  925. if uncnode.accept:
  926. return None
  927. else:
  928. # What does it mean for an arc to stop but not be accepted?
  929. raise Exception
  930. self.node_count += 1
  931. buf = StructFile(BytesIO())
  932. nodestart = dbfile.tell()
  933. #self.count += 1
  934. #self.arccount += numarcs
  935. fixedsize = -1
  936. arcstart = buf.tell()
  937. for i, arc in enumerate(arcs):
  938. self.arc_count += 1
  939. target = arc.target
  940. label = arc.label
  941. flags = 0
  942. if len(label) > 1:
  943. flags += MULTIBYTE_LABEL
  944. if i == numarcs - 1:
  945. flags += ARC_LAST
  946. if arc.accept:
  947. flags += ARC_ACCEPT
  948. if target is None:
  949. flags += ARC_STOP
  950. if arc.value is not None:
  951. flags += ARC_HAS_VAL
  952. if arc.acceptval is not None:
  953. flags += ARC_HAS_ACCEPT_VAL
  954. buf.write(pack_byte(flags))
  955. if len(label) > 1:
  956. buf.write(varint(len(label)))
  957. buf.write(label)
  958. if target is not None:
  959. buf.write(pack_uint(target))
  960. if arc.value is not None:
  961. vtype.write(buf, arc.value)
  962. if arc.acceptval is not None:
  963. vtype.write(buf, arc.acceptval)
  964. here = buf.tell()
  965. thissize = here - arcstart
  966. arcstart = here
  967. if fixedsize == -1:
  968. fixedsize = thissize
  969. elif fixedsize > 0 and thissize != fixedsize:
  970. fixedsize = 0
  971. if fixedsize > 0:
  972. # Write a fake arc containing the fixed size and number of arcs
  973. dbfile.write_byte(255) # FIXED_SIZE
  974. dbfile.write_int(fixedsize)
  975. dbfile.write_int(numarcs)
  976. self.fixed_count += 1
  977. dbfile.write(buf.file.getvalue())
  978. return nodestart
  979. # Graph reader
  980. class BaseGraphReader(object):
  981. def cursor(self, rootname=None):
  982. return Cursor(self, self.root(rootname))
  983. def has_root(self, rootname):
  984. raise NotImplementedError
  985. def root(self, rootname=None):
  986. raise NotImplementedError
  987. # Low level methods
  988. def arc_at(self, address, arc):
  989. raise NotImplementedError
  990. def iter_arcs(self, address, arc=None):
  991. raise NotImplementedError
  992. def find_arc(self, address, label, arc=None):
  993. arc = arc or Arc()
  994. for arc in self.iter_arcs(address, arc):
  995. thislabel = arc.label
  996. if thislabel == label:
  997. return arc
  998. elif thislabel > label:
  999. return None
  1000. # Convenience methods
  1001. def list_arcs(self, address):
  1002. return list(copy.copy(arc) for arc in self.iter_arcs(address))
  1003. def arc_dict(self, address):
  1004. return dict((arc.label, copy.copy(arc))
  1005. for arc in self.iter_arcs(address))
  1006. def find_path(self, path, arc=None, address=None):
  1007. path = to_labels(path)
  1008. if arc:
  1009. address = arc.target
  1010. else:
  1011. arc = Arc()
  1012. if address is None:
  1013. address = self._root
  1014. for label in path:
  1015. if address is None:
  1016. return None
  1017. if not self.find_arc(address, label, arc):
  1018. return None
  1019. address = arc.target
  1020. return arc
  1021. class GraphReader(BaseGraphReader):
  1022. def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
  1023. self.dbfile = dbfile
  1024. self.vtype = vtype
  1025. self.filebase = filebase
  1026. dbfile.seek(filebase)
  1027. magic = dbfile.read(4)
  1028. if magic != b("GRPH"):
  1029. raise FileVersionError
  1030. self.version = dbfile.read_int()
  1031. dbfile.seek(dbfile.read_uint())
  1032. self.roots = dbfile.read_pickle()
  1033. self._root = None
  1034. if rootname is None and len(self.roots) == 1:
  1035. # If there's only one root, just use it. Have to wrap a list around
  1036. # the keys() method here because of Python 3.
  1037. rootname = list(self.roots.keys())[0]
  1038. if rootname is not None:
  1039. self._root = self.root(rootname)
  1040. def close(self):
  1041. self.dbfile.close()
  1042. # Overrides
  1043. def has_root(self, rootname):
  1044. return rootname in self.roots
  1045. def root(self, rootname=None):
  1046. if rootname is None:
  1047. return self._root
  1048. else:
  1049. return self.roots[rootname]
  1050. def default_root(self):
  1051. return self._root
  1052. def arc_at(self, address, arc=None):
  1053. arc = arc or Arc()
  1054. self.dbfile.seek(address)
  1055. return self._read_arc(arc)
  1056. def iter_arcs(self, address, arc=None):
  1057. arc = arc or Arc()
  1058. _read_arc = self._read_arc
  1059. self.dbfile.seek(address)
  1060. while True:
  1061. _read_arc(arc)
  1062. yield arc
  1063. if arc.lastarc:
  1064. break
  1065. def find_arc(self, address, label, arc=None):
  1066. arc = arc or Arc()
  1067. dbfile = self.dbfile
  1068. dbfile.seek(address)
  1069. # If records are fixed size, we can do a binary search
  1070. finfo = self._read_fixed_info()
  1071. if finfo:
  1072. size, count = finfo
  1073. address = dbfile.tell()
  1074. if count > 2:
  1075. return self._binary_search(address, size, count, label, arc)
  1076. # If records aren't fixed size, fall back to the parent's linear
  1077. # search method
  1078. return BaseGraphReader.find_arc(self, address, label, arc)
  1079. # Implementations
  1080. def _read_arc(self, toarc=None):
  1081. toarc = toarc or Arc()
  1082. dbfile = self.dbfile
  1083. flags = dbfile.read_byte()
  1084. if flags == 255:
  1085. # This is a fake arc containing fixed size information; skip it
  1086. # and read the next arc
  1087. dbfile.seek(_INT_SIZE * 2, 1)
  1088. flags = dbfile.read_byte()
  1089. toarc.label = self._read_label(flags)
  1090. return self._read_arc_data(flags, toarc)
  1091. def _read_label(self, flags):
  1092. dbfile = self.dbfile
  1093. if flags & MULTIBYTE_LABEL:
  1094. length = dbfile.read_varint()
  1095. else:
  1096. length = 1
  1097. label = dbfile.read(length)
  1098. return label
  1099. def _read_fixed_info(self):
  1100. dbfile = self.dbfile
  1101. flags = dbfile.read_byte()
  1102. if flags == 255:
  1103. size = dbfile.read_int()
  1104. count = dbfile.read_int()
  1105. return (size, count)
  1106. else:
  1107. return None
  1108. def _read_arc_data(self, flags, arc):
  1109. dbfile = self.dbfile
  1110. accept = arc.accept = bool(flags & ARC_ACCEPT)
  1111. arc.lastarc = flags & ARC_LAST
  1112. if flags & ARC_STOP:
  1113. arc.target = None
  1114. else:
  1115. arc.target = dbfile.read_uint()
  1116. if flags & ARC_HAS_VAL:
  1117. arc.value = self.vtype.read(dbfile)
  1118. else:
  1119. arc.value = None
  1120. if accept and flags & ARC_HAS_ACCEPT_VAL:
  1121. arc.acceptval = self.vtype.read(dbfile)
  1122. arc.endpos = dbfile.tell()
  1123. return arc
  1124. def _binary_search(self, address, size, count, label, arc):
  1125. dbfile = self.dbfile
  1126. _read_label = self._read_label
  1127. lo = 0
  1128. hi = count
  1129. while lo < hi:
  1130. mid = (lo + hi) // 2
  1131. midaddr = address + mid * size
  1132. dbfile.seek(midaddr)
  1133. flags = dbfile.read_byte()
  1134. midlabel = self._read_label(flags)
  1135. if midlabel == label:
  1136. arc.label = midlabel
  1137. return self._read_arc_data(flags, arc)
  1138. elif midlabel < label:
  1139. lo = mid + 1
  1140. else:
  1141. hi = mid
  1142. if lo == count:
  1143. return None
  1144. def to_labels(key):
  1145. """Takes a string and returns a list of bytestrings, suitable for use as
  1146. a key or path in an FSA/FST graph.
  1147. """
  1148. # Convert to tuples of bytestrings (must be tuples so they can be hashed)
  1149. keytype = type(key)
  1150. # I hate the Python 3 bytes object so friggin much
  1151. if keytype is tuple or keytype is list:
  1152. if not all(isinstance(e, bytes_type) for e in key):
  1153. raise TypeError("%r contains a non-bytestring" % key)
  1154. if keytype is list:
  1155. key = tuple(key)
  1156. elif isinstance(key, bytes_type):
  1157. key = tuple(key[i:i + 1] for i in xrange(len(key)))
  1158. elif isinstance(key, text_type):
  1159. key = tuple(utf8encode(key[i:i + 1])[0] for i in xrange(len(key)))
  1160. else:
  1161. raise TypeError("Don't know how to convert %r" % key)
  1162. return key
  1163. # Within edit distance function
  1164. def within(graph, text, k=1, prefix=0, address=None):
  1165. """Yields a series of keys in the given graph within ``k`` edit distance of
  1166. ``text``. If ``prefix`` is greater than 0, all keys must match the first
  1167. ``prefix`` characters of ``text``.
  1168. """
  1169. text = to_labels(text)
  1170. if address is None:
  1171. address = graph._root
  1172. sofar = emptybytes
  1173. accept = False
  1174. if prefix:
  1175. prefixchars = text[:prefix]
  1176. arc = graph.find_path(prefixchars, address=address)
  1177. if arc is None:
  1178. return
  1179. sofar = emptybytes.join(prefixchars)
  1180. address = arc.target
  1181. accept = arc.accept
  1182. stack = [(address, k, prefix, sofar, accept)]
  1183. seen = set()
  1184. while stack:
  1185. state = stack.pop()
  1186. # Have we already tried this state?
  1187. if state in seen:
  1188. continue
  1189. seen.add(state)
  1190. address, k, i, sofar, accept = state
  1191. # If we're at the end of the text (or deleting enough chars would get
  1192. # us to the end and still within K), and we're in the accept state,
  1193. # yield the current result
  1194. if (len(text) - i <= k) and accept:
  1195. yield utf8decode(sofar)[0]
  1196. # If we're in the stop state, give up
  1197. if address is None:
  1198. continue
  1199. # Exact match
  1200. if i < len(text):
  1201. arc = graph.find_arc(address, text[i])
  1202. if arc:
  1203. stack.append((arc.target, k, i + 1, sofar + text[i],
  1204. arc.accept))
  1205. # If K is already 0, can't do any more edits
  1206. if k < 1:
  1207. continue
  1208. k -= 1
  1209. arcs = graph.arc_dict(address)
  1210. # Insertions
  1211. stack.extend((arc.target, k, i, sofar + char, arc.accept)
  1212. for char, arc in iteritems(arcs))
  1213. # Deletion, replacement, and transpo only work before the end
  1214. if i >= len(text):
  1215. continue
  1216. char = text[i]
  1217. # Deletion
  1218. stack.append((address, k, i + 1, sofar, False))
  1219. # Replacement
  1220. for char2, arc in iteritems(arcs):
  1221. if char2 != char:
  1222. stack.append((arc.target, k, i + 1, sofar + char2, arc.accept))
  1223. # Transposition
  1224. if i < len(text) - 1:
  1225. char2 = text[i + 1]
  1226. if char != char2 and char2 in arcs:
  1227. # Find arc from next char to this char
  1228. target = arcs[char2].target
  1229. if target:
  1230. arc = graph.find_arc(target, char)
  1231. if arc:
  1232. stack.append((arc.target, k, i + 2,
  1233. sofar + char2 + char, arc.accept))
  1234. # Utility functions
  1235. def dump_graph(graph, address=None, tab=0, out=None):
  1236. if address is None:
  1237. address = graph._root
  1238. if out is None:
  1239. out = sys.stdout
  1240. here = "%06d" % address
  1241. for i, arc in enumerate(graph.list_arcs(address)):
  1242. if i == 0:
  1243. out.write(here)
  1244. else:
  1245. out.write(" " * 6)
  1246. out.write(" " * tab)
  1247. out.write("%r %r %s %r\n"
  1248. % (arc.label, arc.target, arc.accept, arc.value))
  1249. if arc.target is not None:
  1250. dump_graph(graph, arc.target, tab + 1, out=out)