PageRenderTime 42ms CodeModel.GetById 15ms app.highlight 24ms RepoModel.GetById 1ms app.codeStats 0ms

/johnny/transaction.py

https://bitbucket.org/jmoiron/johnny-cache/
Python | 321 lines | 249 code | 36 blank | 36 comment | 71 complexity | 506ff4a2421b067095bf656445d4756c MD5 | raw file
  1import django
  2from django.db import transaction as django_transaction
  3from django.db import connection
  4try:
  5    from django.db import DEFAULT_DB_ALIAS
  6except:
  7    DEFUALT_DB_ALIAS = None
  8
  9from johnny.decorators import wraps, available_attrs
 10
 11
 12class TransactionManager(object):
 13    """
 14    TransactionManager is a wrapper around a cache_backend that is
 15    transaction aware.
 16
 17    If we are in a transaction, it will return the locally cached version.
 18
 19      * On rollback, it will flush all local caches
 20      * On commit, it will push them up to the real shared cache backend
 21        (ex. memcached).
 22    """
 23    _patched_var = False
 24
 25    def __init__(self, cache_backend, keygen):
 26        from johnny import cache, settings
 27
 28        self.timeout = settings.MIDDLEWARE_SECONDS
 29        self.prefix = settings.MIDDLEWARE_KEY_PREFIX
 30
 31        self.cache_backend = cache_backend
 32        self.local = cache.local
 33        self.keygen = keygen(self.prefix)
 34        self._originals = {}
 35        self._dirty_backup = {}
 36
 37        self.local['trans_sids'] = {}
 38
 39    def _get_sid(self, using=None):
 40        if 'trans_sids' not in self.local:
 41            self.local['trans_sids'] = {}
 42        d = self.local['trans_sids']
 43        if self.has_multi_db():
 44            if using is None:
 45                using = DEFAULT_DB_ALIAS
 46        else:
 47            using = 'default'
 48        if using not in d:
 49            d[using] = []
 50        return d[using]
 51
 52    def _clear_sid_stack(self, using=None):
 53        if self.has_multi_db():
 54            if using is None:
 55                using = DEFAULT_DB_ALIAS
 56        else:
 57            using = 'default'
 58        if using in self.local.get('trans_sids', {}):
 59            del self.local['trans_sids']
 60
 61    def has_multi_db(self):
 62        if django.VERSION[:2] > (1, 1):
 63            return True
 64        return False
 65
 66    def is_managed(self):
 67        return django_transaction.is_managed()
 68
 69    def get(self, key, default=None, using=None):
 70        if self.is_managed() and self._patched_var:
 71            val = self.local.get(key, None)
 72            if val:
 73                return val
 74            if self._uses_savepoints():
 75                val = self._get_from_savepoints(key, using)
 76                if val:
 77                    return val
 78        return self.cache_backend.get(key, default)
 79
 80    def _get_from_savepoints(self, key, using=None):
 81        sids = self._get_sid(using)
 82        cp = list(sids)
 83        cp.reverse()
 84        for sid in cp:
 85            if key in self.local[sid]:
 86                return self.local[sid][key]
 87
 88    def _trunc_using(self, using):
 89        if self.has_multi_db():
 90            if using is None:
 91                using = DEFAULT_DB_ALIAS
 92        else:
 93            using = 'default'
 94        if len(using) > 100:
 95            using = using[0:68] + self.keygen.gen_key(using[68:])
 96        return using
 97
 98    def set(self, key, val, timeout=None, using=None):
 99        """
100        Set will be using the generational key, so if another thread
101        bumps this key, the localstore version will still be invalid.
102        If the key is bumped during a transaction it will be new
103        to the global cache on commit, so it will still be a bump.
104        """
105        if timeout is None:
106            timeout = self.timeout
107        if self.is_managed() and self._patched_var:
108            self.local[key] = val
109        else:
110            self.cache_backend.set(key, val, timeout)
111
112    def _clear(self, using=None):
113        if self.has_multi_db():
114            self.local.clear('%s_%s_*' %
115                             (self.prefix, self._trunc_using(using)))
116        else:
117            self.local.clear('%s_*' % self.prefix)
118
119    def _flush(self, commit=True, using=None):
120        """
121        Flushes the internal cache, either to the memcache or rolls back
122        """
123        if commit:
124            # XXX: multi-set?
125            if self._uses_savepoints():
126                self._commit_all_savepoints(using)
127            if self.has_multi_db():
128                c = self.local.mget('%s_%s_*' %
129                                    (self.prefix, self._trunc_using(using)))
130            else:
131                c = self.local.mget('%s_*' % self.prefix)
132            for key, value in c.iteritems():
133                self.cache_backend.set(key, value, self.timeout)
134        else:
135            if self._uses_savepoints():
136                self._rollback_all_savepoints(using)
137        self._clear(using)
138        self._clear_sid_stack(using)
139
140    def _patched(self, original, commit=True):
141        @wraps(original, assigned=available_attrs(original))
142        def newfun(using=None):
143            #1.2 version
144            original(using=using)
145            self._flush(commit=commit, using=using)
146
147        @wraps(original, assigned=available_attrs(original))
148        def newfun11():
149            #1.1 version
150            original()
151            self._flush(commit=commit)
152
153        if django.VERSION[:2] == (1, 1):
154            return newfun11
155        elif django.VERSION[:2] > (1, 1):
156            return newfun
157        return original
158
159    def _uses_savepoints(self):
160        return connection.features.uses_savepoints
161
162    def _sid_key(self, sid, using=None):
163        if using is not None:
164            prefix = 'trans_savepoint_%s' % using
165        else:
166            prefix = 'trans_savepoint'
167
168        if sid.startswith(prefix):
169            return sid
170        return '%s_%s'%(prefix, sid)
171
172    def _create_savepoint(self, sid, using=None):
173        key = self._sid_key(sid, using)
174
175        #get all local dirty items
176        if self.has_multi_db():
177            c = self.local.mget('%s_%s_*' %
178                                (self.prefix, self._trunc_using(using)))
179        else:
180            c = self.local.mget('%s_*' % self.prefix)
181        #store them to a dictionary in the localstore
182        if key not in self.local:
183            self.local[key] = {}
184        for k, v in c.iteritems():
185            self.local[key][k] = v
186        #clear the dirty
187        self._clear(using)
188        #append the key to the savepoint stack
189        sids = self._get_sid(using)
190        sids.append(key)
191
192    def _rollback_savepoint(self, sid, using=None):
193        sids = self._get_sid(using)
194        key = self._sid_key(sid, using)
195        stack = []
196        try:
197            popped = None
198            while popped != key:
199                popped = sids.pop()
200                stack.insert(0, popped)
201            #delete items from localstore
202            for i in stack:
203                del self.local[i]
204            #clear dirty
205            self._clear(using)
206        except IndexError:
207            #key not found, don't delete from localstore, restore sid stack
208            for i in stack:
209                sids.insert(0, i)
210
211    def _commit_savepoint(self, sid, using=None):
212        # commit is not a commit but is in reality just a clear back to that
213        # savepoint and adds the items back to the dirty transaction.
214        key = self._sid_key(sid, using)
215        sids = self._get_sid(using)
216        stack = []
217        try:
218            popped = None
219            while popped != key:
220                popped = sids.pop()
221                stack.insert(0, popped)
222            self._store_dirty(using)
223            for i in stack:
224                for k, v in self.local[i].iteritems():
225                    self.local[k] = v
226                del self.local[i]
227            self._restore_dirty(using)
228        except IndexError:
229            for i in stack:
230                sids.insert(0, i)
231
232    def _commit_all_savepoints(self, using=None):
233        sids = self._get_sid(using)
234        if sids:
235            self._commit_savepoint(sids[0], using)
236
237    def _rollback_all_savepoints(self, using=None):
238        sids = self._get_sid(using)
239        if sids:
240            self._rollback_savepoint(sids[0], using)
241
242    def _store_dirty(self, using=None):
243        if self.has_multi_db():
244            c = self.local.mget('%s_%s_*' %
245                                (self.prefix, self._trunc_using(using)))
246        else:
247            c = self.local.mget('%s_*' % self.prefix)
248        backup = 'trans_dirty_store_%s' % self._trunc_using(using)
249        self.local[backup] = {}
250        for k, v in c.iteritems():
251            self.local[backup][k] = v
252        self._clear(using)
253
254    def _restore_dirty(self, using=None):
255        backup = 'trans_dirty_store_%s' % self._trunc_using(using)
256        for k, v in self.local.get(backup, {}).iteritems():
257            self.local[k] = v
258        del self.local[backup]
259
260    def _savepoint(self, original):
261        @wraps(original, assigned=available_attrs(original))
262        def newfun(using=None):
263            if using != None:
264                sid = original(using=using)
265            else:
266                sid = original()
267            if self._uses_savepoints():
268                self._create_savepoint(sid, using)
269            return sid
270        return newfun
271
272    def _savepoint_rollback(self, original):
273        def newfun(sid, *args, **kwargs):
274            original(sid, *args, **kwargs)
275            if self._uses_savepoints():
276                if len(args) == 2:
277                    using = args[1]
278                else:
279                    using = kwargs.get('using', None)
280                self._rollback_savepoint(sid, using)
281        return newfun
282
283    def _savepoint_commit(self, original):
284        def newfun(sid, *args, **kwargs):
285            original(sid, *args, **kwargs)
286            if self._uses_savepoints():
287                if len(args) == 1:
288                    using = args[0]
289                else:
290                    using = kwargs.get('using', None)
291                self._commit_savepoint(sid, using)
292        return newfun
293
294    def _getreal(self, name):
295        return getattr(django_transaction, 'real_%s' % name,
296                getattr(django_transaction, name))
297
298    def patch(self):
299        """
300        This function monkey patches commit and rollback
301        writes to the cache should not happen until commit (unless our state
302        isn't managed). It does not yet support savepoints.
303        """
304        if not self._patched_var:
305            self._originals['rollback'] = self._getreal('rollback')
306            self._originals['commit'] = self._getreal('commit')
307            self._originals['savepoint'] = self._getreal('savepoint')
308            self._originals['savepoint_rollback'] = self._getreal('savepoint_rollback')
309            self._originals['savepoint_commit'] = self._getreal('savepoint_commit')
310            django_transaction.rollback = self._patched(django_transaction.rollback, False)
311            django_transaction.commit = self._patched(django_transaction.commit, True)
312            django_transaction.savepoint = self._savepoint(django_transaction.savepoint)
313            django_transaction.savepoint_rollback = self._savepoint_rollback(django_transaction.savepoint_rollback)
314            django_transaction.savepoint_commit = self._savepoint_commit(django_transaction.savepoint_commit)
315
316            self._patched_var = True
317
318    def unpatch(self):
319        for fun in self._originals:
320            setattr(django_transaction, fun, self._originals[fun])
321        self._patched_var = False