PageRenderTime 827ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 1ms

/nltk_contrib/nltk_contrib/tiger/query/result.py

http://nltk.googlecode.com/
Python | 342 lines | 278 code | 47 blank | 17 comment | 64 complexity | 6b62efa3e4e88ca6151a1088f850dc5c MD5 | raw file
Possible License(s): Apache-2.0, AGPL-1.0
  1. # -*- coding: utf-8 -*-
  2. # Copyright Š 2007-2008 Stockholm TreeAligner Project
  3. # Author: Torsten Marek <shlomme@gmx.net>
  4. # Licensed under the GNU GPLv2
  5. """This module contains classes for building result sets for TIGERSearch queries .
  6. The result builder classes evaluate a TigerSearch query based on the internal
  7. representation of the query.
  8. The algorithm and the interfaces in this module are still subject to heavy change. For more
  9. information, see the inline comments.
  10. """
  11. import operator
  12. from functools import partial
  13. from itertools import count, izip
  14. from collections import defaultdict
  15. from nltk_contrib.tiger.index import IndexNodeId
  16. from nltk_contrib.tiger.query.exceptions import MissingFeatureError
  17. from nltk_contrib.tiger.utils.parallel import multiprocessing
  18. from nltk_contrib.tiger.query.constraints import Direction
  19. from nltk_contrib.tiger.query.nodesearcher import NodeSearcher, EqualPartitionsGraphFilter
  20. __all__ = ["ResultBuilder", "ParallelResultBuilder"]
  21. product = partial(reduce, operator.mul)
  22. def named_cross_product(items):
  23. def _outer_product(depth, combination):
  24. varname, nodes = items[-depth]
  25. for node in nodes:
  26. combination[varname] = node
  27. if depth == 1:
  28. yield combination.copy()
  29. else:
  30. for res in _outer_product(depth - 1, combination):
  31. yield res
  32. if items:
  33. return _outer_product(len(items), {})
  34. else:
  35. return iter([{}])
  36. def partition_variables(variables, constraints):
  37. var_connections = dict(izip(variables, count()))
  38. for l, r in constraints:
  39. new_id = var_connections[l]
  40. old_id = var_connections[r]
  41. for name, value in var_connections.iteritems():
  42. if value == old_id:
  43. var_connections[name] = new_id
  44. sets = defaultdict(set)
  45. for name, value in var_connections.iteritems():
  46. sets[value].add(name)
  47. return sets.values()
  48. class ConstraintChecker(object):
  49. @classmethod
  50. def _nodevar_idx_combinations(cls, ordered_node_vars):
  51. return [(upper_key, lower_key)
  52. for lower_key in xrange(1, len(ordered_node_vars))
  53. for upper_key in xrange(lower_key)]
  54. @classmethod
  55. def _get_node_variables(cls, constraints):
  56. return set(var for var_pair in constraints for var in var_pair)
  57. @classmethod
  58. def prepare(cls, constraints, sizes = {}):
  59. constraints = dict(constraints)
  60. set_weight = sum(sizes.values()) + 1
  61. ordered_node_vars = sorted(
  62. cls._get_node_variables(constraints),
  63. key = lambda k: set_weight if k.is_set else sizes.get(k, 0))
  64. ordered_constraints = []
  65. for (upper_idx, lower_idx) in cls._nodevar_idx_combinations(ordered_node_vars):
  66. var_pair = (ordered_node_vars[upper_idx], ordered_node_vars[lower_idx])
  67. fail_after_success = False
  68. if var_pair in constraints:
  69. constraint = constraints[var_pair].check
  70. direction = constraints[var_pair].get_singlematch_direction()
  71. if direction is Direction.BOTH or direction is Direction.LEFT_TO_RIGHT:
  72. fail_after_success = True
  73. ordered_constraints.append((var_pair[0], var_pair[1], constraint, False, fail_after_success))
  74. elif var_pair[::-1] in constraints:
  75. constraint = constraints[var_pair[::-1]].check
  76. direction = constraints[var_pair[::-1]].get_singlematch_direction()
  77. if direction is Direction.BOTH or direction is Direction.RIGHT_TO_LEFT:
  78. fail_after_success = True
  79. ordered_constraints.append((var_pair[0], var_pair[1], constraint, True, fail_after_success))
  80. return partial(ConstraintChecker, ordered_constraints)
  81. def __init__(self, constraints, nodes, query_context):
  82. self.ordered_constraints = constraints
  83. self.nodes = nodes
  84. self.ok = set()
  85. self.has_results = self.prefilter(query_context)
  86. def prefilter(self, query_context):
  87. for (left_var, right_var, constraint, exchange, fail_after_success) in self.ordered_constraints:
  88. l_success = set()
  89. r_success = set()
  90. for left in self.nodes[left_var]:
  91. ldata = query_context.get_node(left)
  92. for right in self.nodes[right_var]:
  93. rdata = query_context.get_node(right)
  94. if exchange:
  95. larg, rarg = rdata, ldata
  96. else:
  97. larg, rarg = ldata, rdata
  98. query_context.constraint_checks += 1
  99. if constraint(larg, rarg, query_context):
  100. self.ok.add((left, right))
  101. if not right_var.is_set:
  102. l_success.add(left)
  103. r_success.add(right)
  104. if fail_after_success:
  105. break
  106. elif right_var.is_set:
  107. break
  108. else:
  109. if right_var.is_set:
  110. l_success.add(left)
  111. r_success.update(self.nodes[right_var])
  112. if not l_success:
  113. return False
  114. if not r_success and not right_var.is_set:
  115. return False
  116. self.nodes[left_var] = list(l_success)
  117. self.nodes[right_var] = list(r_success)
  118. return True
  119. def _nodeids(self, query_result):
  120. for node_var in query_result:
  121. query_result[node_var] = IndexNodeId.from_int(query_result[node_var])
  122. return query_result
  123. def extract(self):
  124. """Creates the result set.
  125. The function currently uses a brute-force attempt. Please see the TODOs at the top
  126. of the module.
  127. """
  128. if self.has_results:
  129. g = [item for item in self.nodes.items() if not item[0].is_set]
  130. return [self._nodeids(query_result)
  131. for query_result in named_cross_product(g)
  132. if self._check(query_result)]
  133. else:
  134. return []
  135. def _check(self, result):
  136. for (left_variable, right_variable, __, __, __) in self.ordered_constraints:
  137. if right_variable.is_set:
  138. for right_node in self.nodes[right_variable]:
  139. if (result[left_variable], right_node) not in self.ok:
  140. return False
  141. else:
  142. if (result[left_variable], result[right_variable]) not in self.ok:
  143. return False
  144. return True
  145. PREPARE_NEW_AFTER = 100
  146. def cct_search(graph_results, query_context):
  147. query_context._ncache.clear()
  148. query_context.checked_graphs += 1
  149. if query_context.checked_graphs == PREPARE_NEW_AFTER:
  150. query_context.checker_factory = ConstraintChecker.prepare(query_context.constraints, query_context.node_counts)
  151. elif query_context.checked_graphs < PREPARE_NEW_AFTER:
  152. for node_var, node_ids in graph_results.iteritems():
  153. query_context.node_counts[node_var] += len(node_ids)
  154. c = query_context.checker_factory(graph_results, query_context)
  155. return c.extract()
  156. class LazyResultSet(object):
  157. def __init__(self, nodes, query_context):
  158. query_context.checked_graphs += 1
  159. self._nodes = [(node_var.name, [IndexNodeId.from_int(nid) for nid in node_ids])
  160. for node_var, node_ids in nodes.iteritems()
  161. if not node_var.is_set]
  162. self._size = product((len(ids) for var, ids in self._nodes), 1)
  163. self._items = None
  164. def __len__(self):
  165. return self._size
  166. def __getitem__(self, idx):
  167. if self._items is None:
  168. self._items = list(iter(self))
  169. return self._items[idx]
  170. def __iter__(self):
  171. return named_cross_product(self._nodes)
  172. class QueryContext(object):
  173. def __init__(self, db, constraints, nodevars):
  174. self.cursor = db.cursor()
  175. self._ncache = {}
  176. self.constraints = constraints
  177. self.node_counts = defaultdict(int)
  178. variable_partitions = partition_variables(nodevars, (c[0] for c in constraints))
  179. if len(variable_partitions) == len(nodevars):
  180. self.constraint_checker = LazyResultSet
  181. elif len(variable_partitions) == 1:
  182. self.checker_factory = ConstraintChecker.prepare(constraints)
  183. self.constraint_checker = cct_search
  184. else:
  185. raise MissingFeatureError, "Missing feature: disjoint constraint sets. Please file a bug report."
  186. self._reset_stats()
  187. def _reset_stats(self):
  188. self.node_cache_hits = 0
  189. self.node_cache_misses = 0
  190. self.checked_graphs = 0
  191. self.constraint_checks = 0
  192. def get_node(self, node_id):
  193. try:
  194. self.node_cache_hits += 1
  195. return self._ncache[node_id]
  196. except KeyError:
  197. self.node_cache_misses += 1
  198. self.cursor.execute("""SELECT id, edge_label,
  199. continuity, left_corner, right_corner, token_order, gorn_address
  200. FROM node_data WHERE id = ?""", (node_id, ))
  201. rs = self._ncache[node_id] = self.cursor.fetchone()
  202. return rs
  203. class ResultBuilderBase(object):
  204. def __init__(self, node_descriptions, predicates):
  205. self._nodes = node_descriptions
  206. self._predicates = predicates
  207. def node_variable_names(self):
  208. """Returns the set of node variables defined in the query."""
  209. return frozenset(nv.name for nv in self._nodes)
  210. class ResultBuilder(QueryContext, ResultBuilderBase):
  211. def __init__(self, ev_context, node_descriptions, predicates, constraints):
  212. QueryContext.__init__(self, ev_context.db, constraints, node_descriptions.keys())
  213. ResultBuilderBase.__init__(self, node_descriptions, predicates)
  214. self._nodesearcher = ev_context.nodesearcher
  215. def evaluate(self):
  216. """Evaluates the query.
  217. Returns a list of `(graph_number, results)` tuples, where `results` is a list of
  218. dictionaries that contains `variable: node_id` pairs for all defined variable names.
  219. """
  220. self._reset_stats()
  221. matching_graphs = self._nodesearcher.search_nodes(self._nodes, self._predicates)
  222. return filter(operator.itemgetter(1),
  223. ((graph_id, self.constraint_checker(nodes, self))
  224. for graph_id, nodes in matching_graphs))
  225. class ParallelEvaluatorContext(object):
  226. def __init__(self, db_provider, graph_filter):
  227. self.db = db_provider.connect()
  228. self.nodesearcher = NodeSearcher(self.db, graph_filter)
  229. self.db_provider = db_provider
  230. def evaluate_parallel(db_provider, nodes, predicates, constraints, result_queue, graph_filter):
  231. ev_ctx = ParallelEvaluatorContext(db_provider, graph_filter)
  232. query = ResultBuilder(ev_ctx, nodes, predicates, constraints)
  233. result_set = query.evaluate()
  234. result_queue.put((result_set, (query.checked_graphs, query.constraint_checks,
  235. query.node_cache_hits, query.node_cache_misses)))
  236. result_queue.close()
  237. class ParallelResultBuilder(ResultBuilderBase):
  238. def __init__(self, ev_context, node_descriptions, predicates, constraints):
  239. super(self.__class__, self).__init__(node_descriptions, predicates)
  240. self._constraints = constraints
  241. self._db_provider = ev_context.db_provider
  242. self._reset_stats()
  243. def _reset_stats(self):
  244. self.node_cache_hits = 0
  245. self.node_cache_misses = 0
  246. self.checked_graphs = 0
  247. self.constraint_checks = 0
  248. def evaluate(self):
  249. self._reset_stats()
  250. result_queue = multiprocessing.Queue()
  251. num_workers = multiprocessing.cpuCount()
  252. workers = []
  253. for i in range(num_workers):
  254. worker = multiprocessing.Process(
  255. target = evaluate_parallel,
  256. args = (self._db_provider, self._nodes, self._predicates,
  257. self._constraints, result_queue,
  258. EqualPartitionsGraphFilter(i, num_workers)))
  259. worker.start()
  260. workers.append(worker)
  261. results = []
  262. running_workers = num_workers
  263. while running_workers > 0:
  264. partial_result, stats = result_queue.get()
  265. results.extend(partial_result)
  266. self.checked_graphs += stats[0]
  267. self.constraint_checks += stats[1]
  268. self.node_cache_hits += stats[2]
  269. self.node_cache_misses += stats[3]
  270. running_workers -= 1
  271. for worker in workers:
  272. worker.join()
  273. return results