PageRenderTime 24ms CodeModel.GetById 15ms RepoModel.GetById 7ms app.codeStats 0ms

/Lib/test/test_richcmp.py

https://gitlab.com/unofficial-mirrors/cpython
Python | 355 lines | 266 code | 63 blank | 26 comment | 36 complexity | fc3737e5205cb37a51a40b335117afdf MD5 | raw file
  1. # Tests for rich comparisons
  2. import unittest
  3. from test import support
  4. import operator
  5. class Number:
  6. def __init__(self, x):
  7. self.x = x
  8. def __lt__(self, other):
  9. return self.x < other
  10. def __le__(self, other):
  11. return self.x <= other
  12. def __eq__(self, other):
  13. return self.x == other
  14. def __ne__(self, other):
  15. return self.x != other
  16. def __gt__(self, other):
  17. return self.x > other
  18. def __ge__(self, other):
  19. return self.x >= other
  20. def __cmp__(self, other):
  21. raise support.TestFailed("Number.__cmp__() should not be called")
  22. def __repr__(self):
  23. return "Number(%r)" % (self.x, )
  24. class Vector:
  25. def __init__(self, data):
  26. self.data = data
  27. def __len__(self):
  28. return len(self.data)
  29. def __getitem__(self, i):
  30. return self.data[i]
  31. def __setitem__(self, i, v):
  32. self.data[i] = v
  33. __hash__ = None # Vectors cannot be hashed
  34. def __bool__(self):
  35. raise TypeError("Vectors cannot be used in Boolean contexts")
  36. def __cmp__(self, other):
  37. raise support.TestFailed("Vector.__cmp__() should not be called")
  38. def __repr__(self):
  39. return "Vector(%r)" % (self.data, )
  40. def __lt__(self, other):
  41. return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
  42. def __le__(self, other):
  43. return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
  44. def __eq__(self, other):
  45. return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
  46. def __ne__(self, other):
  47. return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
  48. def __gt__(self, other):
  49. return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
  50. def __ge__(self, other):
  51. return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
  52. def __cast(self, other):
  53. if isinstance(other, Vector):
  54. other = other.data
  55. if len(self.data) != len(other):
  56. raise ValueError("Cannot compare vectors of different length")
  57. return other
  58. opmap = {
  59. "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
  60. "le": (lambda a,b: a<=b, operator.le, operator.__le__),
  61. "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
  62. "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
  63. "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
  64. "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
  65. }
  66. class VectorTest(unittest.TestCase):
  67. def checkfail(self, error, opname, *args):
  68. for op in opmap[opname]:
  69. self.assertRaises(error, op, *args)
  70. def checkequal(self, opname, a, b, expres):
  71. for op in opmap[opname]:
  72. realres = op(a, b)
  73. # can't use assertEqual(realres, expres) here
  74. self.assertEqual(len(realres), len(expres))
  75. for i in range(len(realres)):
  76. # results are bool, so we can use "is" here
  77. self.assertTrue(realres[i] is expres[i])
  78. def test_mixed(self):
  79. # check that comparisons involving Vector objects
  80. # which return rich results (i.e. Vectors with itemwise
  81. # comparison results) work
  82. a = Vector(range(2))
  83. b = Vector(range(3))
  84. # all comparisons should fail for different length
  85. for opname in opmap:
  86. self.checkfail(ValueError, opname, a, b)
  87. a = list(range(5))
  88. b = 5 * [2]
  89. # try mixed arguments (but not (a, b) as that won't return a bool vector)
  90. args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
  91. for (a, b) in args:
  92. self.checkequal("lt", a, b, [True, True, False, False, False])
  93. self.checkequal("le", a, b, [True, True, True, False, False])
  94. self.checkequal("eq", a, b, [False, False, True, False, False])
  95. self.checkequal("ne", a, b, [True, True, False, True, True ])
  96. self.checkequal("gt", a, b, [False, False, False, True, True ])
  97. self.checkequal("ge", a, b, [False, False, True, True, True ])
  98. for ops in opmap.values():
  99. for op in ops:
  100. # calls __bool__, which should fail
  101. self.assertRaises(TypeError, bool, op(a, b))
  102. class NumberTest(unittest.TestCase):
  103. def test_basic(self):
  104. # Check that comparisons involving Number objects
  105. # give the same results give as comparing the
  106. # corresponding ints
  107. for a in range(3):
  108. for b in range(3):
  109. for typea in (int, Number):
  110. for typeb in (int, Number):
  111. if typea==typeb==int:
  112. continue # the combination int, int is useless
  113. ta = typea(a)
  114. tb = typeb(b)
  115. for ops in opmap.values():
  116. for op in ops:
  117. realoutcome = op(a, b)
  118. testoutcome = op(ta, tb)
  119. self.assertEqual(realoutcome, testoutcome)
  120. def checkvalue(self, opname, a, b, expres):
  121. for typea in (int, Number):
  122. for typeb in (int, Number):
  123. ta = typea(a)
  124. tb = typeb(b)
  125. for op in opmap[opname]:
  126. realres = op(ta, tb)
  127. realres = getattr(realres, "x", realres)
  128. self.assertTrue(realres is expres)
  129. def test_values(self):
  130. # check all operators and all comparison results
  131. self.checkvalue("lt", 0, 0, False)
  132. self.checkvalue("le", 0, 0, True )
  133. self.checkvalue("eq", 0, 0, True )
  134. self.checkvalue("ne", 0, 0, False)
  135. self.checkvalue("gt", 0, 0, False)
  136. self.checkvalue("ge", 0, 0, True )
  137. self.checkvalue("lt", 0, 1, True )
  138. self.checkvalue("le", 0, 1, True )
  139. self.checkvalue("eq", 0, 1, False)
  140. self.checkvalue("ne", 0, 1, True )
  141. self.checkvalue("gt", 0, 1, False)
  142. self.checkvalue("ge", 0, 1, False)
  143. self.checkvalue("lt", 1, 0, False)
  144. self.checkvalue("le", 1, 0, False)
  145. self.checkvalue("eq", 1, 0, False)
  146. self.checkvalue("ne", 1, 0, True )
  147. self.checkvalue("gt", 1, 0, True )
  148. self.checkvalue("ge", 1, 0, True )
  149. class MiscTest(unittest.TestCase):
  150. def test_misbehavin(self):
  151. class Misb:
  152. def __lt__(self_, other): return 0
  153. def __gt__(self_, other): return 0
  154. def __eq__(self_, other): return 0
  155. def __le__(self_, other): self.fail("This shouldn't happen")
  156. def __ge__(self_, other): self.fail("This shouldn't happen")
  157. def __ne__(self_, other): self.fail("This shouldn't happen")
  158. a = Misb()
  159. b = Misb()
  160. self.assertEqual(a<b, 0)
  161. self.assertEqual(a==b, 0)
  162. self.assertEqual(a>b, 0)
  163. def test_not(self):
  164. # Check that exceptions in __bool__ are properly
  165. # propagated by the not operator
  166. import operator
  167. class Exc(Exception):
  168. pass
  169. class Bad:
  170. def __bool__(self):
  171. raise Exc
  172. def do(bad):
  173. not bad
  174. for func in (do, operator.not_):
  175. self.assertRaises(Exc, func, Bad())
  176. @support.no_tracing
  177. def test_recursion(self):
  178. # Check that comparison for recursive objects fails gracefully
  179. from collections import UserList
  180. a = UserList()
  181. b = UserList()
  182. a.append(b)
  183. b.append(a)
  184. self.assertRaises(RecursionError, operator.eq, a, b)
  185. self.assertRaises(RecursionError, operator.ne, a, b)
  186. self.assertRaises(RecursionError, operator.lt, a, b)
  187. self.assertRaises(RecursionError, operator.le, a, b)
  188. self.assertRaises(RecursionError, operator.gt, a, b)
  189. self.assertRaises(RecursionError, operator.ge, a, b)
  190. b.append(17)
  191. # Even recursive lists of different lengths are different,
  192. # but they cannot be ordered
  193. self.assertTrue(not (a == b))
  194. self.assertTrue(a != b)
  195. self.assertRaises(RecursionError, operator.lt, a, b)
  196. self.assertRaises(RecursionError, operator.le, a, b)
  197. self.assertRaises(RecursionError, operator.gt, a, b)
  198. self.assertRaises(RecursionError, operator.ge, a, b)
  199. a.append(17)
  200. self.assertRaises(RecursionError, operator.eq, a, b)
  201. self.assertRaises(RecursionError, operator.ne, a, b)
  202. a.insert(0, 11)
  203. b.insert(0, 12)
  204. self.assertTrue(not (a == b))
  205. self.assertTrue(a != b)
  206. self.assertTrue(a < b)
  207. def test_exception_message(self):
  208. class Spam:
  209. pass
  210. tests = [
  211. (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
  212. (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
  213. (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
  214. (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
  215. (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
  216. (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
  217. (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
  218. (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
  219. (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
  220. (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
  221. (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
  222. (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
  223. (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
  224. ]
  225. for i, test in enumerate(tests):
  226. with self.subTest(test=i):
  227. with self.assertRaisesRegex(TypeError, test[1]):
  228. test[0]()
  229. class DictTest(unittest.TestCase):
  230. def test_dicts(self):
  231. # Verify that __eq__ and __ne__ work for dicts even if the keys and
  232. # values don't support anything other than __eq__ and __ne__ (and
  233. # __hash__). Complex numbers are a fine example of that.
  234. import random
  235. imag1a = {}
  236. for i in range(50):
  237. imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
  238. items = list(imag1a.items())
  239. random.shuffle(items)
  240. imag1b = {}
  241. for k, v in items:
  242. imag1b[k] = v
  243. imag2 = imag1b.copy()
  244. imag2[k] = v + 1.0
  245. self.assertEqual(imag1a, imag1a)
  246. self.assertEqual(imag1a, imag1b)
  247. self.assertEqual(imag2, imag2)
  248. self.assertTrue(imag1a != imag2)
  249. for opname in ("lt", "le", "gt", "ge"):
  250. for op in opmap[opname]:
  251. self.assertRaises(TypeError, op, imag1a, imag2)
  252. class ListTest(unittest.TestCase):
  253. def test_coverage(self):
  254. # exercise all comparisons for lists
  255. x = [42]
  256. self.assertIs(x<x, False)
  257. self.assertIs(x<=x, True)
  258. self.assertIs(x==x, True)
  259. self.assertIs(x!=x, False)
  260. self.assertIs(x>x, False)
  261. self.assertIs(x>=x, True)
  262. y = [42, 42]
  263. self.assertIs(x<y, True)
  264. self.assertIs(x<=y, True)
  265. self.assertIs(x==y, False)
  266. self.assertIs(x!=y, True)
  267. self.assertIs(x>y, False)
  268. self.assertIs(x>=y, False)
  269. def test_badentry(self):
  270. # make sure that exceptions for item comparison are properly
  271. # propagated in list comparisons
  272. class Exc(Exception):
  273. pass
  274. class Bad:
  275. def __eq__(self, other):
  276. raise Exc
  277. x = [Bad()]
  278. y = [Bad()]
  279. for op in opmap["eq"]:
  280. self.assertRaises(Exc, op, x, y)
  281. def test_goodentry(self):
  282. # This test exercises the final call to PyObject_RichCompare()
  283. # in Objects/listobject.c::list_richcompare()
  284. class Good:
  285. def __lt__(self, other):
  286. return True
  287. x = [Good()]
  288. y = [Good()]
  289. for op in opmap["lt"]:
  290. self.assertIs(op(x, y), True)
  291. if __name__ == "__main__":
  292. unittest.main()