PageRenderTime 71ms CodeModel.GetById 29ms RepoModel.GetById 0ms app.codeStats 0ms

/lib-python/2.5.2/test/test_heapq.py

https://github.com/thepian/pypy
Python | 286 lines | 238 code | 36 blank | 12 comment | 47 complexity | ea7447faf39cb06f49d0c2396b91a2d5 MD5 | raw file
  1. """Unittests for heapq."""
  2. from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest
  3. import random
  4. import unittest
  5. from test import test_support
  6. import sys
  7. def heapiter(heap):
  8. # An iterator returning a heap's elements, smallest-first.
  9. try:
  10. while 1:
  11. yield heappop(heap)
  12. except IndexError:
  13. pass
  14. class TestHeap(unittest.TestCase):
  15. def test_push_pop(self):
  16. # 1) Push 256 random numbers and pop them off, verifying all's OK.
  17. heap = []
  18. data = []
  19. self.check_invariant(heap)
  20. for i in range(256):
  21. item = random.random()
  22. data.append(item)
  23. heappush(heap, item)
  24. self.check_invariant(heap)
  25. results = []
  26. while heap:
  27. item = heappop(heap)
  28. self.check_invariant(heap)
  29. results.append(item)
  30. data_sorted = data[:]
  31. data_sorted.sort()
  32. self.assertEqual(data_sorted, results)
  33. # 2) Check that the invariant holds for a sorted array
  34. self.check_invariant(results)
  35. self.assertRaises(TypeError, heappush, [])
  36. try:
  37. self.assertRaises(TypeError, heappush, None, None)
  38. self.assertRaises(TypeError, heappop, None)
  39. except AttributeError:
  40. pass
  41. def check_invariant(self, heap):
  42. # Check the heap invariant.
  43. for pos, item in enumerate(heap):
  44. if pos: # pos 0 has no parent
  45. parentpos = (pos-1) >> 1
  46. self.assert_(heap[parentpos] <= item)
  47. def test_heapify(self):
  48. for size in range(30):
  49. heap = [random.random() for dummy in range(size)]
  50. heapify(heap)
  51. self.check_invariant(heap)
  52. self.assertRaises(TypeError, heapify, None)
  53. def test_naive_nbest(self):
  54. data = [random.randrange(2000) for i in range(1000)]
  55. heap = []
  56. for item in data:
  57. heappush(heap, item)
  58. if len(heap) > 10:
  59. heappop(heap)
  60. heap.sort()
  61. self.assertEqual(heap, sorted(data)[-10:])
  62. def test_nbest(self):
  63. # Less-naive "N-best" algorithm, much faster (if len(data) is big
  64. # enough <wink>) than sorting all of data. However, if we had a max
  65. # heap instead of a min heap, it could go faster still via
  66. # heapify'ing all of data (linear time), then doing 10 heappops
  67. # (10 log-time steps).
  68. data = [random.randrange(2000) for i in range(1000)]
  69. heap = data[:10]
  70. heapify(heap)
  71. for item in data[10:]:
  72. if item > heap[0]: # this gets rarer the longer we run
  73. heapreplace(heap, item)
  74. self.assertEqual(list(heapiter(heap)), sorted(data)[-10:])
  75. self.assertRaises(TypeError, heapreplace, None)
  76. self.assertRaises(TypeError, heapreplace, None, None)
  77. self.assertRaises(IndexError, heapreplace, [], None)
  78. def test_heapsort(self):
  79. # Exercise everything with repeated heapsort checks
  80. for trial in xrange(100):
  81. size = random.randrange(50)
  82. data = [random.randrange(25) for i in range(size)]
  83. if trial & 1: # Half of the time, use heapify
  84. heap = data[:]
  85. heapify(heap)
  86. else: # The rest of the time, use heappush
  87. heap = []
  88. for item in data:
  89. heappush(heap, item)
  90. heap_sorted = [heappop(heap) for i in range(size)]
  91. self.assertEqual(heap_sorted, sorted(data))
  92. def test_nsmallest(self):
  93. data = [(random.randrange(2000), i) for i in range(1000)]
  94. for f in (None, lambda x: x[0] * 547 % 2000):
  95. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  96. self.assertEqual(nsmallest(n, data), sorted(data)[:n])
  97. self.assertEqual(nsmallest(n, data, key=f),
  98. sorted(data, key=f)[:n])
  99. def test_nlargest(self):
  100. data = [(random.randrange(2000), i) for i in range(1000)]
  101. for f in (None, lambda x: x[0] * 547 % 2000):
  102. for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
  103. self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
  104. self.assertEqual(nlargest(n, data, key=f),
  105. sorted(data, key=f, reverse=True)[:n])
  106. #==============================================================================
  107. class LenOnly:
  108. "Dummy sequence class defining __len__ but not __getitem__."
  109. def __len__(self):
  110. return 10
  111. class GetOnly:
  112. "Dummy sequence class defining __getitem__ but not __len__."
  113. def __getitem__(self, ndx):
  114. return 10
  115. class CmpErr:
  116. "Dummy element that always raises an error during comparison"
  117. def __cmp__(self, other):
  118. raise ZeroDivisionError
  119. def R(seqn):
  120. 'Regular generator'
  121. for i in seqn:
  122. yield i
  123. class G:
  124. 'Sequence using __getitem__'
  125. def __init__(self, seqn):
  126. self.seqn = seqn
  127. def __getitem__(self, i):
  128. return self.seqn[i]
  129. class I:
  130. 'Sequence using iterator protocol'
  131. def __init__(self, seqn):
  132. self.seqn = seqn
  133. self.i = 0
  134. def __iter__(self):
  135. return self
  136. def next(self):
  137. if self.i >= len(self.seqn): raise StopIteration
  138. v = self.seqn[self.i]
  139. self.i += 1
  140. return v
  141. class Ig:
  142. 'Sequence using iterator protocol defined with a generator'
  143. def __init__(self, seqn):
  144. self.seqn = seqn
  145. self.i = 0
  146. def __iter__(self):
  147. for val in self.seqn:
  148. yield val
  149. class X:
  150. 'Missing __getitem__ and __iter__'
  151. def __init__(self, seqn):
  152. self.seqn = seqn
  153. self.i = 0
  154. def next(self):
  155. if self.i >= len(self.seqn): raise StopIteration
  156. v = self.seqn[self.i]
  157. self.i += 1
  158. return v
  159. class N:
  160. 'Iterator missing next()'
  161. def __init__(self, seqn):
  162. self.seqn = seqn
  163. self.i = 0
  164. def __iter__(self):
  165. return self
  166. class E:
  167. 'Test propagation of exceptions'
  168. def __init__(self, seqn):
  169. self.seqn = seqn
  170. self.i = 0
  171. def __iter__(self):
  172. return self
  173. def next(self):
  174. 3 // 0
  175. class S:
  176. 'Test immediate stop'
  177. def __init__(self, seqn):
  178. pass
  179. def __iter__(self):
  180. return self
  181. def next(self):
  182. raise StopIteration
  183. from itertools import chain, imap
  184. def L(seqn):
  185. 'Test multiple tiers of iterators'
  186. return chain(imap(lambda x:x, R(Ig(G(seqn)))))
  187. class TestErrorHandling(unittest.TestCase):
  188. def test_non_sequence(self):
  189. for f in (heapify, heappop):
  190. self.assertRaises(TypeError, f, 10)
  191. for f in (heappush, heapreplace, nlargest, nsmallest):
  192. self.assertRaises(TypeError, f, 10, 10)
  193. def test_len_only(self):
  194. for f in (heapify, heappop):
  195. self.assertRaises(TypeError, f, LenOnly())
  196. for f in (heappush, heapreplace):
  197. self.assertRaises(TypeError, f, LenOnly(), 10)
  198. for f in (nlargest, nsmallest):
  199. self.assertRaises(TypeError, f, 2, LenOnly())
  200. def test_get_only(self):
  201. for f in (heapify, heappop):
  202. self.assertRaises(TypeError, f, GetOnly())
  203. for f in (heappush, heapreplace):
  204. self.assertRaises(TypeError, f, GetOnly(), 10)
  205. for f in (nlargest, nsmallest):
  206. self.assertRaises(TypeError, f, 2, GetOnly())
  207. def test_get_only(self):
  208. seq = [CmpErr(), CmpErr(), CmpErr()]
  209. for f in (heapify, heappop):
  210. self.assertRaises(ZeroDivisionError, f, seq)
  211. for f in (heappush, heapreplace):
  212. self.assertRaises(ZeroDivisionError, f, seq, 10)
  213. for f in (nlargest, nsmallest):
  214. self.assertRaises(ZeroDivisionError, f, 2, seq)
  215. def test_arg_parsing(self):
  216. for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest):
  217. self.assertRaises(TypeError, f, 10)
  218. def test_iterable_args(self):
  219. for f in (nlargest, nsmallest):
  220. for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
  221. for g in (G, I, Ig, L, R):
  222. self.assertEqual(f(2, g(s)), f(2,s))
  223. self.assertEqual(f(2, S(s)), [])
  224. self.assertRaises(TypeError, f, 2, X(s))
  225. self.assertRaises(TypeError, f, 2, N(s))
  226. self.assertRaises(ZeroDivisionError, f, 2, E(s))
  227. #==============================================================================
  228. def test_main(verbose=None):
  229. from types import BuiltinFunctionType
  230. test_classes = [TestHeap]
  231. if isinstance(heapify, BuiltinFunctionType):
  232. test_classes.append(TestErrorHandling)
  233. test_support.run_unittest(*test_classes)
  234. # verify reference counting
  235. if verbose and hasattr(sys, "gettotalrefcount"):
  236. import gc
  237. counts = [None] * 5
  238. for i in xrange(len(counts)):
  239. test_support.run_unittest(*test_classes)
  240. gc.collect()
  241. counts[i] = sys.gettotalrefcount()
  242. print counts
  243. if __name__ == "__main__":
  244. test_main(verbose=True)