/SQLAlchemy-0.7.8/lib/sqlalchemy/orm/persistence.py
Python | 779 lines | 513 code | 123 blank | 143 comment | 132 complexity | ac1ca039297e3bb4657a030ad1cd31ce MD5 | raw file
1# orm/persistence.py
2# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
3#
4# This module is part of SQLAlchemy and is released under
5# the MIT License: http://www.opensource.org/licenses/mit-license.php
6
7"""private module containing functions used to emit INSERT, UPDATE
8and DELETE statements on behalf of a :class:`.Mapper` and its descending
9mappers.
10
11The functions here are called only by the unit of work functions
12in unitofwork.py.
13
14"""
15
16import operator
17from itertools import groupby
18
19from sqlalchemy import sql, util, exc as sa_exc
20from sqlalchemy.orm import attributes, sync, \
21 exc as orm_exc
22
23from sqlalchemy.orm.util import _state_mapper, state_str
24
25def save_obj(base_mapper, states, uowtransaction, single=False):
26 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
27 of objects.
28
29 This is called within the context of a UOWTransaction during a
30 flush operation, given a list of states to be flushed. The
31 base mapper in an inheritance hierarchy handles the inserts/
32 updates for all descendant mappers.
33
34 """
35
36 # if batch=false, call _save_obj separately for each object
37 if not single and not base_mapper.batch:
38 for state in _sort_states(states):
39 save_obj(base_mapper, [state], uowtransaction, single=True)
40 return
41
42 states_to_insert, states_to_update = _organize_states_for_save(
43 base_mapper,
44 states,
45 uowtransaction)
46
47 cached_connections = _cached_connection_dict(base_mapper)
48
49 for table, mapper in base_mapper._sorted_tables.iteritems():
50 insert = _collect_insert_commands(base_mapper, uowtransaction,
51 table, states_to_insert)
52
53 update = _collect_update_commands(base_mapper, uowtransaction,
54 table, states_to_update)
55
56 if update:
57 _emit_update_statements(base_mapper, uowtransaction,
58 cached_connections,
59 mapper, table, update)
60
61 if insert:
62 _emit_insert_statements(base_mapper, uowtransaction,
63 cached_connections,
64 table, insert)
65
66 _finalize_insert_update_commands(base_mapper, uowtransaction,
67 states_to_insert, states_to_update)
68
69def post_update(base_mapper, states, uowtransaction, post_update_cols):
70 """Issue UPDATE statements on behalf of a relationship() which
71 specifies post_update.
72
73 """
74 cached_connections = _cached_connection_dict(base_mapper)
75
76 states_to_update = _organize_states_for_post_update(
77 base_mapper,
78 states, uowtransaction)
79
80
81 for table, mapper in base_mapper._sorted_tables.iteritems():
82 update = _collect_post_update_commands(base_mapper, uowtransaction,
83 table, states_to_update,
84 post_update_cols)
85
86 if update:
87 _emit_post_update_statements(base_mapper, uowtransaction,
88 cached_connections,
89 mapper, table, update)
90
91def delete_obj(base_mapper, states, uowtransaction):
92 """Issue ``DELETE`` statements for a list of objects.
93
94 This is called within the context of a UOWTransaction during a
95 flush operation.
96
97 """
98
99 cached_connections = _cached_connection_dict(base_mapper)
100
101 states_to_delete = _organize_states_for_delete(
102 base_mapper,
103 states,
104 uowtransaction)
105
106 table_to_mapper = base_mapper._sorted_tables
107
108 for table in reversed(table_to_mapper.keys()):
109 delete = _collect_delete_commands(base_mapper, uowtransaction,
110 table, states_to_delete)
111
112 mapper = table_to_mapper[table]
113
114 _emit_delete_statements(base_mapper, uowtransaction,
115 cached_connections, mapper, table, delete)
116
117 for state, state_dict, mapper, has_identity, connection \
118 in states_to_delete:
119 mapper.dispatch.after_delete(mapper, connection, state)
120
121def _organize_states_for_save(base_mapper, states, uowtransaction):
122 """Make an initial pass across a set of states for INSERT or
123 UPDATE.
124
125 This includes splitting out into distinct lists for
126 each, calling before_insert/before_update, obtaining
127 key information for each state including its dictionary,
128 mapper, the connection to use for the execution per state,
129 and the identity flag.
130
131 """
132
133 states_to_insert = []
134 states_to_update = []
135
136 for state, dict_, mapper, connection in _connections_for_states(
137 base_mapper, uowtransaction,
138 states):
139
140 has_identity = bool(state.key)
141 instance_key = state.key or mapper._identity_key_from_state(state)
142
143 row_switch = None
144
145 # call before_XXX extensions
146 if not has_identity:
147 mapper.dispatch.before_insert(mapper, connection, state)
148 else:
149 mapper.dispatch.before_update(mapper, connection, state)
150
151 # detect if we have a "pending" instance (i.e. has
152 # no instance_key attached to it), and another instance
153 # with the same identity key already exists as persistent.
154 # convert to an UPDATE if so.
155 if not has_identity and \
156 instance_key in uowtransaction.session.identity_map:
157 instance = \
158 uowtransaction.session.identity_map[instance_key]
159 existing = attributes.instance_state(instance)
160 if not uowtransaction.is_deleted(existing):
161 raise orm_exc.FlushError(
162 "New instance %s with identity key %s conflicts "
163 "with persistent instance %s" %
164 (state_str(state), instance_key,
165 state_str(existing)))
166
167 base_mapper._log_debug(
168 "detected row switch for identity %s. "
169 "will update %s, remove %s from "
170 "transaction", instance_key,
171 state_str(state), state_str(existing))
172
173 # remove the "delete" flag from the existing element
174 uowtransaction.remove_state_actions(existing)
175 row_switch = existing
176
177 if not has_identity and not row_switch:
178 states_to_insert.append(
179 (state, dict_, mapper, connection,
180 has_identity, instance_key, row_switch)
181 )
182 else:
183 states_to_update.append(
184 (state, dict_, mapper, connection,
185 has_identity, instance_key, row_switch)
186 )
187
188 return states_to_insert, states_to_update
189
190def _organize_states_for_post_update(base_mapper, states,
191 uowtransaction):
192 """Make an initial pass across a set of states for UPDATE
193 corresponding to post_update.
194
195 This includes obtaining key information for each state
196 including its dictionary, mapper, the connection to use for
197 the execution per state.
198
199 """
200 return list(_connections_for_states(base_mapper, uowtransaction,
201 states))
202
203def _organize_states_for_delete(base_mapper, states, uowtransaction):
204 """Make an initial pass across a set of states for DELETE.
205
206 This includes calling out before_delete and obtaining
207 key information for each state including its dictionary,
208 mapper, the connection to use for the execution per state.
209
210 """
211 states_to_delete = []
212
213 for state, dict_, mapper, connection in _connections_for_states(
214 base_mapper, uowtransaction,
215 states):
216
217 mapper.dispatch.before_delete(mapper, connection, state)
218
219 states_to_delete.append((state, dict_, mapper,
220 bool(state.key), connection))
221 return states_to_delete
222
223def _collect_insert_commands(base_mapper, uowtransaction, table,
224 states_to_insert):
225 """Identify sets of values to use in INSERT statements for a
226 list of states.
227
228 """
229 insert = []
230 for state, state_dict, mapper, connection, has_identity, \
231 instance_key, row_switch in states_to_insert:
232 if table not in mapper._pks_by_table:
233 continue
234
235 pks = mapper._pks_by_table[table]
236
237 params = {}
238 value_params = {}
239
240 has_all_pks = True
241 for col in mapper._cols_by_table[table]:
242 if col is mapper.version_id_col:
243 params[col.key] = mapper.version_id_generator(None)
244 else:
245 # pull straight from the dict for
246 # pending objects
247 prop = mapper._columntoproperty[col]
248 value = state_dict.get(prop.key, None)
249
250 if value is None:
251 if col in pks:
252 has_all_pks = False
253 elif col.default is None and \
254 col.server_default is None:
255 params[col.key] = value
256
257 elif isinstance(value, sql.ClauseElement):
258 value_params[col] = value
259 else:
260 params[col.key] = value
261
262 insert.append((state, state_dict, params, mapper,
263 connection, value_params, has_all_pks))
264 return insert
265
266def _collect_update_commands(base_mapper, uowtransaction,
267 table, states_to_update):
268 """Identify sets of values to use in UPDATE statements for a
269 list of states.
270
271 This function works intricately with the history system
272 to determine exactly what values should be updated
273 as well as how the row should be matched within an UPDATE
274 statement. Includes some tricky scenarios where the primary
275 key of an object might have been changed.
276
277 """
278
279 update = []
280 for state, state_dict, mapper, connection, has_identity, \
281 instance_key, row_switch in states_to_update:
282 if table not in mapper._pks_by_table:
283 continue
284
285 pks = mapper._pks_by_table[table]
286
287 params = {}
288 value_params = {}
289
290 hasdata = hasnull = False
291 for col in mapper._cols_by_table[table]:
292 if col is mapper.version_id_col:
293 params[col._label] = \
294 mapper._get_committed_state_attr_by_column(
295 row_switch or state,
296 row_switch and row_switch.dict
297 or state_dict,
298 col)
299
300 prop = mapper._columntoproperty[col]
301 history = attributes.get_state_history(
302 state, prop.key,
303 attributes.PASSIVE_NO_INITIALIZE
304 )
305 if history.added:
306 params[col.key] = history.added[0]
307 hasdata = True
308 else:
309 params[col.key] = mapper.version_id_generator(
310 params[col._label])
311
312 # HACK: check for history, in case the
313 # history is only
314 # in a different table than the one
315 # where the version_id_col is.
316 for prop in mapper._columntoproperty.itervalues():
317 history = attributes.get_state_history(
318 state, prop.key,
319 attributes.PASSIVE_NO_INITIALIZE)
320 if history.added:
321 hasdata = True
322 else:
323 prop = mapper._columntoproperty[col]
324 history = attributes.get_state_history(
325 state, prop.key,
326 attributes.PASSIVE_NO_INITIALIZE)
327 if history.added:
328 if isinstance(history.added[0],
329 sql.ClauseElement):
330 value_params[col] = history.added[0]
331 else:
332 value = history.added[0]
333 params[col.key] = value
334
335 if col in pks:
336 if history.deleted and \
337 not row_switch:
338 # if passive_updates and sync detected
339 # this was a pk->pk sync, use the new
340 # value to locate the row, since the
341 # DB would already have set this
342 if ("pk_cascaded", state, col) in \
343 uowtransaction.attributes:
344 value = history.added[0]
345 params[col._label] = value
346 else:
347 # use the old value to
348 # locate the row
349 value = history.deleted[0]
350 params[col._label] = value
351 hasdata = True
352 else:
353 # row switch logic can reach us here
354 # remove the pk from the update params
355 # so the update doesn't
356 # attempt to include the pk in the
357 # update statement
358 del params[col.key]
359 value = history.added[0]
360 params[col._label] = value
361 if value is None:
362 hasnull = True
363 else:
364 hasdata = True
365 elif col in pks:
366 value = state.manager[prop.key].impl.get(
367 state, state_dict)
368 if value is None:
369 hasnull = True
370 params[col._label] = value
371 if hasdata:
372 if hasnull:
373 raise sa_exc.FlushError(
374 "Can't update table "
375 "using NULL for primary "
376 "key value")
377 update.append((state, state_dict, params, mapper,
378 connection, value_params))
379 return update
380
381
382def _collect_post_update_commands(base_mapper, uowtransaction, table,
383 states_to_update, post_update_cols):
384 """Identify sets of values to use in UPDATE statements for a
385 list of states within a post_update operation.
386
387 """
388
389 update = []
390 for state, state_dict, mapper, connection in states_to_update:
391 if table not in mapper._pks_by_table:
392 continue
393 pks = mapper._pks_by_table[table]
394 params = {}
395 hasdata = False
396
397 for col in mapper._cols_by_table[table]:
398 if col in pks:
399 params[col._label] = \
400 mapper._get_state_attr_by_column(
401 state,
402 state_dict, col)
403 elif col in post_update_cols:
404 prop = mapper._columntoproperty[col]
405 history = attributes.get_state_history(
406 state, prop.key,
407 attributes.PASSIVE_NO_INITIALIZE)
408 if history.added:
409 value = history.added[0]
410 params[col.key] = value
411 hasdata = True
412 if hasdata:
413 update.append((state, state_dict, params, mapper,
414 connection))
415 return update
416
417def _collect_delete_commands(base_mapper, uowtransaction, table,
418 states_to_delete):
419 """Identify values to use in DELETE statements for a list of
420 states to be deleted."""
421
422 delete = util.defaultdict(list)
423
424 for state, state_dict, mapper, has_identity, connection \
425 in states_to_delete:
426 if not has_identity or table not in mapper._pks_by_table:
427 continue
428
429 params = {}
430 delete[connection].append(params)
431 for col in mapper._pks_by_table[table]:
432 params[col.key] = \
433 value = \
434 mapper._get_state_attr_by_column(
435 state, state_dict, col)
436 if value is None:
437 raise sa_exc.FlushError(
438 "Can't delete from table "
439 "using NULL for primary "
440 "key value")
441
442 if mapper.version_id_col is not None and \
443 table.c.contains_column(mapper.version_id_col):
444 params[mapper.version_id_col.key] = \
445 mapper._get_committed_state_attr_by_column(
446 state, state_dict,
447 mapper.version_id_col)
448 return delete
449
450
451def _emit_update_statements(base_mapper, uowtransaction,
452 cached_connections, mapper, table, update):
453 """Emit UPDATE statements corresponding to value lists collected
454 by _collect_update_commands()."""
455
456 needs_version_id = mapper.version_id_col is not None and \
457 table.c.contains_column(mapper.version_id_col)
458
459 def update_stmt():
460 clause = sql.and_()
461
462 for col in mapper._pks_by_table[table]:
463 clause.clauses.append(col == sql.bindparam(col._label,
464 type_=col.type))
465
466 if needs_version_id:
467 clause.clauses.append(mapper.version_id_col ==\
468 sql.bindparam(mapper.version_id_col._label,
469 type_=col.type))
470
471 return table.update(clause)
472
473 statement = base_mapper._memo(('update', table), update_stmt)
474
475 rows = 0
476 for state, state_dict, params, mapper, \
477 connection, value_params in update:
478
479 if value_params:
480 c = connection.execute(
481 statement.values(value_params),
482 params)
483 else:
484 c = cached_connections[connection].\
485 execute(statement, params)
486
487 _postfetch(
488 mapper,
489 uowtransaction,
490 table,
491 state,
492 state_dict,
493 c.context.prefetch_cols,
494 c.context.postfetch_cols,
495 c.context.compiled_parameters[0],
496 value_params)
497 rows += c.rowcount
498
499 if connection.dialect.supports_sane_rowcount:
500 if rows != len(update):
501 raise orm_exc.StaleDataError(
502 "UPDATE statement on table '%s' expected to "
503 "update %d row(s); %d were matched." %
504 (table.description, len(update), rows))
505
506 elif needs_version_id:
507 util.warn("Dialect %s does not support updated rowcount "
508 "- versioning cannot be verified." %
509 c.dialect.dialect_description,
510 stacklevel=12)
511
512def _emit_insert_statements(base_mapper, uowtransaction,
513 cached_connections, table, insert):
514 """Emit INSERT statements corresponding to value lists collected
515 by _collect_insert_commands()."""
516
517 statement = base_mapper._memo(('insert', table), table.insert)
518
519 for (connection, pkeys, hasvalue, has_all_pks), \
520 records in groupby(insert,
521 lambda rec: (rec[4],
522 rec[2].keys(),
523 bool(rec[5]),
524 rec[6])
525 ):
526 if has_all_pks and not hasvalue:
527 records = list(records)
528 multiparams = [rec[2] for rec in records]
529 c = cached_connections[connection].\
530 execute(statement, multiparams)
531
532 for (state, state_dict, params, mapper,
533 conn, value_params, has_all_pks), \
534 last_inserted_params in \
535 zip(records, c.context.compiled_parameters):
536 _postfetch(
537 mapper,
538 uowtransaction,
539 table,
540 state,
541 state_dict,
542 c.context.prefetch_cols,
543 c.context.postfetch_cols,
544 last_inserted_params,
545 value_params)
546
547 else:
548 for state, state_dict, params, mapper, \
549 connection, value_params, \
550 has_all_pks in records:
551
552 if value_params:
553 result = connection.execute(
554 statement.values(value_params),
555 params)
556 else:
557 result = cached_connections[connection].\
558 execute(statement, params)
559
560 primary_key = result.context.inserted_primary_key
561
562 if primary_key is not None:
563 # set primary key attributes
564 for pk, col in zip(primary_key,
565 mapper._pks_by_table[table]):
566 prop = mapper._columntoproperty[col]
567 if state_dict.get(prop.key) is None:
568 # TODO: would rather say:
569 #state_dict[prop.key] = pk
570 mapper._set_state_attr_by_column(
571 state,
572 state_dict,
573 col, pk)
574
575 _postfetch(
576 mapper,
577 uowtransaction,
578 table,
579 state,
580 state_dict,
581 result.context.prefetch_cols,
582 result.context.postfetch_cols,
583 result.context.compiled_parameters[0],
584 value_params)
585
586
587
588def _emit_post_update_statements(base_mapper, uowtransaction,
589 cached_connections, mapper, table, update):
590 """Emit UPDATE statements corresponding to value lists collected
591 by _collect_post_update_commands()."""
592
593 def update_stmt():
594 clause = sql.and_()
595
596 for col in mapper._pks_by_table[table]:
597 clause.clauses.append(col == sql.bindparam(col._label,
598 type_=col.type))
599
600 return table.update(clause)
601
602 statement = base_mapper._memo(('post_update', table), update_stmt)
603
604 # execute each UPDATE in the order according to the original
605 # list of states to guarantee row access order, but
606 # also group them into common (connection, cols) sets
607 # to support executemany().
608 for key, grouper in groupby(
609 update, lambda rec: (rec[4], rec[2].keys())
610 ):
611 connection = key[0]
612 multiparams = [params for state, state_dict,
613 params, mapper, conn in grouper]
614 cached_connections[connection].\
615 execute(statement, multiparams)
616
617
618def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
619 mapper, table, delete):
620 """Emit DELETE statements corresponding to value lists collected
621 by _collect_delete_commands()."""
622
623 need_version_id = mapper.version_id_col is not None and \
624 table.c.contains_column(mapper.version_id_col)
625
626 def delete_stmt():
627 clause = sql.and_()
628 for col in mapper._pks_by_table[table]:
629 clause.clauses.append(
630 col == sql.bindparam(col.key, type_=col.type))
631
632 if need_version_id:
633 clause.clauses.append(
634 mapper.version_id_col ==
635 sql.bindparam(
636 mapper.version_id_col.key,
637 type_=mapper.version_id_col.type
638 )
639 )
640
641 return table.delete(clause)
642
643 for connection, del_objects in delete.iteritems():
644 statement = base_mapper._memo(('delete', table), delete_stmt)
645
646 connection = cached_connections[connection]
647
648 if need_version_id:
649 # TODO: need test coverage for this [ticket:1761]
650 if connection.dialect.supports_sane_rowcount:
651 rows = 0
652 # execute deletes individually so that versioned
653 # rows can be verified
654 for params in del_objects:
655 c = connection.execute(statement, params)
656 rows += c.rowcount
657 if rows != len(del_objects):
658 raise orm_exc.StaleDataError(
659 "DELETE statement on table '%s' expected to "
660 "delete %d row(s); %d were matched." %
661 (table.description, len(del_objects), c.rowcount)
662 )
663 else:
664 util.warn(
665 "Dialect %s does not support deleted rowcount "
666 "- versioning cannot be verified." %
667 connection.dialect.dialect_description,
668 stacklevel=12)
669 connection.execute(statement, del_objects)
670 else:
671 connection.execute(statement, del_objects)
672
673
674def _finalize_insert_update_commands(base_mapper, uowtransaction,
675 states_to_insert, states_to_update):
676 """finalize state on states that have been inserted or updated,
677 including calling after_insert/after_update events.
678
679 """
680 for state, state_dict, mapper, connection, has_identity, \
681 instance_key, row_switch in states_to_insert + \
682 states_to_update:
683
684 if mapper._readonly_props:
685 readonly = state.unmodified_intersection(
686 [p.key for p in mapper._readonly_props
687 if p.expire_on_flush or p.key not in state.dict]
688 )
689 if readonly:
690 state.expire_attributes(state.dict, readonly)
691
692 # if eager_defaults option is enabled,
693 # refresh whatever has been expired.
694 if base_mapper.eager_defaults and state.unloaded:
695 state.key = base_mapper._identity_key_from_state(state)
696 uowtransaction.session.query(base_mapper)._load_on_ident(
697 state.key, refresh_state=state,
698 only_load_props=state.unloaded)
699
700 # call after_XXX extensions
701 if not has_identity:
702 mapper.dispatch.after_insert(mapper, connection, state)
703 else:
704 mapper.dispatch.after_update(mapper, connection, state)
705
706def _postfetch(mapper, uowtransaction, table,
707 state, dict_, prefetch_cols, postfetch_cols,
708 params, value_params):
709 """Expire attributes in need of newly persisted database state,
710 after an INSERT or UPDATE statement has proceeded for that
711 state."""
712
713 if mapper.version_id_col is not None:
714 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
715
716 for c in prefetch_cols:
717 if c.key in params and c in mapper._columntoproperty:
718 mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
719
720 if postfetch_cols:
721 state.expire_attributes(state.dict,
722 [mapper._columntoproperty[c].key
723 for c in postfetch_cols if c in
724 mapper._columntoproperty]
725 )
726
727 # synchronize newly inserted ids from one table to the next
728 # TODO: this still goes a little too often. would be nice to
729 # have definitive list of "columns that changed" here
730 for m, equated_pairs in mapper._table_to_equated[table]:
731 sync.populate(state, m, state, m,
732 equated_pairs,
733 uowtransaction,
734 mapper.passive_updates)
735
736def _connections_for_states(base_mapper, uowtransaction, states):
737 """Return an iterator of (state, state.dict, mapper, connection).
738
739 The states are sorted according to _sort_states, then paired
740 with the connection they should be using for the given
741 unit of work transaction.
742
743 """
744 # if session has a connection callable,
745 # organize individual states with the connection
746 # to use for update
747 if uowtransaction.session.connection_callable:
748 connection_callable = \
749 uowtransaction.session.connection_callable
750 else:
751 connection = None
752 connection_callable = None
753
754 for state in _sort_states(states):
755 if connection_callable:
756 connection = connection_callable(base_mapper, state.obj())
757 elif not connection:
758 connection = uowtransaction.transaction.connection(
759 base_mapper)
760
761 mapper = _state_mapper(state)
762
763 yield state, state.dict, mapper, connection
764
765def _cached_connection_dict(base_mapper):
766 # dictionary of connection->connection_with_cache_options.
767 return util.PopulateDict(
768 lambda conn:conn.execution_options(
769 compiled_cache=base_mapper._compiled_cache
770 ))
771
772def _sort_states(states):
773 pending = set(states)
774 persistent = set(s for s in pending if s.key is not None)
775 pending.difference_update(persistent)
776 return sorted(pending, key=operator.attrgetter("insert_order")) + \
777 sorted(persistent, key=lambda q:q.key[1])
778
779