PageRenderTime 74ms CodeModel.GetById 30ms app.highlight 28ms RepoModel.GetById 12ms app.codeStats 0ms

/Lib/test/test_decorators.py

http://unladen-swallow.googlecode.com/
Python | 309 lines | 254 code | 34 blank | 21 comment | 6 complexity | b5043a4b22853c7305b3af7650ed4a03 MD5 | raw file
  1import unittest
  2from test import test_support
  3
  4def funcattrs(**kwds):
  5    def decorate(func):
  6        func.__dict__.update(kwds)
  7        return func
  8    return decorate
  9
 10class MiscDecorators (object):
 11    @staticmethod
 12    def author(name):
 13        def decorate(func):
 14            func.__dict__['author'] = name
 15            return func
 16        return decorate
 17
 18# -----------------------------------------------
 19
 20class DbcheckError (Exception):
 21    def __init__(self, exprstr, func, args, kwds):
 22        # A real version of this would set attributes here
 23        Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
 24                           (exprstr, func, args, kwds))
 25
 26
 27def dbcheck(exprstr, globals=None, locals=None):
 28    "Decorator to implement debugging assertions"
 29    def decorate(func):
 30        expr = compile(exprstr, "dbcheck-%s" % func.func_name, "eval")
 31        def check(*args, **kwds):
 32            if not eval(expr, globals, locals):
 33                raise DbcheckError(exprstr, func, args, kwds)
 34            return func(*args, **kwds)
 35        return check
 36    return decorate
 37
 38# -----------------------------------------------
 39
 40def countcalls(counts):
 41    "Decorator to count calls to a function"
 42    def decorate(func):
 43        func_name = func.func_name
 44        counts[func_name] = 0
 45        def call(*args, **kwds):
 46            counts[func_name] += 1
 47            return func(*args, **kwds)
 48        call.func_name = func_name
 49        return call
 50    return decorate
 51
 52# -----------------------------------------------
 53
 54def memoize(func):
 55    saved = {}
 56    def call(*args):
 57        try:
 58            return saved[args]
 59        except KeyError:
 60            res = func(*args)
 61            saved[args] = res
 62            return res
 63        except TypeError:
 64            # Unhashable argument
 65            return func(*args)
 66    call.func_name = func.func_name
 67    return call
 68
 69# -----------------------------------------------
 70
 71class TestDecorators(unittest.TestCase):
 72
 73    def test_single(self):
 74        class C(object):
 75            @staticmethod
 76            def foo(): return 42
 77        self.assertEqual(C.foo(), 42)
 78        self.assertEqual(C().foo(), 42)
 79
 80    def test_staticmethod_function(self):
 81        @staticmethod
 82        def notamethod(x):
 83            return x
 84        self.assertRaises(TypeError, notamethod, 1)
 85
 86    def test_dotted(self):
 87        decorators = MiscDecorators()
 88        @decorators.author('Cleese')
 89        def foo(): return 42
 90        self.assertEqual(foo(), 42)
 91        self.assertEqual(foo.author, 'Cleese')
 92
 93    def test_argforms(self):
 94        # A few tests of argument passing, as we use restricted form
 95        # of expressions for decorators.
 96
 97        def noteargs(*args, **kwds):
 98            def decorate(func):
 99                setattr(func, 'dbval', (args, kwds))
