PageRenderTime 54ms CodeModel.GetById 24ms RepoModel.GetById 0ms app.codeStats 0ms

/py/vtdb/sql_builder.py

https://gitlab.com/18runt88/vitess
Python | 686 lines | 496 code | 142 blank | 48 comment | 48 complexity | 708e30801bdc2571d8c7936d5db0ba8a MD5 | raw file
  1. """Helper classes for building queries.
  2. Helper classes and fucntions for building queries.
  3. """
  4. import itertools
  5. import pprint
  6. #TODO: add unit-tests for the methods and classes.
  7. #TODO: integration with SQL Alchemy ?
  8. class DBRow(object):
  9. def __init__(self, column_names, row_tuple, **overrides):
  10. self.__dict__ = dict(zip(column_names, row_tuple), **overrides)
  11. def __repr__(self):
  12. return pprint.pformat(self.__dict__, 4)
  13. def select_clause(select_columns, table_name, alias=None, cols=None, order_by_cols=None):
  14. """Build the select clause for a query."""
  15. if alias:
  16. return 'SELECT %s FROM %s %s' % (
  17. colstr(select_columns, alias, cols, order_by_cols=order_by_cols),
  18. table_name, alias)
  19. return 'SELECT %s FROM %s' % (
  20. colstr(select_columns, alias, cols, order_by_cols=order_by_cols),
  21. table_name)
  22. def colstr(select_columns, alias=None, cols=None, bind=None, order_by_cols=None):
  23. if not cols:
  24. cols = select_columns
  25. # in the case of a scatter/gather, prepend these columns to facilitate an in-code
  26. # sort - after that, we can just strip these off and process normally
  27. if order_by_cols:
  28. # avoid altering a class variable
  29. cols = cols[:]
  30. for order_col in reversed(order_by_cols):
  31. if type(order_col) in (tuple, list):
  32. cols.insert(0, order_col[0])
  33. else:
  34. cols.insert(0, order_col)
  35. if not bind:
  36. bind = cols
  37. def prefix(col):
  38. if isinstance(col, SQLAggregate):
  39. return col.sql()
  40. if alias and '.' not in col:
  41. col = '%s.%s' % (alias, col)
  42. return col
  43. return ', '.join([prefix(c) for c in cols if c in bind])
  44. def build_values_clause(columns, bind_values):
  45. """Builds values clause for an insert query."""
  46. clause_parts = []
  47. bind_list = []
  48. for column in columns:
  49. if (column in ('time_created', 'time_updated') and
  50. column not in bind_values):
  51. bind_list.append(column)
  52. clause_parts.append('%%(%s)s' % column)
  53. bind_values[column] = int(time.time())
  54. elif column in bind_values:
  55. bind_list.append(column)
  56. if type(bind_values[column]) == MySQLFunction:
  57. clause_parts.append(bind_values[column])
  58. bind_values.update(column.bind_vals)
  59. else:
  60. clause_parts.append('%%(%s)s' % column)
  61. return ', '.join(clause_parts), bind_list
  62. def build_in(column, items, alt_name=None, counter=None):
  63. """Build SQL IN statement and bind hash for use with pyformat."""
  64. if not items:
  65. raise ValueError('Called with empty "items"')
  66. base = alt_name if alt_name else column
  67. bind_list = make_bind_list(base, items, counter=counter)
  68. return ('%s IN (%s)' % (column,
  69. str.join(',', ['%(' + pair[0] + ')s'
  70. for pair in bind_list])),
  71. dict(bind_list))
  72. def build_order_clause(order_by):
  73. """order_by could be a list, tuple or string."""
  74. if not order_by:
  75. return ''
  76. if type(order_by) not in (tuple, list):
  77. order_by = (order_by,)
  78. subclause_list = []
  79. for subclause in order_by:
  80. if type(subclause) in (tuple, list):
  81. subclause = ' '.join(subclause)
  82. subclause_list.append(subclause)
  83. return 'ORDER BY %s' % ', '.join(subclause_list)
  84. def build_group_clause(group_by):
  85. """Build group_by clause for a query."""
  86. if not group_by:
  87. return ''
  88. if type(group_by) not in (tuple, list):
  89. group_by = (group_by,)
  90. return 'GROUP BY %s' % ', '.join(group_by)
  91. def build_limit_clause(limit):
  92. """Build limit clause for a query."""
  93. if not limit:
  94. return '', {}
  95. if not isinstance(limit, tuple):
  96. limit = (limit,)
  97. bind_vars = {'limit_row_count': limit[0]}
  98. if len(limit) == 1:
  99. return 'LIMIT %(limit_row_count)s', bind_vars
  100. bind_vars = {'limit_offset': limit[0],
  101. 'limit_row_count': limit[1]}
  102. return 'LIMIT %(limit_offset)s,%(limit_row_count)s', bind_vars
  103. def build_where_clause(column_value_pairs):
  104. """Build the where clause for a query."""
  105. condition_list = []
  106. bind_vars = {}
  107. counter = itertools.count(1)
  108. def update_bindvars(newvars):
  109. for k, v in newvars.iteritems():
  110. if k in bind_vars:
  111. raise ValueError('Duplicate bind vars: cannot add %r to %r',
  112. newvars, bind_vars)
  113. bind_vars[k] = v
  114. for column, value in column_value_pairs:
  115. if isinstance(value, (Flag, SQLOperator, NullSafeNotValue)):
  116. clause, clause_bind_vars = value.build_sql(column, counter=counter)
  117. update_bindvars(clause_bind_vars)
  118. condition_list.append(clause)
  119. elif isinstance(value, (tuple, list, set)):
  120. if value:
  121. in_clause, in_bind_variables = build_in(column, value,
  122. counter=counter)
  123. update_bindvars(in_bind_variables)
  124. condition_list.append(in_clause)
  125. else:
  126. condition_list.append('1 = 0')
  127. else:
  128. bind_name = choose_bind_name(column, counter=counter)
  129. update_bindvars({bind_name: value})
  130. condition_list.append('%s = %%(%s)s' % (column, bind_name))
  131. if not bind_vars:
  132. bind_vars = dict(column_value_pairs)
  133. where_clause = ' AND '.join(condition_list)
  134. return where_clause, bind_vars
  135. def select_by_columns_query(select_column_list, table_name, column_value_pairs=None,
  136. order_by=None, group_by=None, limit=None,
  137. for_update=False,client_aggregate=False,
  138. vt_routing_info=None):
  139. if client_aggregate:
  140. clause_list = [select_clause(select_column_list, table_name,
  141. order_by_cols=order_by)]
  142. else:
  143. clause_list = [select_clause(select_column_list, table_name)]
  144. # generate WHERE clause and bind variables
  145. if column_value_pairs:
  146. where_clause, bind_vars = build_where_clause(column_value_pairs)
  147. # add vt routing info
  148. if vt_routing_info:
  149. where_clause, bind_vars = vt_routing_info.update_where_clause(
  150. where_clause, bind_vars)
  151. clause_list += ['WHERE', where_clause]
  152. else:
  153. bind_vars = {}
  154. if group_by:
  155. clause_list.append(build_group_clause(group_by))
  156. if order_by:
  157. clause_list.append(build_order_clause(order_by))
  158. if limit:
  159. clause, limit_bind_vars = build_limit_clause(limit)
  160. clause_list.append(clause)
  161. bind_vars.update(limit_bind_vars)
  162. if for_update:
  163. clause_list.append('FOR UPDATE')
  164. query = ' '.join(clause_list)
  165. return query, bind_vars
  166. def update_columns_query(table_name, where_column_value_pairs=None,
  167. update_column_value_pairs=None, limit=None,
  168. order_by=None):
  169. if not update_column_value_pairs:
  170. raise dbexceptions.ProgrammingError("No update values specified.")
  171. clause_list = []
  172. bind_vals = {}
  173. for i, (column, value) in enumerate(update_column_value_pairs):
  174. if isinstance(value, (Flag, Increment, MySQLFunction)):
  175. clause, clause_bind_vals = value.build_update_sql(column)
  176. clause_list.append(clause)
  177. bind_vals.update(clause_bind_vals)
  178. else:
  179. clause_list.append('%s = %%(update_set_%s)s' % (column, i))
  180. bind_vals['update_set_%s' % i] = value
  181. if not clause_list:
  182. # this would be invalid syntax anyway, let's raise a nicer exception
  183. raise ValueError(
  184. 'Expected nonempty update_column_value_pairs. Got: %r'
  185. % update_column_value_pairs)
  186. set_clause = ', '.join(clause_list)
  187. if not where_column_value_pairs:
  188. # same as above. We could allow for no where clause,
  189. # but this is a notoriously error-prone construct, so, no.
  190. raise ValueError(
  191. 'Expected nonempty where_column_value_pairs. Got: %r'
  192. % where_column_value_pairs)
  193. where_clause, where_bind_vals = build_where_clause(where_column_value_pairs)
  194. bind_vals.update(where_bind_vals)
  195. query = ('UPDATE %(table)s SET %(set_clause)s WHERE %(where_clause)s'
  196. % {'table': table_name, 'set_clause': set_clause,
  197. 'where_clause': where_clause})
  198. additional_clause = []
  199. if order_by:
  200. additional_clause.append(build_order_clause(order_by))
  201. if limit:
  202. limit_clause, limit_bind_vars = build_limit_clause(limit)
  203. additional_clause.append(limit_clause)
  204. bind_vals.update(limit_bind_vars)
  205. query += ' ' + ' '.join(additional_clause)
  206. return query, bind_vals
  207. def delete_by_columns_query(table_name, where_column_value_pairs=None,
  208. limit=None):
  209. where_clause, bind_vars = build_where_clause(where_column_value_pairs)
  210. limit_clause, limit_bind_vars = build_limit_clause(limit)
  211. bind_vars.update(limit_bind_vars)
  212. query = (
  213. 'DELETE FROM %(table_name)s WHERE %(where_clause)s %(limit_clause)s' %
  214. {'table_name': table_name, 'where_clause': where_clause,
  215. 'limit_clause': limit_clause})
  216. return query, bind_vars
  217. def insert_query(table_name, columns_list, **bind_variables):
  218. values_clause, bind_list = build_values_clause(columns_list,
  219. bind_variables)
  220. query = 'INSERT INTO %s (%s) VALUES (%s)' % (table_name,
  221. colstr(columns_list,
  222. bind=bind_list),
  223. values_clause)
  224. return query, bind_variables
  225. def build_aggregate_query(table_name, id_column_name, sort_func='min'):
  226. query_clause = 'SELECT %(id_col)s FROM %(table_name)s ORDER BY %(id_col)s'
  227. if sort_func == 'max':
  228. query_clause += ' DESC'
  229. query_clause += ' LIMIT 1'
  230. query = query_clause % {'id_col': id_column_name, 'table_name': table_name}
  231. return query
  232. def build_count_query(table_name, column_value_pairs):
  233. where_clause, bind_vars = build_where_clause(column_value_pairs)
  234. query = 'SELECT count(1) FROM %s WHERE %s' % (table_name, where_clause)
  235. return query, bind_vars
  236. def choose_bind_name(base, counter=None):
  237. if counter:
  238. base += '_%d' % counter.next()
  239. return base
  240. def make_bind_list(column, values, counter=None):
  241. result = []
  242. bind_names = []
  243. if counter is None:
  244. counter = itertools.count(1)
  245. for value in values:
  246. bind_name = choose_bind_name(column, counter=counter)
  247. bind_names.append(bind_name)
  248. result.append((bind_name, value))
  249. return result
  250. class MySQLFunction(object):
  251. def __init__(self, func, bind_vals=()):
  252. self.bind_vals = bind_vals
  253. self.func = func
  254. def __str__(self):
  255. return self.func
  256. def build_update_sql(self, column):
  257. clause = '%s = %s' % (column, self.func)
  258. return clause, self.bind_vals
  259. class SQLAggregate(object):
  260. def __init__(self, function_name, column_name):
  261. self.function_name = function_name
  262. self.column_name = column_name
  263. def sql(self):
  264. clause = '%(function_name)s(%(column_name)s)' % vars(self)
  265. return clause
  266. def Sum(column_name):
  267. return SQLAggregate('SUM', column_name)
  268. def Max(column_name):
  269. return SQLAggregate('MAX', column_name)
  270. def Min(column_name):
  271. return SQLAggregate('MIN', column_name)
  272. # A null-safe inequality operator. For any [column] and [value] we do
  273. # "NOT [column] <=> [value]".
  274. #
  275. # This is a bit of a hack because our framework assumes all operators are
  276. # binary in nature (whereas we need a combination of unary and binary
  277. # operators).
  278. #
  279. # This is only enabled for use in the where clause. For use in select or
  280. # update you'll need to do some additional work.
  281. class NullSafeNotValue(object):
  282. def __init__(self, value):
  283. self.value = value
  284. def build_sql(self, column_name, counter=None):
  285. bind_name = choose_bind_name(column_name, counter=counter)
  286. clause = 'NOT %(column_name)s <=> %%(%(bind_name)s)s' % vars()
  287. bind_vars = {bind_name: self.value}
  288. return clause, bind_vars
  289. class SQLOperator(object):
  290. """Base class for a column expression in a SQL WHERE clause."""
  291. def __init__(self, value, op):
  292. """Constructor.
  293. Args:
  294. value: The value against which to compare the column, or an iterable of
  295. values if appropriate for the operator.
  296. op: The operator to use for comparison.
  297. """
  298. self.value = value
  299. self.op = op
  300. def build_sql(self, column_name, counter=None):
  301. """Render this expression as a SQL string.
  302. Args:
  303. column_name: Name of the column being tested in this expression.
  304. counter: Instance of itertools.count supplying numeric suffixes for
  305. disambiguating bind_names, or None. (See choose_bind_name
  306. for a discussion.)
  307. Returns:
  308. clause: The SQL expression, including a placeholder for the value.
  309. bind_vars: Dict mapping placeholder names to actual values.
  310. """
  311. op = self.op
  312. bind_name = choose_bind_name(column_name, counter=counter)
  313. clause = '%(column_name)s %(op)s %%(%(bind_name)s)s' % vars()
  314. bind_vars = {bind_name: self.value}
  315. return clause, bind_vars
  316. class NotValue(SQLOperator):
  317. def __init__(self, value):
  318. super(NotValue, self).__init__(value, '!=')
  319. def build_sql(self, column_name, counter=None):
  320. if self.value is None:
  321. return '%s IS NOT NULL' % column_name, {}
  322. return super(NotValue, self).build_sql(column_name, counter=counter)
  323. class InValuesOperatorBase(SQLOperator):
  324. def __init__(self, op, *values):
  325. super(InValuesOperatorBase, self).__init__(values, op)
  326. def build_sql(self, column_name, counter=None):
  327. op = self.op
  328. bind_list = make_bind_list(column_name, self.value, counter=counter)
  329. in_clause = ', '.join(('%(' + key + ')s') for key, val in bind_list)
  330. clause = '%(column_name)s %(op)s (%(in_clause)s)' % vars()
  331. return clause, dict(bind_list)
  332. # You rarely need to use InValues directly in your database classes.
  333. # List and tuples are handled automatically by most database helper methods.
  334. class InValues(InValuesOperatorBase):
  335. def __init__(self, *values):
  336. super(InValues, self).__init__('IN', *values)
  337. class NotInValues(InValuesOperatorBase):
  338. def __init__(self, *values):
  339. super(NotInValues, self).__init__('NOT IN', *values)
  340. class InValuesOrNull(InValues):
  341. def build_sql(self, column_name, counter=None):
  342. clause, bind_vars = super(InValuesOrNull, self).build_sql(column_name,
  343. counter=counter)
  344. clause = '(%s OR %s IS NULL)' % (clause, column_name)
  345. return clause, bind_vars
  346. class BetweenValues(SQLOperator):
  347. def __init__(self, value0, value1):
  348. if value0 < value1:
  349. super(BetweenValues, self).__init__((value0, value1), 'BETWEEN')
  350. else:
  351. super(BetweenValues, self).__init__((value1, value0), 'BETWEEN')
  352. def build_sql(self, column_name, counter=None):
  353. op = self.op
  354. bind_list = make_bind_list(column_name, self.value, counter=counter)
  355. between_clause = ' AND '.join(('%(' + key + ')s') for key, val in bind_list)
  356. clause = '%(column_name)s %(op)s %(between_clause)s' % vars()
  357. return clause, dict(bind_list)
  358. class OrValues(SQLOperator):
  359. def __init__(self, *values):
  360. if not values or len(values) == 1:
  361. raise errors.IllegalArgumentException
  362. super(OrValues, self).__init__(values, 'OR')
  363. def build_sql(self, column_name, counter=None):
  364. condition_list = []
  365. bind_vars = {}
  366. if counter is None:
  367. counter = itertools.count(1)
  368. for v in self.value:
  369. if isinstance(v, (SQLOperator, Flag, NullSafeNotValue)):
  370. clause, clause_bind_vars = v.build_sql(column_name, counter=counter)
  371. bind_vars.update(clause_bind_vars)
  372. condition_list.append(clause)
  373. else:
  374. bind_name = choose_bind_name(column_name, counter=counter)
  375. bind_vars[bind_name] = v
  376. condition_list.append('%s = %%(%s)s' % (column_name, bind_name))
  377. or_clause = '((' + ') OR ('.join(condition_list) + '))'
  378. return or_clause, bind_vars
  379. class LikeValue(SQLOperator):
  380. def __init__(self, value):
  381. super(LikeValue, self).__init__(value, 'LIKE')
  382. class GreaterThanValue(SQLOperator):
  383. def __init__(self, value):
  384. super(GreaterThanValue, self).__init__(value, '>')
  385. class GreaterThanOrEqualToValue(SQLOperator):
  386. def __init__(self, value):
  387. super(GreaterThanOrEqualToValue, self).__init__(value, '>=')
  388. class LessThanValue(SQLOperator):
  389. def __init__(self, value):
  390. super(LessThanValue, self).__init__(value, '<')
  391. class LessThanOrEqualToValue(SQLOperator):
  392. def __init__(self, value):
  393. super(LessThanOrEqualToValue, self).__init__(value, '<=')
  394. class ModuloEquals(SQLOperator):
  395. """column % modulus = value."""
  396. def __init__(self, modulus, value):
  397. super(ModuloEquals, self).__init__(value, '%')
  398. self.modulus = modulus
  399. def build_sql(self, column, counter=None):
  400. mod_bind_name = choose_bind_name('modulus', counter=counter)
  401. val_bind_name = choose_bind_name(column, counter=counter)
  402. sql = '(%(column)s %%%% %%(%(mod_bind_name)s)s) = %%(%(val_bind_name)s)s'
  403. return (sql % {'column': column,
  404. 'mod_bind_name': mod_bind_name,
  405. 'val_bind_name': val_bind_name},
  406. {mod_bind_name: self.modulus,
  407. val_bind_name: self.value})
  408. class Expression(SQLOperator):
  409. def build_sql(self, column_name, counter=None):
  410. op = self.op
  411. value = str(self.value)
  412. clause = '%(column_name)s %(op)s %(value)s' % vars()
  413. return clause, {}
  414. class IsNullOrEmptyString(SQLOperator):
  415. def __init__(self):
  416. super(IsNullOrEmptyString, self).__init__('', '')
  417. def build_sql(self, column_name, counter=None):
  418. # mysql treats '' the same as ' '
  419. return "(%s IS NULL OR %s = '')" % (column_name, column_name), {}
  420. class IsNullValue(SQLOperator):
  421. def __init__(self):
  422. super(IsNullValue, self).__init__('NULL', 'IS')
  423. def build_sql(self, column_name, counter=None):
  424. return '%s IS NULL' % column_name, {}
  425. class IsNotNullValue(SQLOperator):
  426. def __init__(self):
  427. super(IsNotNullValue, self).__init__('NULL', 'IS NOT')
  428. def build_sql(self, column_name, counter=None):
  429. return '%s IS NOT NULL' % column_name, {}
  430. class Flag(object):
  431. def __init__(self, flags_present=0x0, flags_absent=0x0):
  432. if flags_present & flags_absent:
  433. raise errors.InternalError(
  434. 'flags_present (0x%016x) and flags_absent (0x%016x)'
  435. ' overlap: 0x%016x' % (
  436. flags_present, flags_absent, flags_present & flags_absent))
  437. self.mask = flags_present | flags_absent
  438. self.value = flags_present
  439. self.flags_to_remove = flags_absent
  440. self.flags_to_add = flags_present
  441. def __repr__(self):
  442. return '%s(flags_present=0x%X, flags_absent=0x%X)' % (
  443. self.__class__.__name__, self.flags_to_add, self.flags_to_remove)
  444. def __or__(self, other):
  445. return Flag(flags_present=self.flags_to_add | other.flags_to_add,
  446. flags_absent=self.flags_to_remove | other.flags_to_remove)
  447. # Beware: this doesn't switch the present and absent flags, it makes
  448. # an object that *clears all the flags* that the operand would touch.
  449. def __invert__(self):
  450. return Flag(flags_absent=self.mask)
  451. def __eq__(self, other):
  452. if not isinstance(other, Flag):
  453. return False
  454. return (self.mask == other.mask
  455. and self.value == other.value
  456. and self.flags_to_add == other.flags_to_add
  457. and self.flags_to_remove == other.flags_to_remove)
  458. def sql(self, column_name='flags'):
  459. return '%s & %s = %s' % (column_name, self.mask, self.value)
  460. def build_sql(self, column_name='flags', counter=None):
  461. bind_name_mask = choose_bind_name(column_name + '_mask', counter=counter)
  462. bind_name_value = choose_bind_name(column_name + '_value', counter=counter)
  463. clause = '{column_name} & %({bind_name_mask})s = %({bind_name_value})s'.format(
  464. bind_name_mask=bind_name_mask, bind_name_value=bind_name_value,
  465. column_name=column_name)
  466. bind_vars = {
  467. bind_name_mask: self.mask,
  468. bind_name_value: self.value
  469. }
  470. return clause, bind_vars
  471. def update_sql(self, column_name='flags'):
  472. return '%s = (%s | %s) & ~%s' % (
  473. column_name, column_name, self.flags_to_add, self.flags_to_remove)
  474. def build_update_sql(self, column_name='flags'):
  475. clause = ('%(column_name)s = (%(column_name)s | '
  476. '%%(update_%(column_name)s_add)s) & '
  477. '~%%(update_%(column_name)s_remove)s') % vars( )
  478. bind_vars = {
  479. 'update_%s_add' % column_name: self.flags_to_add, 'update_%s_remove' %
  480. column_name: self.flags_to_remove}
  481. return clause, bind_vars
  482. def make_flag(flag_mask, value):
  483. if value:
  484. return Flag(flags_present=flag_mask)
  485. else:
  486. return Flag(flags_absent=flag_mask)
  487. class Increment(object):
  488. def __init__(self, amount):
  489. self.amount = amount
  490. def build_update_sql(self, column_name):
  491. clause = ('%(column_name)s = (%(column_name)s + '
  492. '%%(update_%(column_name)s_amount)s)') % vars()
  493. bind_vars = {'update_%s_amount' % column_name: self.amount}
  494. return clause, bind_vars