PageRenderTime 28ms CodeModel.GetById 0ms RepoModel.GetById 0ms app.codeStats 0ms

/skink/lib/sqlalchemy/orm/scoping.py

http://github.com/heynemann/skink
Python | 208 lines | 150 code | 28 blank | 30 comment | 32 complexity | 9f9f99e05ec95c817290630d39c13fca MD5 | raw file
  1. # scoping.py
  2. # Copyright (C) the SQLAlchemy authors and contributors
  3. #
  4. # This module is part of SQLAlchemy and is released under
  5. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  6. import sqlalchemy.exceptions as sa_exc
  7. from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
  8. to_list, get_cls_kwargs, deprecated
  9. from sqlalchemy.orm import (
  10. EXT_CONTINUE, MapperExtension, class_mapper, object_session
  11. )
  12. from sqlalchemy.orm import exc as orm_exc
  13. from sqlalchemy.orm.session import Session
  14. __all__ = ['ScopedSession']
  15. class ScopedSession(object):
  16. """Provides thread-local management of Sessions.
  17. Usage::
  18. Session = scoped_session(sessionmaker(autoflush=True))
  19. ... use session normally.
  20. """
  21. def __init__(self, session_factory, scopefunc=None):
  22. self.session_factory = session_factory
  23. if scopefunc:
  24. self.registry = ScopedRegistry(session_factory, scopefunc)
  25. else:
  26. self.registry = ThreadLocalRegistry(session_factory)
  27. self.extension = _ScopedExt(self)
  28. def __call__(self, **kwargs):
  29. if kwargs:
  30. scope = kwargs.pop('scope', False)
  31. if scope is not None:
  32. if self.registry.has():
  33. raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
  34. else:
  35. sess = self.session_factory(**kwargs)
  36. self.registry.set(sess)
  37. return sess
  38. else:
  39. return self.session_factory(**kwargs)
  40. else:
  41. return self.registry()
  42. def remove(self):
  43. if self.registry.has():
  44. self.registry().close()
  45. self.registry.clear()
  46. @deprecated("Session.mapper is deprecated. "
  47. "Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper "
  48. "for information on how to replicate its behavior.")
  49. def mapper(self, *args, **kwargs):
  50. """return a mapper() function which associates this ScopedSession with the Mapper.
  51. DEPRECATED.
  52. """
  53. from sqlalchemy.orm import mapper
  54. extension_args = dict((arg, kwargs.pop(arg))
  55. for arg in get_cls_kwargs(_ScopedExt)
  56. if arg in kwargs)
  57. kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
  58. if extension_args:
  59. extension.append(self.extension.configure(**extension_args))
  60. else:
  61. extension.append(self.extension)
  62. return mapper(*args, **kwargs)
  63. def configure(self, **kwargs):
  64. """reconfigure the sessionmaker used by this ScopedSession."""
  65. self.session_factory.configure(**kwargs)
  66. def query_property(self, query_cls=None):
  67. """return a class property which produces a `Query` object against the
  68. class when called.
  69. e.g.::
  70. Session = scoped_session(sessionmaker())
  71. class MyClass(object):
  72. query = Session.query_property()
  73. # after mappers are defined
  74. result = MyClass.query.filter(MyClass.name=='foo').all()
  75. Produces instances of the session's configured query class by
  76. default. To override and use a custom implementation, provide
  77. a ``query_cls`` callable. The callable will be invoked with
  78. the class's mapper as a positional argument and a session
  79. keyword argument.
  80. There is no limit to the number of query properties placed on
  81. a class.
  82. """
  83. class query(object):
  84. def __get__(s, instance, owner):
  85. try:
  86. mapper = class_mapper(owner)
  87. if mapper:
  88. if query_cls:
  89. # custom query class
  90. return query_cls(mapper, session=self.registry())
  91. else:
  92. # session's configured query class
  93. return self.registry().query(mapper)
  94. except orm_exc.UnmappedClassError:
  95. return None
  96. return query()
  97. def instrument(name):
  98. def do(self, *args, **kwargs):
  99. return getattr(self.registry(), name)(*args, **kwargs)
  100. return do
  101. for meth in Session.public_methods:
  102. setattr(ScopedSession, meth, instrument(meth))
  103. def makeprop(name):
  104. def set(self, attr):
  105. setattr(self.registry(), name, attr)
  106. def get(self):
  107. return getattr(self.registry(), name)
  108. return property(get, set)
  109. for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active'):
  110. setattr(ScopedSession, prop, makeprop(prop))
  111. def clslevel(name):
  112. def do(cls, *args, **kwargs):
  113. return getattr(Session, name)(*args, **kwargs)
  114. return classmethod(do)
  115. for prop in ('close_all', 'object_session', 'identity_key'):
  116. setattr(ScopedSession, prop, clslevel(prop))
  117. class _ScopedExt(MapperExtension):
  118. def __init__(self, context, validate=False, save_on_init=True):
  119. self.context = context
  120. self.validate = validate
  121. self.save_on_init = save_on_init
  122. self.set_kwargs_on_init = True
  123. def validating(self):
  124. return _ScopedExt(self.context, validate=True)
  125. def configure(self, **kwargs):
  126. return _ScopedExt(self.context, **kwargs)
  127. def instrument_class(self, mapper, class_):
  128. class query(object):
  129. def __getattr__(s, key):
  130. return getattr(self.context.registry().query(class_), key)
  131. def __call__(s):
  132. return self.context.registry().query(class_)
  133. def __get__(self, instance, cls):
  134. return self
  135. if not 'query' in class_.__dict__:
  136. class_.query = query()
  137. if self.set_kwargs_on_init and class_.__init__ is object.__init__:
  138. class_.__init__ = self._default__init__(mapper)
  139. def _default__init__(ext, mapper):
  140. def __init__(self, **kwargs):
  141. for key, value in kwargs.items():
  142. if ext.validate:
  143. if not mapper.get_property(key, resolve_synonyms=False,
  144. raiseerr=False):
  145. raise sa_exc.ArgumentError(
  146. "Invalid __init__ argument: '%s'" % key)
  147. setattr(self, key, value)
  148. return __init__
  149. def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
  150. if self.save_on_init:
  151. session = kwargs.pop('_sa_session', None)
  152. if session is None:
  153. session = self.context.registry()
  154. session._save_without_cascade(instance)
  155. return EXT_CONTINUE
  156. def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
  157. sess = object_session(instance)
  158. if sess:
  159. sess.expunge(instance)
  160. return EXT_CONTINUE
  161. def dispose_class(self, mapper, class_):
  162. if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
  163. if class_.__init__._oldinit is not None:
  164. class_.__init__ = class_.__init__._oldinit
  165. else:
  166. delattr(class_, '__init__')
  167. if hasattr(class_, 'query'):
  168. delattr(class_, 'query')