100                return func
101            return decorate
102
103        args = ( 'Now', 'is', 'the', 'time' )
104        kwds = dict(one=1, two=2)
105        @noteargs(*args, **kwds)
106        def f1(): return 42
107        self.assertEqual(f1(), 42)
108        self.assertEqual(f1.dbval, (args, kwds))
109
110        @noteargs('terry', 'gilliam', eric='idle', john='cleese')
111        def f2(): return 84
112        self.assertEqual(f2(), 84)
113        self.assertEqual(f2.dbval, (('terry', 'gilliam'),
114                                     dict(eric='idle', john='cleese')))
115
116        @noteargs(1, 2,)
117        def f3(): pass
118        self.assertEqual(f3.dbval, ((1, 2), {}))
119
120    def test_dbcheck(self):
121        @dbcheck('args[1] is not None')
122        def f(a, b):
123            return a + b
124        self.assertEqual(f(1, 2), 3)
125        self.assertRaises(DbcheckError, f, 1, None)
126
127    def test_memoize(self):
128        counts = {}
129
130        @memoize
131        @countcalls(counts)
132        def double(x):
133            return x * 2
134        self.assertEqual(double.func_name, 'double')
135
136        self.assertEqual(counts, dict(double=0))
137
138        # Only the first call with a given argument bumps the call count:
139        #
140        self.assertEqual(double(2), 4)
141        self.assertEqual(counts['double'], 1)
142        self.assertEqual(double(2), 4)
143        self.assertEqual(counts['double'], 1)
144        self.assertEqual(double(3), 6)
145        self.assertEqual(counts['double'], 2)
146
147        # Unhashable arguments do not get memoized:
148        #
149        self.assertEqual(double([10]), [10, 10])
150        self.assertEqual(counts['double'], 3)
151        self.assertEqual(double([10]), [10, 10])
152        self.assertEqual(counts['double'], 4)
153
154    def test_errors(self):
155        # Test syntax restrictions - these are all compile-time errors:
156        #
157        for expr in [ "1+2", "x[3]", "(1, 2)" ]:
158            # Sanity check: is expr is a valid expression by itself?
159            compile(expr, "testexpr", "exec")
160
161            codestr = "@%s\ndef f(): pass" % expr
162            self.assertRaises(SyntaxError, compile, codestr, "test", "exec")
163
164        # You can't put multiple decorators on a single line:
165        #
166        self.assertRaises(SyntaxError, compile,
167                          "@f1 @f2\ndef f(): pass", "test", "exec")
168
169        # Test runtime errors
170
171        def unimp(func):
172            raise NotImplementedError
173        context = dict(nullval=None, unimp=unimp)
174
175        for expr, exc in [ ("undef", NameError),
176                           ("nullval", TypeError),
177                           ("nullval.attr", AttributeError),
178                           ("unimp", NotImplementedError)]:
179            codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
180            code = compile(codestr, "test", "exec")
181            self.assertRaises(exc, eval, code, context)
182
183    def test_double(self):
184        class C(object):
185            @funcattrs(abc=1, xyz="haha")
186            @funcattrs(booh=42)
187            def foo(self): return 42
188        self.assertEqual(C().foo(), 42)
189        self.assertEqual(C.foo.abc, 1)
190        self.assertEqual(C.foo.xyz, "haha")
191        self.assertEqual(C.foo.booh, 42)
192
193    def test_order(self):
194        # Test that decorators are applied in the proper order to the function
195        # they are decorating.
196        def callnum(num):
197            """Decorator factory that returns a decorator that replaces the
198            passed-in function with one that returns the value of 'num'"""
199            def deco(func):
200                return lambda: num
201            return deco
202        @callnum(2)
203        @callnum(1)
204        def foo(): return 42
205        self.assertEqual(foo(), 2,
206                            "Application order of decorators is incorrect")
207
208    def test_eval_order(self):
209        # Evaluating a decorated function involves four steps for each
210        # decorator-maker (the function that returns a decorator):
211        #
212        #    1: Evaluate the decorator-maker name
213        #    2: Evaluate the decorator-maker arguments (if any)
214        #    3: Call the decorator-maker to make a decorator
215        #    4: Call the decorator
216        #
217        # When there are multiple decorators, these steps should be
218        # performed in the above order for each decorator, but we should
219        # iterate through the decorators in the reverse of the order they
220        # appear in the source.
221
222        actions = []
223
224        def make_decorator(tag):
225            actions.append('makedec' + tag)
226            def decorate(func):
227                actions.append('calldec' + tag)
228                return func
229            return decorate
230
231        class NameLookupTracer (object):
232            def __init__(self, index):
233                self.index = index
234
235            def __getattr__(self, fname):
236                if fname == 'make_decorator':
237                    opname, res = ('evalname', make_decorator)
238                elif fname == 'arg':
239                    opname, res = ('evalargs', str(self.index))
240                else:
241                    assert False, "Unknown attrname %s" % fname
242                actions.append('%s%d' % (opname, self.index))
243                return res
244
245        c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
246
247        expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
248                             'evalname2', 'evalargs2', 'makedec2',
249                             'evalname3', 'evalargs3', 'makedec3',
250                             'calldec3', 'calldec2', 'calldec1' ]
251
252        actions = []
253        @c1.make_decorator(c1.arg)
254        @c2.make_decorator(c2.arg)
255        @c3.make_decorator(c3.arg)
256        def foo(): return 42
257        self.assertEqual(foo(), 42)
258
259        self.assertEqual(actions, expected_actions)
260
261        # Test the equivalence claim in chapter 7 of the reference manual.
262        #
263        actions = []
264        def bar(): return 42
265        bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
266        self.assertEqual(bar(), 42)
267        self.assertEqual(actions, expected_actions)
268
269class TestClassDecorators(unittest.TestCase):
270
271    def test_simple(self):
272        def plain(x):
273            x.extra = 'Hello'
274            return x
275        @plain
276        class C(object): pass
277        self.assertEqual(C.extra, 'Hello')
278
279    def test_double(self):
280        def ten(x):
281            x.extra = 10
282            return x
283        def add_five(x):
284            x.extra += 5
285            return x
286
287        @add_five
288        @ten
289        class C(object): pass
290        self.assertEqual(C.extra, 15)
291
292    def test_order(self):
293        def applied_first(x):
294            x.extra = 'first'
295            return x
296        def applied_second(x):
297            x.extra = 'second'
298            return x
299        @applied_second
300        @applied_first
301        class C(object): pass
302        self.assertEqual(C.extra, 'second')
303
304def test_main():
305    test_support.run_unittest(TestDecorators)
306    test_support.run_unittest(TestClassDecorators)
307
308if __name__=="__main__":
309    test_main()