PageRenderTime 41ms CodeModel.GetById 13ms RepoModel.GetById 0ms app.codeStats 0ms

/api/base/filters.py

https://gitlab.com/doublebits/osf.io
Python | 389 lines | 343 code | 21 blank | 25 comment | 20 complexity | a90469eb2281a64c29ad4b3fb2bd859c MD5 | raw file
  1. import re
  2. import functools
  3. import operator
  4. from dateutil import parser as date_parser
  5. import datetime
  6. from django.core.exceptions import ValidationError
  7. from modularodm import Q
  8. from modularodm.query import queryset as modularodm_queryset
  9. from rest_framework.filters import OrderingFilter
  10. from rest_framework import serializers as ser
  11. from api.base.exceptions import (
  12. InvalidFilterError,
  13. InvalidFilterOperator,
  14. InvalidFilterComparisonType,
  15. InvalidFilterMatchType,
  16. InvalidFilterValue,
  17. InvalidFilterFieldError
  18. )
  19. from api.base import utils
  20. from api.base.serializers import RelationshipField, TargetField
  21. def sort_multiple(fields):
  22. fields = list(fields)
  23. def sort_fn(a, b):
  24. while fields:
  25. field = fields.pop(0)
  26. a_field = getattr(a, field)
  27. b_field = getattr(b, field)
  28. if a_field > b_field:
  29. return 1
  30. elif a_field < b_field:
  31. return -1
  32. return 0
  33. return sort_fn
  34. class ODMOrderingFilter(OrderingFilter):
  35. """Adaptation of rest_framework.filters.OrderingFilter to work with modular-odm."""
  36. # override
  37. def filter_queryset(self, request, queryset, view):
  38. ordering = self.get_ordering(request, queryset, view)
  39. if ordering:
  40. if not isinstance(queryset, modularodm_queryset.BaseQuerySet) and isinstance(ordering, (list, tuple)):
  41. sorted_list = sorted(queryset, cmp=sort_multiple(ordering))
  42. return sorted_list
  43. return queryset.sort(*ordering)
  44. return queryset
  45. class FilterMixin(object):
  46. """ View mixin with helper functions for filtering. """
  47. QUERY_PATTERN = re.compile(r'^filter\[(?P<field>\w+)\](\[(?P<op>\w+)\])?$')
  48. MATCH_OPERATORS = ('contains', 'icontains')
  49. MATCHABLE_FIELDS = (ser.CharField, ser.ListField)
  50. DEFAULT_OPERATORS = ('eq', 'ne')
  51. DEFAULT_OPERATOR_OVERRIDES = {
  52. ser.CharField: 'icontains',
  53. ser.ListField: 'contains',
  54. }
  55. NUMERIC_FIELDS = (ser.IntegerField, ser.DecimalField, ser.FloatField)
  56. DATE_FIELDS = (ser.DateTimeField, ser.DateField)
  57. DATETIME_PATTERN = re.compile(r'^\d{4}\-\d{2}\-\d{2}(?P<time>T\d{2}:\d{2}(:\d{2}(\.\d{1,6})?)?)$')
  58. COMPARISON_OPERATORS = ('gt', 'gte', 'lt', 'lte')
  59. COMPARABLE_FIELDS = NUMERIC_FIELDS + DATE_FIELDS
  60. LIST_FIELDS = (ser.ListField, )
  61. RELATIONSHIP_FIELDS = (RelationshipField, TargetField)
  62. def __init__(self, *args, **kwargs):
  63. super(FilterMixin, self).__init__(*args, **kwargs)
  64. if not self.serializer_class:
  65. raise NotImplementedError()
  66. def _get_default_operator(self, field):
  67. return self.DEFAULT_OPERATOR_OVERRIDES.get(type(field), 'eq')
  68. def _get_valid_operators(self, field):
  69. if isinstance(field, self.COMPARABLE_FIELDS):
  70. return self.COMPARISON_OPERATORS + self.DEFAULT_OPERATORS
  71. elif isinstance(field, self.MATCHABLE_FIELDS):
  72. return self.MATCH_OPERATORS + self.DEFAULT_OPERATORS
  73. else:
  74. return None
  75. def _get_field_or_error(self, field_name):
  76. """
  77. Check that the attempted filter field is valid
  78. :raises InvalidFilterError: If the filter field is not valid
  79. """
  80. if field_name not in self.serializer_class._declared_fields:
  81. raise InvalidFilterError(detail="'{0}' is not a valid field for this endpoint.".format(field_name))
  82. if field_name not in getattr(self.serializer_class, 'filterable_fields', set()):
  83. raise InvalidFilterFieldError(parameter='filter', value=field_name)
  84. return self.serializer_class._declared_fields[field_name]
  85. def _validate_operator(self, field, field_name, op):
  86. """
  87. Check that the operator and field combination is valid
  88. :raises InvalidFilterComparisonType: If the query contains comparisons against non-date or non-numeric fields
  89. :raises InvalidFilterMatchType: If the query contains comparisons against non-string or non-list fields
  90. :raises InvalidFilterOperator: If the filter operator is not a member of self.COMPARISON_OPERATORS
  91. """
  92. if op not in set(self.MATCH_OPERATORS + self.COMPARISON_OPERATORS + self.DEFAULT_OPERATORS):
  93. valid_operators = self._get_valid_operators(field)
  94. raise InvalidFilterOperator(value=op, valid_operators=valid_operators)
  95. if op in self.COMPARISON_OPERATORS:
  96. if not isinstance(field, self.COMPARABLE_FIELDS):
  97. raise InvalidFilterComparisonType(
  98. parameter="filter",
  99. detail="Field '{0}' does not support comparison operators in a filter.".format(field_name)
  100. )
  101. if op in self.MATCH_OPERATORS:
  102. if not isinstance(field, self.MATCHABLE_FIELDS):
  103. raise InvalidFilterMatchType(
  104. parameter="filter",
  105. detail="Field '{0}' does not support match operators in a filter.".format(field_name)
  106. )
  107. def _parse_date_param(self, field, field_name, op, value):
  108. """
  109. Allow for ambiguous date filters. This supports operations like finding Nodes created on a given day
  110. even though Node.date_created is a specific datetime.
  111. :return list<dict>: list of one (specific datetime) or more (date range) parsed query params
  112. """
  113. time_match = self.DATETIME_PATTERN.match(value)
  114. if op != 'eq' or time_match:
  115. return [{
  116. 'op': op,
  117. 'value': self.convert_value(value, field)
  118. }]
  119. else: # TODO: let times be as generic as possible (i.e. whole month, whole year)
  120. start = self.convert_value(value, field)
  121. stop = start + datetime.timedelta(days=1)
  122. return [{
  123. 'op': 'gte',
  124. 'value': start
  125. }, {
  126. 'op': 'lt',
  127. 'value': stop
  128. }]
  129. def bulk_get_values(self, value, field):
  130. """
  131. Returns list of values from query_param for IN query
  132. If url contained `/nodes/?filter[id]=12345, abcde`, the returned values would be:
  133. [u'12345', u'abcde']
  134. """
  135. value = value.lstrip('[').rstrip(']')
  136. separated_values = value.split(',')
  137. values = [self.convert_value(val.strip(), field) for val in separated_values]
  138. return values
  139. def parse_query_params(self, query_params):
  140. """Maps query params to a dict useable for filtering
  141. :param dict query_params:
  142. :return dict: of the format {
  143. <resolved_field_name>: {
  144. 'op': <comparison_operator>,
  145. 'value': <resolved_value>
  146. }
  147. }
  148. """
  149. query = {}
  150. for key, value in query_params.iteritems():
  151. match = self.QUERY_PATTERN.match(key)
  152. if match:
  153. match_dict = match.groupdict()
  154. field_name = match_dict['field'].strip()
  155. field = self._get_field_or_error(field_name)
  156. op = match_dict.get('op') or self._get_default_operator(field)
  157. self._validate_operator(field, field_name, op)
  158. if not isinstance(field, ser.SerializerMethodField):
  159. field_name = self.convert_key(field_name, field)
  160. if field_name not in query:
  161. query[field_name] = []
  162. # Special case date(time)s to allow for ambiguous date matches
  163. if isinstance(field, self.DATE_FIELDS):
  164. query[field_name].extend(self._parse_date_param(field, field_name, op, value))
  165. elif not isinstance(value, int) and field_name == '_id':
  166. query[field_name].append({
  167. 'op': 'in',
  168. 'value': self.bulk_get_values(value, field)
  169. })
  170. else:
  171. query[field_name].append({
  172. 'op': op,
  173. 'value': self.convert_value(value, field)
  174. })
  175. return query
  176. def convert_key(self, field_name, field):
  177. """Used so that that queries on fields with the source attribute set will work
  178. :param basestring field_name: text representation of the field name
  179. :param rest_framework.fields.Field field: Field instance
  180. """
  181. source = field.source
  182. if source == '*':
  183. source = getattr(field, 'filter_key', None)
  184. return source or field_name
  185. def convert_value(self, value, field):
  186. """Used to convert incoming values from query params to the appropriate types for filter comparisons
  187. :param basestring value: value to be resolved
  188. :param rest_framework.fields.Field field: Field instance
  189. """
  190. if isinstance(field, ser.BooleanField):
  191. if utils.is_truthy(value):
  192. return True
  193. elif utils.is_falsy(value):
  194. return False
  195. else:
  196. raise InvalidFilterValue(
  197. value=value,
  198. field_type='bool'
  199. )
  200. elif isinstance(field, self.DATE_FIELDS):
  201. try:
  202. return date_parser.parse(value)
  203. except ValueError:
  204. raise InvalidFilterValue(
  205. value=value,
  206. field_type='date'
  207. )
  208. elif isinstance(field, (self.LIST_FIELDS, self.RELATIONSHIP_FIELDS, ser.SerializerMethodField)) \
  209. or isinstance((getattr(field, 'field', None)), self.LIST_FIELDS):
  210. if value == 'null':
  211. value = None
  212. return value
  213. else:
  214. try:
  215. return field.to_internal_value(value)
  216. except ValidationError:
  217. raise InvalidFilterValue(
  218. value=value,
  219. )
  220. class ODMFilterMixin(FilterMixin):
  221. """View mixin that adds a get_query_from_request method which converts query params
  222. of the form `filter[field_name]=value` into an ODM Query object.
  223. Subclasses must define `get_default_odm_query()`.
  224. Serializers that want to restrict which fields are used for filtering need to have a variable called
  225. filterable_fields which is a frozenset of strings representing the field names as they appear in the serialization.
  226. """
  227. # TODO Handle simple and complex non-standard fields
  228. field_comparison_operators = {
  229. ser.CharField: 'icontains',
  230. ser.ListField: 'contains',
  231. }
  232. def __init__(self, *args, **kwargs):
  233. super(FilterMixin, self).__init__(*args, **kwargs)
  234. if not self.serializer_class:
  235. raise NotImplementedError()
  236. def get_default_odm_query(self):
  237. """Return the default MODM query for the result set.
  238. NOTE: If the client provides additional filters in query params, the filters
  239. will intersected with this query.
  240. """
  241. raise NotImplementedError('Must define get_default_odm_query')
  242. def get_query_from_request(self):
  243. if self.request.parser_context['kwargs'].get('is_embedded'):
  244. param_query = None
  245. else:
  246. param_query = self.query_params_to_odm_query(self.request.QUERY_PARAMS)
  247. default_query = self.get_default_odm_query()
  248. if param_query:
  249. query = param_query & default_query
  250. else:
  251. query = default_query
  252. return query
  253. def query_params_to_odm_query(self, query_params):
  254. """Convert query params to a modularodm Query object."""
  255. filters = self.parse_query_params(query_params)
  256. if filters:
  257. query_parts = []
  258. for field_name, params in filters.iteritems():
  259. for group in params:
  260. query = Q(field_name, group['op'], group['value'])
  261. query_parts.append(query)
  262. try:
  263. query = functools.reduce(operator.and_, query_parts)
  264. except TypeError:
  265. query = None
  266. else:
  267. query = None
  268. return query
  269. class ListFilterMixin(FilterMixin):
  270. """View mixin that adds a get_queryset_from_request method which uses query params
  271. of the form `filter[field_name]=value` to filter a list of objects.
  272. Subclasses must define `get_default_queryset()`.
  273. Serializers that want to restrict which fields are used for filtering need to have a variable called
  274. filterable_fields which is a frozenset of strings representing the field names as they appear in the serialization.
  275. """
  276. FILTERS = {
  277. 'eq': operator.eq,
  278. 'lt': operator.lt,
  279. 'lte': operator.le,
  280. 'gt': operator.gt,
  281. 'gte': operator.ge
  282. }
  283. def __init__(self, *args, **kwargs):
  284. super(FilterMixin, self).__init__(*args, **kwargs)
  285. if not self.serializer_class:
  286. raise NotImplementedError()
  287. def get_default_queryset(self):
  288. raise NotImplementedError('Must define get_default_queryset')
  289. def get_queryset_from_request(self):
  290. default_queryset = self.get_default_queryset()
  291. if not self.kwargs.get('is_embedded') and self.request.QUERY_PARAMS:
  292. param_queryset = self.param_queryset(self.request.QUERY_PARAMS, default_queryset)
  293. return param_queryset
  294. else:
  295. return default_queryset
  296. def param_queryset(self, query_params, default_queryset):
  297. """filters default queryset based on query parameters"""
  298. filters = self.parse_query_params(query_params)
  299. queryset = set(default_queryset)
  300. if filters:
  301. for field_name, params in filters.iteritems():
  302. for group in params:
  303. queryset = queryset.intersection(set(self.get_filtered_queryset(field_name, group, default_queryset)))
  304. return list(queryset)
  305. def get_filtered_queryset(self, field_name, params, default_queryset):
  306. """filters default queryset based on the serializer field type"""
  307. field = self.serializer_class._declared_fields[field_name]
  308. field_name = self.convert_key(field_name, field)
  309. if isinstance(field, ser.SerializerMethodField):
  310. return_val = [
  311. item for item in default_queryset
  312. if self.FILTERS[params['op']](self.get_serializer_method(field_name)(item), params['value'])
  313. ]
  314. elif isinstance(field, ser.CharField):
  315. return_val = [
  316. item for item in default_queryset
  317. if params['value'].lower() in getattr(item, field_name, {}).lower()
  318. ]
  319. else:
  320. return_val = [
  321. item for item in default_queryset
  322. if self.FILTERS[params['op']](getattr(item, field_name, None), params['value'])
  323. ]
  324. return return_val
  325. def get_serializer_method(self, field_name):
  326. """
  327. :param field_name: The name of a SerializerMethodField
  328. :return: The function attached to the SerializerMethodField to get its value
  329. """
  330. serializer = self.get_serializer()
  331. serializer_method_name = 'get_' + field_name
  332. return getattr(serializer, serializer_method_name)