/johnny/transaction.py

https://github.com/GoodCloud/johnny-cache
Python | 318 lines | 245 code | 37 blank | 36 comment | 70 complexity | 2572c83735367a2e2ce07c9a1db16c42 MD5 | raw file
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from django.db import transaction as django_transaction
  4. from django.db import connection
  5. try:
  6. from django.db import DEFAULT_DB_ALIAS
  7. except:
  8. DEFUALT_DB_ALIAS = None
  9. try:
  10. from functools import wraps
  11. except ImportError:
  12. from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
  13. import django
  14. class TransactionManager(object):
  15. """TransactionManager is a wrapper around a cache_backend that is
  16. transaction aware.
  17. If we are in a transaction, it will return the locally cached version.
  18. * On rollback, it will flush all local caches
  19. * On commit, it will push them up to the real shared cache backend
  20. (ex. memcached).
  21. """
  22. _patched_var = False
  23. def __init__(self, cache_backend, keygen):
  24. from johnny import cache, settings
  25. self.timeout = settings.MIDDLEWARE_SECONDS
  26. self.prefix = settings.MIDDLEWARE_KEY_PREFIX
  27. self.cache_backend = cache_backend
  28. self.local = cache.local
  29. self.keygen = keygen(self.prefix)
  30. self._originals = {}
  31. self._dirty_backup = {}
  32. self.local['trans_sids'] = {}
  33. def _get_sid(self, using=None):
  34. if 'trans_sids' not in self.local:
  35. self.local['trans_sids'] = {}
  36. d = self.local['trans_sids']
  37. if self.has_multi_db():
  38. if using is None:
  39. using = DEFAULT_DB_ALIAS
  40. else:
  41. using = 'default'
  42. if using not in d:
  43. d[using] = []
  44. return d[using]
  45. def _clear_sid_stack(self, using=None):
  46. if self.has_multi_db():
  47. if using is None:
  48. using = DEFAULT_DB_ALIAS
  49. else:
  50. using = 'default'
  51. if using in self.local.get('trans_sids', {}):
  52. del self.local['trans_sids']
  53. def has_multi_db(self):
  54. if django.VERSION[:2] in ((1, 2), (1, 3)):
  55. return True
  56. return False
  57. def is_managed(self):
  58. return django_transaction.is_managed()
  59. def get(self, key, default=None, using=None):
  60. if self.is_managed() and self._patched_var:
  61. val = self.local.get(key, None)
  62. if val: return val
  63. if self._uses_savepoints():
  64. val = self._get_from_savepoints(key, using)
  65. if val: return val
  66. return self.cache_backend.get(key, default)
  67. def _get_from_savepoints(self, key, using=None):
  68. sids = self._get_sid(using)
  69. cp = list(sids)
  70. cp.reverse()
  71. for sid in cp:
  72. if key in self.local[sid]:
  73. return self.local[sid][key]
  74. def _trunc_using(self, using):
  75. if self.has_multi_db():
  76. if using is None:
  77. using = DEFAULT_DB_ALIAS
  78. else:
  79. using = 'default'
  80. if len(using) > 100:
  81. using = using[0:68] + self.keygen.gen_key(using[68:])
  82. return using
  83. def set(self, key, val, timeout=None, using=None):
  84. """
  85. Set will be using the generational key, so if another thread
  86. bumps this key, the localstore version will still be invalid.
  87. If the key is bumped during a transaction it will be new
  88. to the global cache on commit, so it will still be a bump.
  89. """
  90. if timeout is None:
  91. timeout = self.timeout
  92. if self.is_managed() and self._patched_var:
  93. self.local[key] = val
  94. else:
  95. self.cache_backend.set(key, val, timeout)
  96. def _clear(self, using=None):
  97. if self.has_multi_db():
  98. self.local.clear('%s_%s_*'%(self.prefix, self._trunc_using(using)))
  99. else:
  100. self.local.clear('%s_*'%self.prefix)
  101. def _flush(self, commit=True, using=None):
  102. """
  103. Flushes the internal cache, either to the memcache or rolls back
  104. """
  105. if commit:
  106. # XXX: multi-set?
  107. if self._uses_savepoints():
  108. self._commit_all_savepoints(using)
  109. if self.has_multi_db():
  110. c = self.local.mget('%s_%s_*'%(self.prefix, self._trunc_using(using)))
  111. else:
  112. c = self.local.mget('%s_*'%self.prefix)
  113. for key, value in c.iteritems():
  114. self.cache_backend.set(key, value, self.timeout)
  115. else:
  116. if self._uses_savepoints():
  117. self._rollback_all_savepoints(using)
  118. self._clear(using)
  119. self._clear_sid_stack(using)
  120. def _patched(self, original, commit=True):
  121. @wraps(original)
  122. def newfun(using=None):
  123. #1.2 version
  124. original(using=using)
  125. self._flush(commit=commit, using=using)
  126. @wraps(original)
  127. def newfun11():
  128. #1.1 version
  129. original()
  130. self._flush(commit=commit)
  131. if django.VERSION[:2] == (1,1):
  132. return newfun11
  133. elif django.VERSION[:2] in ((1,2), (1,3)):
  134. return newfun
  135. return original
  136. def _uses_savepoints(self):
  137. return connection.features.uses_savepoints
  138. def _sid_key(self, sid, using=None):
  139. if using != None:
  140. return 'trans_savepoint_%s_%s'%(using, sid)
  141. return 'trans_savepoint_%s'%sid
  142. def _create_savepoint(self, sid, using=None):
  143. key = self._sid_key(sid, using)
  144. #get all local dirty items
  145. if self.has_multi_db():
  146. c = self.local.mget('%s_%s_*'%(self.prefix, self._trunc_using(using)))
  147. else:
  148. c = self.local.mget('%s_*'%self.prefix)
  149. #store them to a dictionary in the localstore
  150. if key not in self.local:
  151. self.local[key] = {}
  152. for k, v in c.iteritems():
  153. self.local[key][k] = v
  154. #clear the dirty
  155. self._clear(using)
  156. #append the key to the savepoint stack
  157. sids = self._get_sid(using)
  158. sids.append(key)
  159. def _rollback_savepoint(self, sid, using=None):
  160. sids = self._get_sid(using)
  161. key = self._sid_key(sid, using)
  162. stack = []
  163. try:
  164. popped = None
  165. while popped != key:
  166. popped = sids.pop()
  167. stack.insert(0, popped)
  168. #delete items from localstore
  169. for i in stack:
  170. del self.local[i]
  171. #clear dirty
  172. self._clear(using)
  173. except IndexError, e:
  174. #key not found, don't delete from localstore, restore sid stack
  175. for i in stack:
  176. sids.insert(0, i)
  177. def _commit_savepoint(self, sid, using=None):
  178. #commit is not a commit but is in reality just a clear back to that savepoint
  179. #and adds the items back to the dirty transaction.
  180. key = self._sid_key(sid, using)
  181. sids = self._get_sid(using)
  182. stack = []
  183. try:
  184. popped = None
  185. while popped != key:
  186. popped = sids.pop()
  187. stack.insert(0, popped)
  188. self._store_dirty(using)
  189. for i in stack:
  190. for k, v in self.local[i].iteritems():
  191. self.local[k] = v
  192. del self.local[i]
  193. self._restore_dirty(using)
  194. except IndexError, e:
  195. for i in stack:
  196. sids.insert(0, i)
  197. def _commit_all_savepoints(self, using=None):
  198. sids = self._get_sid(using)
  199. if sids:
  200. self._commit_savepoint(sids[0], using)
  201. def _rollback_all_savepoints(self, using=None):
  202. sids = self._get_sid(using)
  203. if sids:
  204. self._rollback_savepoint(sids[0], using)
  205. def _store_dirty(self, using=None):
  206. if self.has_multi_db():
  207. c = self.local.mget('%s_%s_*'%(self.prefix, self._trunc_using(using)))
  208. else:
  209. c = self.local.mget('%s_*'%self.prefix)
  210. backup = 'trans_dirty_store_%s'%self._trunc_using(using)
  211. self.local[backup] = {}
  212. for k, v in c.iteritems():
  213. self.local[backup][k] = v
  214. self._clear(using)
  215. def _restore_dirty(self, using=None):
  216. backup = 'trans_dirty_store_%s'%self._trunc_using(using)
  217. for k, v in self.local.get(backup, {}).iteritems():
  218. self.local[k] = v
  219. del self.local[backup]
  220. def _savepoint(self, original):
  221. @wraps(original)
  222. def newfun(using=None):
  223. if using != None:
  224. sid = original(using=using)
  225. else:
  226. sid = original()
  227. if self._uses_savepoints():
  228. self._create_savepoint(sid, using)
  229. return sid
  230. return newfun
  231. def _savepoint_rollback(self, original):
  232. def newfun(sid, *args, **kwargs):
  233. original(sid, *args, **kwargs)
  234. if self._uses_savepoints():
  235. if len(args) == 2:
  236. using = args[1]
  237. else:
  238. using = kwargs.get('using', None)
  239. self._rollback_savepoint(sid, using)
  240. return newfun
  241. def _savepoint_commit(self, original):
  242. def newfun(sid, *args, **kwargs):
  243. original(sid, *args, **kwargs)
  244. if self._uses_savepoints():
  245. if len(args) == 1:
  246. using = args[0]
  247. else:
  248. using = kwargs.get('using', None)
  249. self._commit_savepoint(sid, using)
  250. return newfun
  251. def _getreal(self, name):
  252. return getattr(django_transaction, 'real_%s' % name,
  253. getattr(django_transaction, name))
  254. def patch(self):
  255. """
  256. This function monkey patches commit and rollback
  257. writes to the cache should not happen until commit (unless our state isn't managed).
  258. It does not yet support savepoints.
  259. """
  260. if not self._patched_var:
  261. self._originals['rollback'] = self._getreal('rollback')
  262. self._originals['commit'] = self._getreal('commit')
  263. self._originals['savepoint'] = self._getreal('savepoint')
  264. self._originals['savepoint_rollback'] = self._getreal('savepoint_rollback')
  265. self._originals['savepoint_commit'] = self._getreal('savepoint_commit')
  266. django_transaction.rollback = self._patched(django_transaction.rollback, False)
  267. django_transaction.commit = self._patched(django_transaction.commit, True)
  268. django_transaction.savepoint = self._savepoint(django_transaction.savepoint)
  269. django_transaction.savepoint_rollback = self._savepoint_rollback(django_transaction.savepoint_rollback)
  270. django_transaction.savepoint_commit = self._savepoint_commit(django_transaction.savepoint_commit)
  271. self._patched_var = True
  272. def unpatch(self):
  273. for fun in self._originals:
  274. setattr(django_transaction, fun, self._originals[fun])
  275. self._patched_var = False