/Lib/test/test_enumerate.py

http://unladen-swallow.googlecode.com/ · Python · 230 lines · 180 code · 43 blank · 7 comment · 16 complexity · 6844d8c0a91e51cc5888b69206d49343 MD5 · raw file

  1. import unittest
  2. import sys
  3. from test import test_support
  4. class G:
  5. 'Sequence using __getitem__'
  6. def __init__(self, seqn):
  7. self.seqn = seqn
  8. def __getitem__(self, i):
  9. return self.seqn[i]
  10. class I:
  11. 'Sequence using iterator protocol'
  12. def __init__(self, seqn):
  13. self.seqn = seqn
  14. self.i = 0
  15. def __iter__(self):
  16. return self
  17. def next(self):
  18. if self.i >= len(self.seqn): raise StopIteration
  19. v = self.seqn[self.i]
  20. self.i += 1
  21. return v
  22. class Ig:
  23. 'Sequence using iterator protocol defined with a generator'
  24. def __init__(self, seqn):
  25. self.seqn = seqn
  26. self.i = 0
  27. def __iter__(self):
  28. for val in self.seqn:
  29. yield val
  30. class X:
  31. 'Missing __getitem__ and __iter__'
  32. def __init__(self, seqn):
  33. self.seqn = seqn
  34. self.i = 0
  35. def next(self):
  36. if self.i >= len(self.seqn): raise StopIteration
  37. v = self.seqn[self.i]
  38. self.i += 1
  39. return v
  40. class E:
  41. 'Test propagation of exceptions'
  42. def __init__(self, seqn):
  43. self.seqn = seqn
  44. self.i = 0
  45. def __iter__(self):
  46. return self
  47. def next(self):
  48. 3 // 0
  49. class N:
  50. 'Iterator missing next()'
  51. def __init__(self, seqn):
  52. self.seqn = seqn
  53. self.i = 0
  54. def __iter__(self):
  55. return self
  56. class EnumerateTestCase(unittest.TestCase):
  57. enum = enumerate
  58. seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
  59. def test_basicfunction(self):
  60. self.assertEqual(type(self.enum(self.seq)), self.enum)
  61. e = self.enum(self.seq)
  62. self.assertEqual(iter(e), e)
  63. self.assertEqual(list(self.enum(self.seq)), self.res)
  64. self.enum.__doc__
  65. def test_getitemseqn(self):
  66. self.assertEqual(list(self.enum(G(self.seq))), self.res)
  67. e = self.enum(G(''))
  68. self.assertRaises(StopIteration, e.next)
  69. def test_iteratorseqn(self):
  70. self.assertEqual(list(self.enum(I(self.seq))), self.res)
  71. e = self.enum(I(''))
  72. self.assertRaises(StopIteration, e.next)
  73. def test_iteratorgenerator(self):
  74. self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
  75. e = self.enum(Ig(''))
  76. self.assertRaises(StopIteration, e.next)
  77. def test_noniterable(self):
  78. self.assertRaises(TypeError, self.enum, X(self.seq))
  79. def test_illformediterable(self):
  80. self.assertRaises(TypeError, list, self.enum(N(self.seq)))
  81. def test_exception_propagation(self):
  82. self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
  83. def test_argumentcheck(self):
  84. self.assertRaises(TypeError, self.enum) # no arguments
  85. self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
  86. self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
  87. self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
  88. def test_tuple_reuse(self):
  89. # Tests an implementation detail where tuple is reused
  90. # whenever nothing else holds a reference to it
  91. self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
  92. self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
  93. class MyEnum(enumerate):
  94. pass
  95. class SubclassTestCase(EnumerateTestCase):
  96. enum = MyEnum
  97. class TestEmpty(EnumerateTestCase):
  98. seq, res = '', []
  99. class TestBig(EnumerateTestCase):
  100. seq = range(10,20000,2)
  101. res = zip(range(20000), seq)
  102. class TestReversed(unittest.TestCase):
  103. def test_simple(self):
  104. class A:
  105. def __getitem__(self, i):
  106. if i < 5:
  107. return str(i)
  108. raise StopIteration
  109. def __len__(self):
  110. return 5
  111. for data in 'abc', range(5), tuple(enumerate('abc')), A(), xrange(1,17,5):
  112. self.assertEqual(list(data)[::-1], list(reversed(data)))
  113. self.assertRaises(TypeError, reversed, {})
  114. # don't allow keyword arguments
  115. self.assertRaises(TypeError, reversed, [], a=1)
  116. def test_xrange_optimization(self):
  117. x = xrange(1)
  118. self.assertEqual(type(reversed(x)), type(iter(x)))
  119. def test_len(self):
  120. # This is an implementation detail, not an interface requirement
  121. from test.test_iterlen import len
  122. for s in ('hello', tuple('hello'), list('hello'), xrange(5)):
  123. self.assertEqual(len(reversed(s)), len(s))
  124. r = reversed(s)
  125. list(r)
  126. self.assertEqual(len(r), 0)
  127. class SeqWithWeirdLen:
  128. called = False
  129. def __len__(self):
  130. if not self.called:
  131. self.called = True
  132. return 10
  133. raise ZeroDivisionError
  134. def __getitem__(self, index):
  135. return index
  136. r = reversed(SeqWithWeirdLen())
  137. self.assertRaises(ZeroDivisionError, len, r)
  138. def test_gc(self):
  139. class Seq:
  140. def __len__(self):
  141. return 10
  142. def __getitem__(self, index):
  143. return index
  144. s = Seq()
  145. r = reversed(s)
  146. s.r = r
  147. def test_args(self):
  148. self.assertRaises(TypeError, reversed)
  149. self.assertRaises(TypeError, reversed, [], 'extra')
  150. def test_bug1229429(self):
  151. # this bug was never in reversed, it was in
  152. # PyObject_CallMethod, and reversed_new calls that sometimes.
  153. if not hasattr(sys, "getrefcount"):
  154. return
  155. def f():
  156. pass
  157. r = f.__reversed__ = object()
  158. rc = sys.getrefcount(r)
  159. for i in range(10):
  160. try:
  161. reversed(f)
  162. except TypeError:
  163. pass
  164. else:
  165. self.fail("non-callable __reversed__ didn't raise!")
  166. self.assertEqual(rc, sys.getrefcount(r))
  167. class TestStart(EnumerateTestCase):
  168. enum = lambda i: enumerate(i, start=11)
  169. seq, res = 'abc', [(1, 'a'), (2, 'b'), (3, 'c')]
  170. class TestLongStart(EnumerateTestCase):
  171. enum = lambda i: enumerate(i, start=sys.maxint+1)
  172. seq, res = 'abc', [(sys.maxint+1,'a'), (sys.maxint+2,'b'),
  173. (sys.maxint+3,'c')]
  174. def test_main(verbose=None):
  175. testclasses = (EnumerateTestCase, SubclassTestCase, TestEmpty, TestBig,
  176. TestReversed)
  177. test_support.run_unittest(*testclasses)
  178. # verify reference counting
  179. import sys
  180. if verbose and hasattr(sys, "gettotalrefcount"):
  181. counts = [None] * 5
  182. for i in xrange(len(counts)):
  183. test_support.run_unittest(*testclasses)
  184. counts[i] = sys.gettotalrefcount()
  185. print counts
  186. if __name__ == "__main__":
  187. test_main(verbose=True)