PageRenderTime 43ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 0ms

/schevo/query.py

https://bitbucket.org/11craft/schevo
Python | 637 lines | 492 code | 77 blank | 68 comment | 79 complexity | dd72dfd7d4fed81498a4b4899a6eae0d MD5 | raw file
  1. """Query classes."""
  2. # Copyright (c) 2001-2009 ElevenCraft Inc.
  3. # See LICENSE for details.
  4. import operator
  5. import sys
  6. from schevo.lib import optimize
  7. from schevo import base
  8. from schevo.constant import UNASSIGNED
  9. import schevo.error
  10. from schevo import field
  11. from schevo.fieldspec import FieldMap, FieldSpecMap
  12. from schevo.label import label, plural
  13. from schevo.lib.odict import odict
  14. from schevo.meta import schema_metaclass
  15. import schevo.namespace
  16. from schevo.namespace import namespaceproperty
  17. from schevo import queryns
  18. from schevo.trace import log
  19. # --------------------------------------------------------------------
  20. QueryMeta = schema_metaclass('Q')
  21. # --------------------------------------------------------------------
  22. class Query(base.Query):
  23. """Simplest query possible, returning no results."""
  24. __metaclass__ = QueryMeta
  25. def __call__(self):
  26. """Shortcut to get to `_query_results` method."""
  27. return self._results()
  28. def _results(self):
  29. """Return a `Results` instance based on the current state of
  30. this query."""
  31. return results(())
  32. def __str__(self):
  33. return str(unicode(self))
  34. def __unicode__(self):
  35. """Return a human language representation of the query."""
  36. return repr(self)
  37. class Simple(Query):
  38. """Simple query that wraps a callable and a unicode
  39. representation."""
  40. def __init__(self, fn, label):
  41. self._fn = fn
  42. self._label = label
  43. def _results(self):
  44. return results(self._fn())
  45. def __unicode__(self):
  46. return self._label
  47. class Param(Query):
  48. """Parameterized query that has field definitions, and an optional
  49. object on which to operate."""
  50. __slots__ = ['_on', '_field_map', '_label',
  51. '_f', '_h', '_s']
  52. _field_spec = FieldSpecMap()
  53. # Namespaces.
  54. f = namespaceproperty('f', instance=schevo.namespace.Fields)
  55. h = namespaceproperty('h', instance=queryns.ParamChangeHandlers)
  56. s = namespaceproperty('s', instance=queryns.ParamSys)
  57. # Deprecated namespaces.
  58. sys = namespaceproperty('s', instance=queryns.ParamSys)
  59. def __init__(self, *args, **kw):
  60. self._field_map = self._field_spec.field_map(instance=self)
  61. if args:
  62. self._on = args[0]
  63. else:
  64. self._on = None
  65. for name, value in kw.iteritems():
  66. setattr(self, name, value)
  67. def __getattr__(self, name):
  68. try:
  69. return self._field_map[name].get()
  70. except KeyError:
  71. msg = 'Field %r does not exist on %r.' % (name, self)
  72. raise AttributeError(msg)
  73. def __setattr__(self, name, value):
  74. if name.startswith('_') or len(name) == 1:
  75. return Query.__setattr__(self, name, value)
  76. else:
  77. self._field_map[name].set(value)
  78. def __repr__(self):
  79. if self._on:
  80. return '<%s query on %s>' % (self.__class__.__name__, self._on)
  81. else:
  82. return '<%s query>' % self.__class__.__name__
  83. def _getAttributeNames(self):
  84. """Return list of hidden attributes to extend introspection."""
  85. return sorted(self._field_map.keys())
  86. class Exact(Param):
  87. """Parameterized query for an extent that uses ``find``."""
  88. __slots__ = Param.__slots__
  89. _label = 'Exact Matches'
  90. def __init__(self, extent, **kw):
  91. # NOTE: This deliberately does NOT call Param.__init__
  92. self._on = extent
  93. # First, use the fields defined in a subclass, if any.
  94. field_spec = FieldSpecMap(self._field_spec)
  95. field_map = self._field_map = field_spec.field_map(instance=self)
  96. # Next, update field_spec and fields based on extent.
  97. for name, FieldClass in extent.field_spec.iteritems():
  98. if name not in field_map:
  99. # Subclass all fields so they won't be constrained by
  100. # having __slots__ defined. Convert fget fields to
  101. # non-fget, so we can query against them.
  102. class NoSlotsField(FieldClass):
  103. fget = None
  104. readonly = False
  105. required = False
  106. NoSlotsField.__name__ = FieldClass.__name__
  107. FieldClass = NoSlotsField
  108. field_spec[name] = FieldClass
  109. field = field_map[name] = FieldClass(self)
  110. field._name = name
  111. for field in field_map.itervalues():
  112. field.assigned = False
  113. for name, value in kw.iteritems():
  114. setattr(self, name, value)
  115. field = field_map[name]
  116. field.assigned = True
  117. def _results(self):
  118. return results(self._on.find(**self._criteria))
  119. @property
  120. def _criteria(self):
  121. criteria = odict()
  122. for name, field in self.s.field_map().iteritems():
  123. if field.assigned:
  124. criteria[name] = field.get()
  125. return criteria
  126. def __unicode__(self):
  127. extent = self._on
  128. criteria = self._criteria
  129. if criteria:
  130. field_spec = self._on.field_spec
  131. criteria = [
  132. # (field label, value label)
  133. (label(field_spec[name]), unicode(self.f[name]))
  134. for name in criteria
  135. ]
  136. criteria = ', '.join(
  137. '%s == %s' % (field_label, value_label)
  138. for field_label, value_label
  139. in criteria
  140. )
  141. return u'%s where (%s)' % (plural(extent), criteria)
  142. else:
  143. return u'all %s' % plural(extent)
  144. class Links(Query):
  145. """Query whose results are a call to `links` on an entity."""
  146. def __init__(self, entity, other_extent, other_field_name):
  147. self._entity = entity
  148. self._other_extent = other_extent
  149. self._other_field_name = other_field_name
  150. def _results(self):
  151. return results(self._entity.s.links(
  152. self._other_extent, self._other_field_name))
  153. class MatchOperator(object):
  154. def __init__(self, name, label, oper=None):
  155. self.name = name
  156. self.label = label
  157. self.operator = oper
  158. def __repr__(self):
  159. return '<MatchOperator: %s>' % self.label
  160. o_any = MatchOperator('any', u'is anything')
  161. o_assigned = MatchOperator('assigned', u'has a value')
  162. o_unassigned = MatchOperator('unassigned', u'has no value')
  163. o_eq = MatchOperator('eq', u'==', operator.eq)
  164. o_in = MatchOperator('in', u'in', operator.contains)
  165. o_le = MatchOperator('le', u'<=', operator.le)
  166. o_lt = MatchOperator('lt', u'<', operator.lt)
  167. o_ge = MatchOperator('ge', u'>=', operator.ge)
  168. o_gt = MatchOperator('gt', u'>', operator.gt)
  169. o_ne = MatchOperator('ne', u'!=', operator.ne)
  170. def _contains(a, b):
  171. if a is UNASSIGNED:
  172. return False
  173. else:
  174. return b in a
  175. o_contains = MatchOperator('contains', u'contains', _contains)
  176. def _startswith(a, b):
  177. if a is UNASSIGNED:
  178. return False
  179. else:
  180. return a.startswith(b)
  181. o_startswith = MatchOperator('startswith', u'starts with', _startswith)
  182. o_aliases = {
  183. '==': o_eq,
  184. '<=': o_le,
  185. '<': o_lt,
  186. '>=': o_ge,
  187. '>': o_gt,
  188. '!=': o_ne,
  189. 'eq': o_eq,
  190. 'le': o_le,
  191. 'lt': o_lt,
  192. 'ge': o_ge,
  193. 'gt': o_gt,
  194. 'ne': o_ne,
  195. 'any': o_any,
  196. 'assigned': o_assigned,
  197. 'contains': o_contains,
  198. 'in': o_in,
  199. 'startswith': o_startswith,
  200. 'unassigned': o_unassigned,
  201. }
  202. class Match(Query):
  203. """Field match query."""
  204. def __init__(self, on, field_name, operator=o_eq, value=None,
  205. FieldClass=None):
  206. """Create a new field match query.
  207. - ``on``: Extent or Results instance to match on.
  208. - ``field_name``: The field name to match on.
  209. - ``operator``: An object or string alias for the
  210. `MatchOperator` to use when matching.
  211. - ``value``: If not ``None``, the value to match for, or
  212. results to match in.
  213. - ``FieldClass``: If not ``None``, the field class to use to
  214. create the ``field`` attribute. If ``None``, then ``on``
  215. must provide a Field class for ``field_name``.
  216. """
  217. self.on = on
  218. self.field_name = field_name
  219. if not FieldClass:
  220. FieldClass = getattr(on.f, field_name)
  221. # Subclass all fields so they won't be constrained by having
  222. # __slots__ defined. Convert fget fields to non-fget, so we
  223. # can query against them.
  224. class NoSlotsField(FieldClass):
  225. fget = None
  226. readonly = False
  227. required = False
  228. NoSlotsField.__name__ = FieldClass.__name__
  229. FieldClass = NoSlotsField
  230. self.FieldClass = FieldClass
  231. self.operator = operator
  232. self.value = value
  233. def _results(self):
  234. on = self.on
  235. if isinstance(on, base.Query):
  236. on = on()
  237. operator = self.operator
  238. field_name = self.field_name
  239. value = self.value
  240. if operator is o_in:
  241. if isinstance(value, base.Query):
  242. value = value()
  243. value = frozenset(value)
  244. return results(
  245. obj for obj in on if getattr(obj, field_name) in value)
  246. else:
  247. if operator is o_any:
  248. return results(on)
  249. elif operator is o_assigned:
  250. return results(
  251. obj for obj in on
  252. if getattr(obj, field_name) is not UNASSIGNED)
  253. elif operator is o_unassigned:
  254. if isinstance(on, base.Extent):
  255. kw = {field_name: UNASSIGNED}
  256. return results(on.find(**kw))
  257. else:
  258. return results(
  259. obj for obj in on
  260. if getattr(obj, field_name) is UNASSIGNED)
  261. if value is not None:
  262. field = self.FieldClass(self, field_name)
  263. field.set(value)
  264. value = field.get()
  265. if isinstance(on, base.Extent) and operator is o_eq:
  266. kw = {field_name: value}
  267. return results(on.find(**kw))
  268. elif operator.operator:
  269. oper = operator.operator
  270. def generator():
  271. for obj in on:
  272. a = getattr(obj, field_name)
  273. b = value
  274. try:
  275. result = oper(a, b)
  276. except TypeError:
  277. # Cannot compare e.g. UNASSIGNED with
  278. # datetime; assume no match.
  279. continue
  280. if result:
  281. yield obj
  282. return results(generator())
  283. def _get_operator(self):
  284. return self._operator
  285. def _set_operator(self, operator):
  286. if isinstance(operator, basestring):
  287. self._operator = o_aliases[operator]
  288. else:
  289. self._operator = operator
  290. operator = property(_get_operator, _set_operator)
  291. @property
  292. def valid_operators(self):
  293. """Return a sequence of valid operators based on the
  294. FieldClass."""
  295. FieldClass = self.FieldClass
  296. valid = []
  297. if issubclass(FieldClass, field.Field):
  298. valid.append(o_any)
  299. valid.append(o_assigned)
  300. valid.append(o_unassigned)
  301. valid.append(o_eq)
  302. valid.append(o_ne)
  303. if issubclass(FieldClass, (field.String, field.Unicode)):
  304. valid.append(o_contains)
  305. valid.append(o_startswith)
  306. if not issubclass(FieldClass, field.Entity):
  307. valid.append(o_le)
  308. valid.append(o_lt)
  309. valid.append(o_ge)
  310. valid.append(o_gt)
  311. return tuple(valid)
  312. def __unicode__(self):
  313. FieldClass = self.FieldClass
  314. field = FieldClass(self, self.field_name)
  315. operator = self.operator
  316. on = self.on
  317. if isinstance(on, base.Extent):
  318. on_label = plural(on)
  319. else:
  320. on_label = unicode(on)
  321. s = u'%s where %s %s' % (
  322. on_label,
  323. label(field),
  324. label(self.operator),
  325. )
  326. if operator is not o_any:
  327. value = self.value
  328. if isinstance(value, Query):
  329. s += u' %s' % value
  330. else:
  331. field.set(value)
  332. s += u' %s' % field
  333. return s
  334. class Intersection(Query):
  335. """The results common to all given queries.
  336. - ``queries``: A list of queries to intersect.
  337. """
  338. def __init__(self, *queries):
  339. self.queries = list(queries)
  340. def _results(self):
  341. assert log(1, 'called Intersection')
  342. resultset = None
  343. for query in self.queries:
  344. assert log(2, 'resultset is', resultset)
  345. assert log(2, 'intersecting with', query)
  346. s = set(query())
  347. if resultset is None:
  348. resultset = s
  349. else:
  350. resultset = resultset.intersection(s)
  351. assert log(2, 'resultset is finally', resultset)
  352. return results(frozenset(resultset))
  353. def __unicode__(self):
  354. if not self.queries:
  355. return u'the intersection of ()'
  356. last_on = None
  357. for query in self.queries:
  358. # Optimize length of string when results will be all
  359. # entities in an extent.
  360. if (isinstance(query, Match)
  361. and isinstance(query.on, base.Extent)
  362. and (query.on is last_on or not last_on)
  363. and (query.operator is o_any)
  364. ):
  365. last_on = query.on
  366. continue
  367. # Not a default query.
  368. return u'the intersection of (%s)' % (
  369. u', '.join(unicode(query) for query in self.queries)
  370. )
  371. # Was a default query.
  372. return u'all %s' % plural(last_on)
  373. @property
  374. def match_names(self):
  375. """The field names of immediate Match subqueries."""
  376. field_names = []
  377. for query in self.queries:
  378. if isinstance(query, Match):
  379. field_names.append(query.field_name)
  380. return field_names
  381. def remove_match(self, field_name):
  382. """Remove the the first immediate Match subquery with the
  383. given field name."""
  384. for query in self.queries:
  385. if isinstance(query, Match) and query.field_name == field_name:
  386. self.queries.remove(query)
  387. return
  388. raise schevo.error.FieldDoesNotExist(self, field_name)
  389. class ByExample(Intersection):
  390. """Find by example query for a given extent."""
  391. _label = 'By Example'
  392. def __init__(self, extent, **kw):
  393. queries = []
  394. self.extent = extent
  395. for name, FieldClass in extent.field_spec.iteritems():
  396. # Make sure calculated fields are -not- calculated in the
  397. # match query.
  398. class NoSlotsField(FieldClass):
  399. fget = None
  400. readonly = False
  401. required = False
  402. NoSlotsField.__name__ = FieldClass.__name__
  403. match = Match(extent, name, 'any', FieldClass=NoSlotsField)
  404. if name in kw:
  405. match.value = kw[name]
  406. match.operator = '=='
  407. queries.append(match)
  408. Intersection.__init__(self, *queries)
  409. class Union(Query):
  410. """One of each unique result in all given queries.
  411. - ``queries``: The list of queries to union.
  412. """
  413. def __init__(self, *queries):
  414. self.queries = list(queries)
  415. def _results(self):
  416. resultset = set()
  417. for query in self.queries:
  418. resultset.update(query())
  419. return results(frozenset(resultset))
  420. def __unicode__(self):
  421. return u'the union of (%s)' % (
  422. u', '.join(unicode(query) for query in self.queries)
  423. )
  424. class Group(Query):
  425. """Group a query's Results into a list of Results instances,
  426. grouped by a field."""
  427. def __init__(self, query, field_name, FieldClass):
  428. self.query = query
  429. self.field_name = field_name
  430. self.FieldClass = FieldClass
  431. def _results(self):
  432. field_name = self.field_name
  433. groups = {}
  434. for result in self.query():
  435. key = getattr(result, field_name)
  436. L = groups.setdefault(key, [])
  437. L.append(result)
  438. return results(values for values in groups.itervalues())
  439. def __unicode__(self):
  440. field = self.FieldClass(self, self.field_name)
  441. return u'%s, grouped by %s' % (self.query, label(field))
  442. class Min(Query):
  443. """The result of each group in a Group query's results that has
  444. the minimum value for a field."""
  445. def __init__(self, query, field_name, FieldClass=None):
  446. self.query = query
  447. self.field_name = field_name
  448. self.FieldClass = FieldClass
  449. def _results(self):
  450. def generator():
  451. field_name = self.field_name
  452. groups = self.query()
  453. for group in groups:
  454. min_result = None
  455. min_value = None
  456. for result in group:
  457. value = getattr(result, field_name)
  458. if min_result is None or value < min_value:
  459. min_result = result
  460. min_value = value
  461. if min_result is not None:
  462. yield min_result
  463. return results(generator())
  464. def __unicode__(self):
  465. field = self.FieldClass(self, self.field_name)
  466. return u'results that have the minimum %s in each (%s)' % (
  467. label(field), self.query
  468. )
  469. class Max(Query):
  470. """The result of each group in a Group query's results that has
  471. the maximum value for a field."""
  472. def __init__(self, query, field_name, FieldClass=None):
  473. self.query = query
  474. self.field_name = field_name
  475. self.FieldClass = FieldClass
  476. def _results(self):
  477. def generator():
  478. field_name = self.field_name
  479. groups = self.query()
  480. for group in groups:
  481. max_result = None
  482. max_value = None
  483. for result in group:
  484. value = getattr(result, field_name)
  485. if max_result is None or value > max_value:
  486. max_result = result
  487. max_value = value
  488. if max_result is not None:
  489. yield max_result
  490. return results(generator())
  491. def __unicode__(self):
  492. field = self.FieldClass(self, self.field_name)
  493. return u'results that have the maximum %s in each (%s)' % (
  494. label(field), self.query
  495. )
  496. # --------------------------------------------------------------------
  497. def results(obj):
  498. """Return a decorated object based on ``obj`` that mixes in the
  499. `schevo.base.Results` type."""
  500. if isinstance(obj, frozenset):
  501. return ResultsFrozenset(obj)
  502. elif isinstance(obj, list):
  503. return ResultsList(obj)
  504. elif isinstance(obj, set):
  505. return ResultsSet(obj)
  506. elif isinstance(obj, tuple):
  507. return ResultsTuple(obj)
  508. else:
  509. return ResultsIterator(obj)
  510. class ResultsFrozenset(frozenset, base.Results):
  511. pass
  512. class ResultsList(list, base.Results):
  513. pass
  514. class ResultsSet(set, base.Results):
  515. pass
  516. class ResultsTuple(tuple, base.Results):
  517. pass
  518. class ResultsIterator(base.Results):
  519. def __init__(self, orig):
  520. self._orig = orig
  521. def __iter__(self):
  522. return iter(self._orig)
  523. base.classes_using_fields = base.classes_using_fields + (Param, )
  524. optimize.bind_all(sys.modules[__name__]) # Last line of module.