PageRenderTime 50ms CodeModel.GetById 27ms RepoModel.GetById 0ms app.codeStats 0ms

/nltk_contrib/tiger/query/factory.py

https://github.com/nltk/nltk_contrib
Python | 256 lines | 141 code | 38 blank | 77 comment | 24 complexity | 0c35fc93962c4a724b2b8d0888ed6095 MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. # Copyright © 2007-2008 Stockholm TreeAligner Project
  3. # Author: Torsten Marek <shlomme@gmx.net>
  4. # Licensed under the GNU GPLv2
  5. """This module contains classes to create a result builder from a query AST.
  6. """
  7. from collections import defaultdict
  8. from itertools import count
  9. from nltk_contrib.tiger.graph import NodeType
  10. from nltk_contrib.tiger.query import ast_visitor
  11. from nltk_contrib.tiger.query.ast_utils import create_varref, NodeDescriptionNormalizer
  12. from nltk_contrib.tiger.query.node_variable import NodeVariable
  13. from nltk_contrib.tiger.query import ast
  14. from nltk_contrib.tiger.query.predicates import PredicateFactory, NodeTypePredicate
  15. from nltk_contrib.tiger.query.constraints import ConstraintFactory
  16. from nltk_contrib.tiger.query.exceptions import TigerTypeError, UndefinedNameError
  17. __all__ = ["QueryFactory"]
  18. class NodeTypeInferencer(ast_visitor.AstVisitor):
  19. """An AST visitor that processes a node description and infers the type
  20. of the node variable using the feature constraints.
  21. *Parameters*:
  22. * `terminal_features`: the set of features on T nodes
  23. * `nonterminal_features`: the set of features on NT nodes
  24. """
  25. def __init__(self, feature_types):
  26. super(self.__class__, self).__init__()
  27. self._type_assoc = feature_types
  28. def setup(self, *args):
  29. """Prepares the typer for a new variable."""
  30. self._types = set()
  31. self._disjoints = [self._types]
  32. self._has_frec = False
  33. @ast_visitor.post_child_handler(ast.Disjunction)
  34. def after_disjunction_subexpr(self, node, child_idx):
  35. """Handles disjunctions."""
  36. if child_idx < len(node.children) - 1:
  37. self._types = set()
  38. self._disjoints.append(self._types)
  39. @ast_visitor.node_handler(ast.FeatureConstraint)
  40. def feature_constraints(self, node):
  41. """Adds type information based on the feature name."""
  42. try:
  43. self._types.add(self._type_assoc[node.feature])
  44. except KeyError:
  45. raise UndefinedNameError, (UndefinedNameError.FEATURE, node.feature)
  46. return self.STOP
  47. @ast_visitor.node_handler(ast.FeatureRecord)
  48. def feature_record(self, node):
  49. """Adds the type specified in a feature record."""
  50. self._types.add(node.type)
  51. self._has_frec = True
  52. return self.STOP
  53. def result(self, query_ast, node_variable):
  54. """Returns the type inferred type of the node variable.
  55. The return value is a member of the enum `nltk_contrib.tiger.graph.NodeType`.
  56. If the feature refer to conflicting node types, a `TigerTypeError` is raised.
  57. """
  58. node_var_type = set()
  59. for disj in self._disjoints:
  60. if len(disj) == 2:
  61. raise TigerTypeError, node_variable.name
  62. else:
  63. node_var_type.update(disj)
  64. if len(node_var_type) == 1:
  65. return (list(node_var_type)[0], self._has_frec)
  66. else:
  67. return (NodeType.UNKNOWN, False)
  68. class QueryFactory(ast_visitor.AstVisitor):
  69. """Creates the internal representation from a query AST.
  70. A query AST is split into three parts:
  71. * node descriptions: a dictionary of `varname: AST`
  72. * predicates: a dictionary with `varname: nltk_contrib.tiger.query.predicates.Predicate` entries
  73. * constraints: a list of ((left, right), nltk_contrib.tiger.query.constraints.Constraint)` tuples
  74. These collections will be used to instantiate a `Query` object.
  75. Anonymous node descriptions will be wrapped into a variable definition with an
  76. automatically generated, globally unique variable name.
  77. """
  78. get_anon_nodevar = (":anon:%i" % (c, ) for c in count()).next
  79. constraint_factory = ConstraintFactory()
  80. predicate_factory = PredicateFactory()
  81. def __init__(self, ev_context):
  82. super(self.__class__, self).__init__()
  83. self.nodedesc_normalizer = NodeDescriptionNormalizer(ev_context.corpus_info.feature_types)
  84. self._ev_context = ev_context
  85. for cls in self.constraint_factory:
  86. cls.setup_context(self._ev_context)
  87. self._ntyper = NodeTypeInferencer(ev_context.corpus_info.feature_types)
  88. @ast_visitor.node_handler(ast.NodeDescription)
  89. def handle_node_description(self, child_node):
  90. """Replaces an anonymous node description with a reference to a fresh node variable.
  91. The node description is stored for later reference.
  92. """
  93. variable = NodeVariable(self.get_anon_nodevar(), False)
  94. self.node_defs[variable] = child_node
  95. self.node_vars[variable.name] = variable
  96. return self.REPLACE(create_varref(variable.name))
  97. @ast_visitor.node_handler(ast.VariableDefinition)
  98. def handle_node_variable_def(self, child_node):
  99. """Replaces a node variable definition with a reference, and stores it.
  100. If the variable has already been defined, the node descriptions are merged.
  101. """
  102. assert child_node.variable.type == ast.VariableTypes.NodeIdentifier
  103. node_variable = NodeVariable.from_node(child_node.variable)
  104. self.node_vars[child_node.variable.name] = node_variable
  105. if node_variable in self.node_defs:
  106. self.node_defs[node_variable] = ast.NodeDescription(
  107. ast.Conjunction([self.node_defs[node_variable].expression,
  108. child_node.expression.expression]))
  109. else:
  110. self.node_defs[node_variable] = child_node.expression
  111. return self.REPLACE(create_varref(child_node.variable.name,
  112. container_type = child_node.variable.container))
  113. @ast_visitor.node_handler(ast.Predicate)
  114. def handle_predicate(self, child_node):
  115. """Stores the predicate in the list of predicates."""
  116. self.predicates.append(child_node)
  117. return self.CONTINUE(child_node)
  118. @ast_visitor.node_handler(ast.SiblingOperator,
  119. ast.CornerOperator,
  120. ast.DominanceOperator,
  121. ast.PrecedenceOperator,
  122. ast.SecEdgeOperator)
  123. def constraint_op(self, child_node):
  124. """Stores the constraint in the list of constraints."""
  125. self.constraints.append(child_node)
  126. return self.CONTINUE(child_node)
  127. def setup(self, query_ast):
  128. """Creates the collections for the internal representation of the query."""
  129. self.predicates = []
  130. self.node_defs = {}
  131. self.node_vars = {}
  132. self.constraints = []
  133. def _get_variable(self, variable):
  134. """Returns a node variable object associated with the AST fragment `variable`.
  135. If `variable` is seen the first time, a new node variable is created using
  136. `NodeVariable.from_node`.
  137. """
  138. try:
  139. return self.node_vars[variable.name]
  140. except KeyError:
  141. node_variable = self.node_vars[variable.name] = NodeVariable.from_node(variable)
  142. self.node_defs[node_variable] = ast.NodeDescription(ast.Nop())
  143. return node_variable
  144. def _process_predicates(self, predicates):
  145. """Creates the predicate objects.
  146. The predicate objects are created from the AST nodes using the `predicate_factory`.
  147. """
  148. for pred_ast_node in self.predicates:
  149. ast_var, predicate = self.predicate_factory.create(pred_ast_node)
  150. predicates[self._get_variable(ast_var)].append(predicate)
  151. def _process_constraints(self, predicates):
  152. """Creates the constraint objects.
  153. The constraints are created from the AST representations using the `constraint_factory`.
  154. """
  155. result = []
  156. for constraint_ast_node in self.constraints:
  157. left_var = self._get_variable(constraint_ast_node.left_operand.variable)
  158. right_var = self._get_variable(constraint_ast_node.right_operand.variable)
  159. constraint = self.constraint_factory.create(
  160. constraint_ast_node, (left_var.var_type, right_var.var_type), self._ev_context)
  161. result.append(((left_var, right_var), constraint))
  162. for node_var, var_type in zip((left_var, right_var),
  163. constraint.get_node_variable_types()):
  164. node_var.refine_type(var_type)
  165. for (left_var, right_var), constraint in result:
  166. left_p, right_p = constraint.get_predicates(left_var, right_var)
  167. predicates[left_var].extend(left_p)
  168. predicates[right_var].extend(right_p)
  169. return result
  170. def _add_type_predicates(self, predicates):
  171. """Adds type predicates to the predicate lists if necessary.
  172. A type predicate is only added for a node variable if all of the following conditions
  173. are true:
  174. * the node description is empty
  175. * no predicates are defined for the variable
  176. * the variable type is not `NodeType.UNKNOWN`
  177. This mechanism is different from handling of feature records. The type predicate
  178. is added to each disjunct, while the feature record can differ between each disjunct.
  179. """
  180. for node_variable, description in self.node_defs.iteritems():
  181. if description.expression.TYPE is ast.Nop and len(predicates[node_variable]) == 0 \
  182. and node_variable.var_type is not NodeType.UNKNOWN:
  183. predicates[node_variable].append(NodeTypePredicate(node_variable.var_type))
  184. def from_ast(self, query_ast):
  185. """Convert a query AST into a result builder object.
  186. Query ASTs are in the same state as returned by the parser.
  187. The result builder class is injected using the `get_result_builder_class`
  188. on the evaluator context.
  189. """
  190. return self.run(query_ast)
  191. def result(self, query_ast):
  192. """Processes the collected items and returns the query object."""
  193. predicates = defaultdict(list)
  194. for node_variable, node_desc in self.node_defs.iteritems():
  195. self.nodedesc_normalizer.run(node_desc)
  196. node_var_type, has_frec = self._ntyper.run(node_desc, node_variable)
  197. node_variable.refine_type(node_var_type)
  198. if has_frec:
  199. predicates[node_variable].append(NodeTypePredicate(node_var_type))
  200. self._process_predicates(predicates)
  201. constraints = self._process_constraints(predicates)
  202. self._add_type_predicates(predicates)
  203. return self._ev_context.get_result_builder_class(len(constraints) > 0)(
  204. self._ev_context, self.node_defs, predicates, constraints)