PageRenderTime 26ms CodeModel.GetById 1ms app.highlight 20ms RepoModel.GetById 1ms app.codeStats 0ms

/Lib/test/test_defaultdict.py

http://unladen-swallow.googlecode.com/
Python | 179 lines | 153 code | 18 blank | 8 comment | 14 complexity | a09ec84784c574285cd949757fa58ce0 MD5 | raw file
  1"""Unit tests for collections.defaultdict."""
  2
  3import os
  4import copy
  5import tempfile
  6import unittest
  7from test import test_support
  8
  9from collections import defaultdict
 10
 11def foobar():
 12    return list
 13
 14class TestDefaultDict(unittest.TestCase):
 15
 16    def test_basic(self):
 17        d1 = defaultdict()
 18        self.assertEqual(d1.default_factory, None)
 19        d1.default_factory = list
 20        d1[12].append(42)
 21        self.assertEqual(d1, {12: [42]})
 22        d1[12].append(24)
 23        self.assertEqual(d1, {12: [42, 24]})
 24        d1[13]
 25        d1[14]
 26        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
 27        self.assert_(d1[12] is not d1[13] is not d1[14])
 28        d2 = defaultdict(list, foo=1, bar=2)
 29        self.assertEqual(d2.default_factory, list)
 30        self.assertEqual(d2, {"foo": 1, "bar": 2})
 31        self.assertEqual(d2["foo"], 1)
 32        self.assertEqual(d2["bar"], 2)
 33        self.assertEqual(d2[42], [])
 34        self.assert_("foo" in d2)
 35        self.assert_("foo" in d2.keys())
 36        self.assert_("bar" in d2)
 37        self.assert_("bar" in d2.keys())
 38        self.assert_(42 in d2)
 39        self.assert_(42 in d2.keys())
 40        self.assert_(12 not in d2)
 41        self.assert_(12 not in d2.keys())
 42        d2.default_factory = None
 43        self.assertEqual(d2.default_factory, None)
 44        try:
 45            d2[15]
 46        except KeyError, err:
 47            self.assertEqual(err.args, (15,))
 48        else:
 49            self.fail("d2[15] didn't raise KeyError")
 50        self.assertRaises(TypeError, defaultdict, 1)
 51
 52    def test_missing(self):
 53        d1 = defaultdict()
 54        self.assertRaises(KeyError, d1.__missing__, 42)
 55        d1.default_factory = list
 56        self.assertEqual(d1.__missing__(42), [])
 57
 58    def test_repr(self):
 59        d1 = defaultdict()
 60        self.assertEqual(d1.default_factory, None)
 61        self.assertEqual(repr(d1), "defaultdict(None, {})")
 62        self.assertEqual(eval(repr(d1)), d1)
 63        d1[11] = 41
 64        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
 65        d2 = defaultdict(int)
 66        self.assertEqual(d2.default_factory, int)
 67        d2[12] = 42
 68        self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
 69        def foo(): return 43
 70        d3 = defaultdict(foo)
 71        self.assert_(d3.default_factory is foo)
 72        d3[13]
 73        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
 74
 75    def test_print(self):
 76        d1 = defaultdict()
 77        def foo(): return 42
 78        d2 = defaultdict(foo, {1: 2})
 79        # NOTE: We can't use tempfile.[Named]TemporaryFile since this
 80        # code must exercise the tp_print C code, which only gets
 81        # invoked for *real* files.
 82        tfn = tempfile.mktemp()
 83        try:
 84            f = open(tfn, "w+")
 85            try:
 86                print >>f, d1
 87                print >>f, d2
 88                f.seek(0)
 89                self.assertEqual(f.readline(), repr(d1) + "\n")
 90                self.assertEqual(f.readline(), repr(d2) + "\n")
 91            finally:
 92                f.close()
 93        finally:
 94            os.remove(tfn)
 95
 96    def test_copy(self):
 97        d1 = defaultdict()
 98        d2 = d1.copy()
 99        self.assertEqual(type(d2), defaultdict)
100        self.assertEqual(d2.default_factory, None)
101        self.assertEqual(d2, {})
102        d1.default_factory = list
103        d3 = d1.copy()
104        self.assertEqual(type(d3), defaultdict)
105        self.assertEqual(d3.default_factory, list)
106        self.assertEqual(d3, {})
107        d1[42]
108        d4 = d1.copy()
109        self.assertEqual(type(d4), defaultdict)
110        self.assertEqual(d4.default_factory, list)
111        self.assertEqual(d4, {42: []})
112        d4[12]
113        self.assertEqual(d4, {42: [], 12: []})
114
115        # Issue 6637: Copy fails for empty default dict
116        d = defaultdict()
117        d['a'] = 42
118        e = d.copy()
119        self.assertEqual(e['a'], 42)
120
121    def test_shallow_copy(self):
122        d1 = defaultdict(foobar, {1: 1})
123        d2 = copy.copy(d1)
124        self.assertEqual(d2.default_factory, foobar)
125        self.assertEqual(d2, d1)
126        d1.default_factory = list
127        d2 = copy.copy(d1)
128        self.assertEqual(d2.default_factory, list)
129        self.assertEqual(d2, d1)
130
131    def test_deep_copy(self):
132        d1 = defaultdict(foobar, {1: [1]})
133        d2 = copy.deepcopy(d1)
134        self.assertEqual(d2.default_factory, foobar)
135        self.assertEqual(d2, d1)
136        self.assert_(d1[1] is not d2[1])
137        d1.default_factory = list
138        d2 = copy.deepcopy(d1)
139        self.assertEqual(d2.default_factory, list)
140        self.assertEqual(d2, d1)
141
142    def test_keyerror_without_factory(self):
143        d1 = defaultdict()
144        try:
145            d1[(1,)]
146        except KeyError, err:
147            self.assertEqual(err.args[0], (1,))
148        else:
149            self.fail("expected KeyError")
150
151    def test_recursive_repr(self):
152        # Issue2045: stack overflow when default_factory is a bound method
153        class sub(defaultdict):
154            def __init__(self):
155                self.default_factory = self._factory
156            def _factory(self):
157                return []
158        d = sub()
159        self.assert_(repr(d).startswith(
160            "defaultdict(<bound method sub._factory of defaultdict(..."))
161
162        # NOTE: printing a subclass of a builtin type does not call its
163        # tp_print slot. So this part is essentially the same test as above.
164        tfn = tempfile.mktemp()
165        try:
166            f = open(tfn, "w+")
167            try:
168                print >>f, d
169            finally:
170                f.close()
171        finally:
172            os.remove(tfn)
173
174
175def test_main():
176    test_support.run_unittest(TestDefaultDict)
177
178if __name__ == "__main__":
179    test_main()