PageRenderTime 41ms CodeModel.GetById 15ms RepoModel.GetById 1ms app.codeStats 0ms

/corehq/apps/userreports/reports/specs.py

https://github.com/dimagi/commcare-hq
Python | 607 lines | 532 code | 47 blank | 28 comment | 17 complexity | 97fc818e0f06fd7e29f18c3046e9f265 MD5 | raw file
Possible License(s): BSD-3-Clause, LGPL-2.1
  1. import json
  2. from collections import namedtuple
  3. from datetime import date
  4. from django.utils.translation import gettext as _
  5. from jsonobject.base import DefaultProperty
  6. from jsonobject.exceptions import BadValueError
  7. from memoized import memoized
  8. from sqlagg import (
  9. CountColumn,
  10. CountUniqueColumn,
  11. MaxColumn,
  12. MeanColumn,
  13. MinColumn,
  14. SumColumn,
  15. )
  16. from sqlagg.columns import (
  17. ArrayAggColumn,
  18. ConditionalAggregation,
  19. MonthColumn,
  20. NonzeroSumColumn,
  21. SimpleColumn,
  22. SumWhen,
  23. YearColumn,
  24. )
  25. from sqlalchemy import bindparam
  26. from couchforms.geopoint import GeoPoint
  27. from dimagi.ext.jsonobject import (
  28. BooleanProperty,
  29. DictProperty,
  30. IntegerProperty,
  31. JsonArray,
  32. JsonObject,
  33. ListProperty,
  34. ObjectProperty,
  35. StringProperty,
  36. )
  37. from corehq.apps.reports.datatables import DataTablesColumn
  38. from corehq.apps.reports.sqlreport import AggregateColumn, DatabaseColumn
  39. from corehq.apps.userreports import const
  40. from corehq.apps.userreports.columns import (
  41. ColumnConfig,
  42. get_expanded_column_config,
  43. )
  44. from corehq.apps.userreports.const import DEFAULT_MAXIMUM_EXPANSION
  45. from corehq.apps.userreports.exceptions import BadSpecError, InvalidQueryColumn
  46. from corehq.apps.userreports.expressions.factory import ExpressionFactory
  47. from corehq.apps.userreports.reports.sorting import ASCENDING, DESCENDING
  48. from corehq.apps.userreports.specs import TypeProperty
  49. from corehq.apps.userreports.transforms.factory import TransformFactory
  50. from corehq.apps.userreports.util import localize
  51. from corehq.toggles import UCR_SUM_WHEN_TEMPLATES
  52. SQLAGG_COLUMN_MAP = {
  53. const.AGGGREGATION_TYPE_AVG: MeanColumn,
  54. const.AGGGREGATION_TYPE_COUNT_UNIQUE: CountUniqueColumn,
  55. const.AGGGREGATION_TYPE_COUNT: CountColumn,
  56. const.AGGGREGATION_TYPE_MIN: MinColumn,
  57. const.AGGGREGATION_TYPE_MAX: MaxColumn,
  58. const.AGGGREGATION_TYPE_MONTH: MonthColumn,
  59. const.AGGGREGATION_TYPE_SUM: SumColumn,
  60. const.AGGGREGATION_TYPE_SIMPLE: SimpleColumn,
  61. const.AGGGREGATION_TYPE_YEAR: YearColumn,
  62. const.AGGGREGATION_TYPE_NONZERO_SUM: NonzeroSumColumn,
  63. }
  64. class BaseReportColumn(JsonObject):
  65. type = StringProperty(required=True)
  66. column_id = StringProperty(required=True)
  67. display = DefaultProperty()
  68. description = StringProperty()
  69. visible = BooleanProperty(default=True)
  70. @classmethod
  71. def restricted_to_static(cls, domain):
  72. return False
  73. @classmethod
  74. def wrap(cls, obj):
  75. if 'display' not in obj and 'column_id' in obj:
  76. obj['display'] = obj['column_id']
  77. return super(BaseReportColumn, cls).wrap(obj)
  78. def get_header(self, lang):
  79. return localize(self.display, lang)
  80. def get_column_ids(self):
  81. """
  82. Used as an abstraction layer for columns that can contain more than one data column
  83. (for example, PercentageColumns).
  84. """
  85. return [self.column_id]
  86. def get_column_config(self, data_source_config, lang):
  87. raise NotImplementedError('subclasses must override this')
  88. def get_fields(self, data_source_config=None, lang=None):
  89. """
  90. Get database fields associated with this column. Could be one, or more
  91. if a column is a function of two values in the DB (e.g. PercentageColumn)
  92. """
  93. raise NotImplementedError('subclasses must override this')
  94. class ReportColumn(BaseReportColumn):
  95. transform = DictProperty()
  96. calculate_total = BooleanProperty(default=False)
  97. def format_data(self, data):
  98. """
  99. Subclasses can apply formatting to the entire dataset.
  100. """
  101. pass
  102. def get_format_fn(self):
  103. """
  104. A function that gets applied to the data just in time before the report is rendered.
  105. """
  106. if self.transform:
  107. return TransformFactory.get_transform(self.transform).get_transform_function()
  108. return None
  109. def get_query_column_ids(self):
  110. """
  111. Gets column IDs associated with a query. These could be different from
  112. the normal column_ids if the same column ends up in multiple columns in
  113. the query (e.g. an aggregate date splitting into year and month)
  114. """
  115. raise InvalidQueryColumn(_("You can't query on columns of type {}".format(self.type)))
  116. class FieldColumn(ReportColumn):
  117. type = TypeProperty('field')
  118. field = StringProperty(required=True)
  119. aggregation = StringProperty(
  120. choices=list(SQLAGG_COLUMN_MAP),
  121. required=True,
  122. )
  123. format = StringProperty(default='default', choices=[
  124. 'default',
  125. 'percent_of_total',
  126. ])
  127. sortable = BooleanProperty(default=False)
  128. width = StringProperty(default=None, required=False)
  129. css_class = StringProperty(default=None, required=False)
  130. @classmethod
  131. def wrap(cls, obj):
  132. # lazy migrations for legacy data.
  133. # todo: remove once all reports are on new format
  134. # 1. set column_id to alias, or field if no alias found
  135. _add_column_id_if_missing(obj)
  136. # 2. if aggregation='expand' convert to ExpandedColumn
  137. if obj.get('aggregation') == 'expand':
  138. del obj['aggregation']
  139. obj['type'] = 'expanded'
  140. return ExpandedColumn.wrap(obj)
  141. return super(FieldColumn, cls).wrap(obj)
  142. def format_data(self, data):
  143. if self.format == 'percent_of_total':
  144. column_name = self.column_id
  145. total = sum(row[column_name] for row in data)
  146. for row in data:
  147. row[column_name] = '{:.0%}'.format(
  148. row[column_name] / total
  149. )
  150. def get_column_config(self, data_source_config, lang):
  151. return ColumnConfig(columns=[
  152. DatabaseColumn(
  153. header=self.get_header(lang),
  154. agg_column=SQLAGG_COLUMN_MAP[self.aggregation](self.field, alias=self.column_id),
  155. sortable=self.sortable,
  156. data_slug=self.column_id,
  157. format_fn=self.get_format_fn(),
  158. help_text=self.description,
  159. visible=self.visible,
  160. width=self.width,
  161. css_class=self.css_class,
  162. )
  163. ])
  164. def get_fields(self, data_source_config=None, lang=None):
  165. return [self.field]
  166. def _data_source_col_config(self, data_source_config):
  167. return filter(
  168. lambda c: c['column_id'] == self.field, data_source_config.configured_indicators
  169. )[0]
  170. def _column_data_type(self, data_source_config):
  171. return self._data_source_col_config(data_source_config).get('datatype')
  172. def _use_terms_aggregation_for_max_min(self, data_source_config):
  173. return (
  174. self.aggregation in ['max', 'min'] and
  175. self._column_data_type(data_source_config) and
  176. self._column_data_type(data_source_config) not in ['integer', 'decimal']
  177. )
  178. def get_query_column_ids(self):
  179. return [self.column_id]
  180. class LocationColumn(ReportColumn):
  181. type = TypeProperty('location')
  182. field = StringProperty(required=True)
  183. sortable = BooleanProperty(default=False)
  184. def format_data(self, data):
  185. column_name = self.column_id
  186. for row in data:
  187. try:
  188. row[column_name] = '{g.latitude} {g.longitude} {g.altitude} {g.accuracy}'.format(
  189. g=GeoPoint.from_string(row[column_name])
  190. )
  191. except BadValueError:
  192. row[column_name] = '{} ({})'.format(row[column_name], _('Invalid Location'))
  193. def get_column_config(self, data_source_config, lang):
  194. return ColumnConfig(columns=[
  195. DatabaseColumn(
  196. header=self.get_header(lang),
  197. agg_column=SimpleColumn(self.field, alias=self.column_id),
  198. sortable=self.sortable,
  199. data_slug=self.column_id,
  200. format_fn=self.get_format_fn(),
  201. help_text=self.description
  202. )
  203. ])
  204. class ExpandedColumn(ReportColumn):
  205. type = TypeProperty('expanded')
  206. field = StringProperty(required=True)
  207. max_expansion = IntegerProperty(default=DEFAULT_MAXIMUM_EXPANSION)
  208. @classmethod
  209. def wrap(cls, obj):
  210. # lazy migrations for legacy data.
  211. # todo: remove once all reports are on new format
  212. _add_column_id_if_missing(obj)
  213. return super(ExpandedColumn, cls).wrap(obj)
  214. def get_column_config(self, data_source_config, lang):
  215. return get_expanded_column_config(data_source_config, self, lang)
  216. def get_fields(self, data_source_config, lang):
  217. return [self.field] + [
  218. c.aggregation.name for c in self.get_column_config(data_source_config, lang).columns
  219. ]
  220. class AggregateDateColumn(ReportColumn):
  221. """
  222. Used for grouping months and years together.
  223. """
  224. type = TypeProperty('aggregate_date')
  225. field = StringProperty(required=True)
  226. format = StringProperty(required=False)
  227. def get_column_config(self, data_source_config, lang):
  228. return ColumnConfig(columns=[
  229. AggregateColumn(
  230. header=self.get_header(lang),
  231. aggregate_fn=lambda year, month: {'year': year, 'month': month},
  232. format_fn=self.get_format_fn(),
  233. columns=[
  234. YearColumn(self.field, alias=self._year_column_alias()),
  235. MonthColumn(self.field, alias=self._month_column_alias()),
  236. ],
  237. slug=self.column_id,
  238. data_slug=self.column_id,
  239. )],
  240. )
  241. def _year_column_alias(self):
  242. return '{}_year'.format(self.column_id)
  243. def _month_column_alias(self):
  244. return '{}_month'.format(self.column_id)
  245. def get_format_fn(self):
  246. def _format(data):
  247. if not data.get('year', None) or not data.get('month', None):
  248. return _('Unknown Date')
  249. format_ = self.format or '%Y-%m'
  250. return date(year=int(data['year']), month=int(data['month']), day=1).strftime(format_)
  251. return _format
  252. def get_query_column_ids(self):
  253. return [self._year_column_alias(), self._month_column_alias()]
  254. class _CaseExpressionColumn(ReportColumn):
  255. """ Wraps a SQLAlchemy "case" expression:
  256. http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.case
  257. """
  258. type = None
  259. whens = ListProperty(ListProperty) # List of (expression, bind1, bind2, ... value) tuples
  260. else_ = StringProperty()
  261. sortable = BooleanProperty(default=False)
  262. _agg_column_type = None
  263. def get_column_config(self, data_source_config, lang):
  264. if not self.type and self._agg_column_type:
  265. raise NotImplementedError("subclasses must define a type and column_type")
  266. return ColumnConfig(columns=[
  267. DatabaseColumn(
  268. header=self.get_header(lang),
  269. agg_column=self._agg_column_type(
  270. whens=self.get_whens(),
  271. else_=self.else_,
  272. alias=self.column_id,
  273. ),
  274. sortable=self.sortable,
  275. data_slug=self.column_id,
  276. format_fn=self.get_format_fn(),
  277. help_text=self.description,
  278. visible=self.visible,
  279. )],
  280. )
  281. def get_whens(self):
  282. return self.whens
  283. def get_query_column_ids(self):
  284. return [self.column_id]
  285. class IntegerBucketsColumn(_CaseExpressionColumn):
  286. """Used for grouping by SQL conditionals"""
  287. type = TypeProperty('integer_buckets')
  288. _agg_column_type = ConditionalAggregation
  289. field = StringProperty(required=True)
  290. ranges = DictProperty()
  291. def get_whens(self):
  292. whens = []
  293. for value, bounds in self.ranges.items():
  294. if len(bounds) != 2:
  295. raise BadSpecError('Range must contain 2 items, contains {}'.format(len(bounds)))
  296. try:
  297. bounds = [int(b) for b in bounds]
  298. except ValueError:
  299. raise BadSpecError('Invalid range: [{}, {}]'.format(bounds[0], bounds[1]))
  300. whens.append([self._base_expression(bounds), bindparam(None, value)])
  301. return whens
  302. def _base_expression(self, bounds):
  303. return "{} between {} and {}".format(self.field, bounds[0], bounds[1])
  304. class SumWhenColumn(_CaseExpressionColumn):
  305. type = TypeProperty("sum_when")
  306. else_ = IntegerProperty(default=0)
  307. _agg_column_type = SumWhen
  308. @classmethod
  309. def restricted_to_static(cls, domain):
  310. # The conditional expressions used here don't have sufficient safety checks,
  311. # so this column type is only available for static reports. To release this,
  312. # we should require that conditions be expressed using a PreFilterValue type
  313. # syntax, as attempted in commit 02833e28b7aaf5e0a71741244841ad9910ffb1e5
  314. return True
  315. class SumWhenTemplateColumn(SumWhenColumn):
  316. type = TypeProperty("sum_when_template")
  317. whens = ListProperty(DictProperty) # List of SumWhenTemplateSpec dicts
  318. @classmethod
  319. def restricted_to_static(cls, domain):
  320. return not UCR_SUM_WHEN_TEMPLATES.enabled(domain)
  321. def get_whens(self):
  322. from corehq.apps.userreports.reports.factory import SumWhenTemplateFactory
  323. whens = []
  324. for spec in self.whens:
  325. template = SumWhenTemplateFactory.make_template(spec)
  326. whens.append([template.expression] + template.binds + [template.then])
  327. return whens
  328. class PercentageColumn(ReportColumn):
  329. type = TypeProperty('percent')
  330. numerator = ObjectProperty(FieldColumn, required=True)
  331. denominator = ObjectProperty(FieldColumn, required=True)
  332. format = StringProperty(
  333. choices=['percent', 'fraction', 'both', 'numeric_percent', 'decimal'],
  334. default='percent'
  335. )
  336. def get_column_config(self, data_source_config, lang):
  337. # todo: better checks that fields are not expand
  338. num_config = self.numerator.get_column_config(data_source_config, lang)
  339. denom_config = self.denominator.get_column_config(data_source_config, lang)
  340. return ColumnConfig(columns=[
  341. AggregateColumn(
  342. header=self.get_header(lang),
  343. aggregate_fn=lambda n, d: {'num': n, 'denom': d},
  344. format_fn=self.get_format_fn(),
  345. columns=[c.view for c in num_config.columns + denom_config.columns],
  346. slug=self.column_id,
  347. data_slug=self.column_id,
  348. )],
  349. warnings=num_config.warnings + denom_config.warnings,
  350. )
  351. def get_format_fn(self):
  352. NO_DATA_TEXT = '--'
  353. CANT_CALCULATE_TEXT = '?'
  354. class NoData(Exception):
  355. pass
  356. class BadData(Exception):
  357. pass
  358. def trap_errors(fn):
  359. def inner(*args, **kwargs):
  360. try:
  361. return fn(*args, **kwargs)
  362. except BadData:
  363. return CANT_CALCULATE_TEXT
  364. except NoData:
  365. return NO_DATA_TEXT
  366. return inner
  367. def _raw(data):
  368. if data['denom']:
  369. try:
  370. return float(round(data['num'] / data['denom'], 3))
  371. except (ValueError, TypeError):
  372. raise BadData()
  373. else:
  374. raise NoData()
  375. def _raw_pct(data, round_type=float):
  376. return round_type(_raw(data) * 100)
  377. @trap_errors
  378. def _clean_raw(data):
  379. return _raw(data)
  380. @trap_errors
  381. def _numeric_pct(data):
  382. return _raw_pct(data, round_type=int)
  383. @trap_errors
  384. def _pct(data):
  385. return '{0:.0f}%'.format(_raw_pct(data))
  386. _fraction = lambda data: '{num}/{denom}'.format(**data)
  387. return {
  388. 'percent': _pct,
  389. 'fraction': _fraction,
  390. 'both': lambda data: '{} ({})'.format(_pct(data), _fraction(data)),
  391. 'numeric_percent': _numeric_pct,
  392. 'decimal': _clean_raw,
  393. }[self.format]
  394. def get_column_ids(self):
  395. # override this to include the columns for the numerator and denominator as well
  396. return [self.column_id, self.numerator.column_id, self.denominator.column_id]
  397. def get_fields(self, data_source_config=None, lang=None):
  398. return self.numerator.get_fields() + self.denominator.get_fields()
  399. class ArrayAggLastValueReportColumn(ReportColumn):
  400. type = TypeProperty('array_agg_last_value')
  401. field = StringProperty(required=True)
  402. order_by_col = StringProperty(required=False)
  403. _agg_column_type = ArrayAggColumn
  404. def get_column_config(self, data_source_config, lang):
  405. def _last_value(array):
  406. return array[-1] if array else None
  407. return ColumnConfig(columns=[
  408. DatabaseColumn(
  409. header=self.get_header(lang),
  410. agg_column=self._agg_column_type(
  411. key=self.field,
  412. order_by_col=self.order_by_col,
  413. alias=self.column_id,
  414. ),
  415. format_fn=_last_value,
  416. data_slug=self.column_id,
  417. help_text=self.description,
  418. visible=self.visible,
  419. sortable=False,
  420. )
  421. ])
  422. def _add_column_id_if_missing(obj):
  423. if obj.get('column_id') is None:
  424. obj['column_id'] = obj.get('alias') or obj['field']
  425. class CalculatedColumn(namedtuple('CalculatedColumn', ['header', 'slug', 'visible', 'help_text'])):
  426. @property
  427. def data_tables_column(self):
  428. return DataTablesColumn(self.header, sortable=False, data_slug=self.slug,
  429. visible=self.visible, help_text=self.help_text)
  430. class ExpressionColumn(BaseReportColumn):
  431. expression = DefaultProperty(required=True)
  432. @property
  433. def calculate_total(self):
  434. """Calculating total not supported"""
  435. # Using a function property so that it can't be overridden during wrapping
  436. return False
  437. @property
  438. @memoized
  439. def wrapped_expression(self):
  440. return ExpressionFactory.from_spec(self.expression)
  441. def get_column_config(self, data_source_config, lang):
  442. return ColumnConfig(columns=[
  443. CalculatedColumn(
  444. header=self.get_header(lang),
  445. slug=self.column_id,
  446. visible=self.visible,
  447. # todo: are these needed?
  448. # format_fn=self.get_format_fn(),
  449. help_text=self.description
  450. )
  451. ])
  452. def get_query_column_ids(self):
  453. raise InvalidQueryColumn(_("Expression Columns do not support group by, sorting, or querying."))
  454. class ChartSpec(JsonObject):
  455. type = StringProperty(required=True)
  456. title = StringProperty()
  457. chart_id = StringProperty()
  458. @classmethod
  459. def wrap(cls, obj):
  460. if obj.get('chart_id') is None:
  461. obj['chart_id'] = (obj.get('title') or '') + str(hash(json.dumps(sorted(obj.items()))))
  462. return super(ChartSpec, cls).wrap(obj)
  463. class PieChartSpec(ChartSpec):
  464. type = TypeProperty('pie')
  465. aggregation_column = StringProperty()
  466. value_column = StringProperty(required=True)
  467. class GraphDisplayColumn(JsonObject):
  468. column_id = StringProperty(required=True)
  469. display = StringProperty(required=True)
  470. @classmethod
  471. def wrap(cls, obj):
  472. # automap column_id to display if display isn't set
  473. if isinstance(obj, dict) and 'column_id' in obj and 'display' not in obj:
  474. obj['display'] = obj['column_id']
  475. return super(GraphDisplayColumn, cls).wrap(obj)
  476. class MultibarChartSpec(ChartSpec):
  477. type = TypeProperty('multibar')
  478. aggregation_column = StringProperty()
  479. x_axis_column = StringProperty(required=True)
  480. y_axis_columns = ListProperty(GraphDisplayColumn)
  481. is_stacked = BooleanProperty(default=False)
  482. @classmethod
  483. def wrap(cls, obj):
  484. def _convert_columns_to_properly_dicts(cols):
  485. for column in cols:
  486. if isinstance(column, str):
  487. yield {'column_id': column, 'display': column}
  488. else:
  489. yield column
  490. obj['y_axis_columns'] = list(_convert_columns_to_properly_dicts(obj.get('y_axis_columns', [])))
  491. return super(MultibarChartSpec, cls).wrap(obj)
  492. class MultibarAggregateChartSpec(ChartSpec):
  493. type = TypeProperty('multibar-aggregate')
  494. primary_aggregation = StringProperty(required=True)
  495. secondary_aggregation = StringProperty(required=True)
  496. value_column = StringProperty(required=True)
  497. class OrderBySpec(JsonObject):
  498. field = StringProperty()
  499. order = StringProperty(choices=[ASCENDING, DESCENDING], default=ASCENDING)