/johnny/transaction.py

https://bitbucket.org/jmoiron/johnny-cache/ · Python · 321 lines · 249 code · 36 blank · 36 comment · 70 complexity · 506ff4a2421b067095bf656445d4756c MD5 · raw file

  1. import django
  2. from django.db import transaction as django_transaction
  3. from django.db import connection
  4. try:
  5. from django.db import DEFAULT_DB_ALIAS
  6. except:
  7. DEFUALT_DB_ALIAS = None
  8. from johnny.decorators import wraps, available_attrs
  9. class TransactionManager(object):
  10. """
  11. TransactionManager is a wrapper around a cache_backend that is
  12. transaction aware.
  13. If we are in a transaction, it will return the locally cached version.
  14. * On rollback, it will flush all local caches
  15. * On commit, it will push them up to the real shared cache backend
  16. (ex. memcached).
  17. """
  18. _patched_var = False
  19. def __init__(self, cache_backend, keygen):
  20. from johnny import cache, settings
  21. self.timeout = settings.MIDDLEWARE_SECONDS
  22. self.prefix = settings.MIDDLEWARE_KEY_PREFIX
  23. self.cache_backend = cache_backend
  24. self.local = cache.local
  25. self.keygen = keygen(self.prefix)
  26. self._originals = {}
  27. self._dirty_backup = {}
  28. self.local['trans_sids'] = {}
  29. def _get_sid(self, using=None):
  30. if 'trans_sids' not in self.local:
  31. self.local['trans_sids'] = {}
  32. d = self.local['trans_sids']
  33. if self.has_multi_db():
  34. if using is None:
  35. using = DEFAULT_DB_ALIAS
  36. else:
  37. using = 'default'
  38. if using not in d:
  39. d[using] = []
  40. return d[using]
  41. def _clear_sid_stack(self, using=None):
  42. if self.has_multi_db():
  43. if using is None:
  44. using = DEFAULT_DB_ALIAS
  45. else:
  46. using = 'default'
  47. if using in self.local.get('trans_sids', {}):
  48. del self.local['trans_sids']
  49. def has_multi_db(self):
  50. if django.VERSION[:2] > (1, 1):
  51. return True
  52. return False
  53. def is_managed(self):
  54. return django_transaction.is_managed()
  55. def get(self, key, default=None, using=None):
  56. if self.is_managed() and self._patched_var:
  57. val = self.local.get(key, None)
  58. if val:
  59. return val
  60. if self._uses_savepoints():
  61. val = self._get_from_savepoints(key, using)
  62. if val:
  63. return val
  64. return self.cache_backend.get(key, default)
  65. def _get_from_savepoints(self, key, using=None):
  66. sids = self._get_sid(using)
  67. cp = list(sids)
  68. cp.reverse()
  69. for sid in cp:
  70. if key in self.local[sid]:
  71. return self.local[sid][key]
  72. def _trunc_using(self, using):
  73. if self.has_multi_db():
  74. if using is None:
  75. using = DEFAULT_DB_ALIAS
  76. else:
  77. using = 'default'
  78. if len(using) > 100:
  79. using = using[0:68] + self.keygen.gen_key(using[68:])
  80. return using
  81. def set(self, key, val, timeout=None, using=None):
  82. """
  83. Set will be using the generational key, so if another thread
  84. bumps this key, the localstore version will still be invalid.
  85. If the key is bumped during a transaction it will be new
  86. to the global cache on commit, so it will still be a bump.
  87. """
  88. if timeout is None:
  89. timeout = self.timeout
  90. if self.is_managed() and self._patched_var:
  91. self.local[key] = val
  92. else:
  93. self.cache_backend.set(key, val, timeout)
  94. def _clear(self, using=None):
  95. if self.has_multi_db():
  96. self.local.clear('%s_%s_*' %
  97. (self.prefix, self._trunc_using(using)))
  98. else:
  99. self.local.clear('%s_*' % self.prefix)
  100. def _flush(self, commit=True, using=None):
  101. """
  102. Flushes the internal cache, either to the memcache or rolls back
  103. """
  104. if commit:
  105. # XXX: multi-set?
  106. if self._uses_savepoints():
  107. self._commit_all_savepoints(using)
  108. if self.has_multi_db():
  109. c = self.local.mget('%s_%s_*' %
  110. (self.prefix, self._trunc_using(using)))
  111. else:
  112. c = self.local.mget('%s_*' % self.prefix)
  113. for key, value in c.iteritems():
  114. self.cache_backend.set(key, value, self.timeout)
  115. else:
  116. if self._uses_savepoints():
  117. self._rollback_all_savepoints(using)
  118. self._clear(using)
  119. self._clear_sid_stack(using)
  120. def _patched(self, original, commit=True):
  121. @wraps(original, assigned=available_attrs(original))
  122. def newfun(using=None):
  123. #1.2 version
  124. original(using=using)
  125. self._flush(commit=commit, using=using)
  126. @wraps(original, assigned=available_attrs(original))
  127. def newfun11():
  128. #1.1 version
  129. original()
  130. self._flush(commit=commit)
  131. if django.VERSION[:2] == (1, 1):
  132. return newfun11
  133. elif django.VERSION[:2] > (1, 1):
  134. return newfun
  135. return original
  136. def _uses_savepoints(self):
  137. return connection.features.uses_savepoints
  138. def _sid_key(self, sid, using=None):
  139. if using is not None:
  140. prefix = 'trans_savepoint_%s' % using
  141. else:
  142. prefix = 'trans_savepoint'
  143. if sid.startswith(prefix):
  144. return sid
  145. return '%s_%s'%(prefix, sid)
  146. def _create_savepoint(self, sid, using=None):
  147. key = self._sid_key(sid, using)
  148. #get all local dirty items
  149. if self.has_multi_db():
  150. c = self.local.mget('%s_%s_*' %
  151. (self.prefix, self._trunc_using(using)))
  152. else:
  153. c = self.local.mget('%s_*' % self.prefix)
  154. #store them to a dictionary in the localstore
  155. if key not in self.local:
  156. self.local[key] = {}
  157. for k, v in c.iteritems():
  158. self.local[key][k] = v
  159. #clear the dirty
  160. self._clear(using)
  161. #append the key to the savepoint stack
  162. sids = self._get_sid(using)
  163. sids.append(key)
  164. def _rollback_savepoint(self, sid, using=None):
  165. sids = self._get_sid(using)
  166. key = self._sid_key(sid, using)
  167. stack = []
  168. try:
  169. popped = None
  170. while popped != key:
  171. popped = sids.pop()
  172. stack.insert(0, popped)
  173. #delete items from localstore
  174. for i in stack:
  175. del self.local[i]
  176. #clear dirty
  177. self._clear(using)
  178. except IndexError:
  179. #key not found, don't delete from localstore, restore sid stack
  180. for i in stack:
  181. sids.insert(0, i)
  182. def _commit_savepoint(self, sid, using=None):
  183. # commit is not a commit but is in reality just a clear back to that
  184. # savepoint and adds the items back to the dirty transaction.
  185. key = self._sid_key(sid, using)
  186. sids = self._get_sid(using)
  187. stack = []
  188. try:
  189. popped = None
  190. while popped != key:
  191. popped = sids.pop()
  192. stack.insert(0, popped)
  193. self._store_dirty(using)
  194. for i in stack:
  195. for k, v in self.local[i].iteritems():
  196. self.local[k] = v
  197. del self.local[i]
  198. self._restore_dirty(using)
  199. except IndexError:
  200. for i in stack:
  201. sids.insert(0, i)
  202. def _commit_all_savepoints(self, using=None):
  203. sids = self._get_sid(using)
  204. if sids:
  205. self._commit_savepoint(sids[0], using)
  206. def _rollback_all_savepoints(self, using=None):
  207. sids = self._get_sid(using)
  208. if sids:
  209. self._rollback_savepoint(sids[0], using)
  210. def _store_dirty(self, using=None):
  211. if self.has_multi_db():
  212. c = self.local.mget('%s_%s_*' %
  213. (self.prefix, self._trunc_using(using)))
  214. else:
  215. c = self.local.mget('%s_*' % self.prefix)
  216. backup = 'trans_dirty_store_%s' % self._trunc_using(using)
  217. self.local[backup] = {}
  218. for k, v in c.iteritems():
  219. self.local[backup][k] = v
  220. self._clear(using)
  221. def _restore_dirty(self, using=None):
  222. backup = 'trans_dirty_store_%s' % self._trunc_using(using)
  223. for k, v in self.local.get(backup, {}).iteritems():
  224. self.local[k] = v
  225. del self.local[backup]
  226. def _savepoint(self, original):
  227. @wraps(original, assigned=available_attrs(original))
  228. def newfun(using=None):
  229. if using != None:
  230. sid = original(using=using)
  231. else:
  232. sid = original()
  233. if self._uses_savepoints():
  234. self._create_savepoint(sid, using)
  235. return sid
  236. return newfun
  237. def _savepoint_rollback(self, original):
  238. def newfun(sid, *args, **kwargs):
  239. original(sid, *args, **kwargs)
  240. if self._uses_savepoints():
  241. if len(args) == 2:
  242. using = args[1]
  243. else:
  244. using = kwargs.get('using', None)
  245. self._rollback_savepoint(sid, using)
  246. return newfun
  247. def _savepoint_commit(self, original):
  248. def newfun(sid, *args, **kwargs):
  249. original(sid, *args, **kwargs)
  250. if self._uses_savepoints():
  251. if len(args) == 1:
  252. using = args[0]
  253. else:
  254. using = kwargs.get('using', None)
  255. self._commit_savepoint(sid, using)
  256. return newfun
  257. def _getreal(self, name):
  258. return getattr(django_transaction, 'real_%s' % name,
  259. getattr(django_transaction, name))
  260. def patch(self):
  261. """
  262. This function monkey patches commit and rollback
  263. writes to the cache should not happen until commit (unless our state
  264. isn't managed). It does not yet support savepoints.
  265. """
  266. if not self._patched_var:
  267. self._originals['rollback'] = self._getreal('rollback')
  268. self._originals['commit'] = self._getreal('commit')
  269. self._originals['savepoint'] = self._getreal('savepoint')
  270. self._originals['savepoint_rollback'] = self._getreal('savepoint_rollback')
  271. self._originals['savepoint_commit'] = self._getreal('savepoint_commit')
  272. django_transaction.rollback = self._patched(django_transaction.rollback, False)
  273. django_transaction.commit = self._patched(django_transaction.commit, True)
  274. django_transaction.savepoint = self._savepoint(django_transaction.savepoint)
  275. django_transaction.savepoint_rollback = self._savepoint_rollback(django_transaction.savepoint_rollback)
  276. django_transaction.savepoint_commit = self._savepoint_commit(django_transaction.savepoint_commit)
  277. self._patched_var = True
  278. def unpatch(self):
  279. for fun in self._originals:
  280. setattr(django_transaction, fun, self._originals[fun])
  281. self._patched_var = False