/zss/tests/test_metricspace.py

https://gitlab.com/tadh/zhang-shasha
Python | 215 lines | 155 code | 37 blank | 23 comment | 31 complexity | 81cb165df8f71afabfd0d69f2e4316ec MD5 | raw file
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #Author: Tim Henderson
  4. #Email: tim.tadh@gmail.com
  5. #For licensing see the LICENSE file in the top level directory.
  6. from __future__ import absolute_import
  7. from six.moves import map
  8. from six.moves import range
  9. import copy
  10. import itertools
  11. import os
  12. import sys
  13. import random
  14. import unittest
  15. from random import randint, seed, shuffle
  16. from zss import (
  17. simple_distance,
  18. Node,
  19. )
  20. from zss.compare import strdist
  21. seed(os.urandom(15))
  22. N = 3
  23. def product(*args, **kwds):
  24. # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
  25. # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
  26. pools = list(map(tuple, args)) * kwds.get('repeat', 1)
  27. result = [[]]
  28. for pool in pools:
  29. result = [x+[y] for x in result for y in pool]
  30. for prod in result:
  31. yield tuple(prod)
  32. if not hasattr(itertools, 'product'):
  33. setattr(itertools, 'product', product)
  34. tree1_nodes = ['a','b','c','d','e','f']
  35. def tree1():
  36. return (
  37. Node("f")
  38. .addkid(Node("d")
  39. .addkid(Node("a"))
  40. .addkid(Node("c")
  41. .addkid(Node("b"))))
  42. .addkid(Node("e"))
  43. )
  44. tree2_nodes = ['a','b','c','d','e','f']
  45. def tree2():
  46. return (
  47. Node("a")
  48. .addkid(Node("c")
  49. .addkid(Node("d")
  50. .addkid(Node("b"))
  51. .addkid(Node("e"))))
  52. .addkid(Node("f"))
  53. )
  54. tree3_nodes = ['a','b','c','d','e','f']
  55. def tree3():
  56. return (
  57. Node("a")
  58. .addkid(Node("d")
  59. .addkid(Node("f"))
  60. .addkid(Node("c")
  61. .addkid(Node("b"))))
  62. .addkid(Node("e"))
  63. )
  64. tree4_nodes = ['q','b','c','d','e','f']
  65. def tree4():
  66. return (
  67. Node("f")
  68. .addkid(Node("d")
  69. .addkid(Node("q"))
  70. .addkid(Node("c")
  71. .addkid(Node("b"))))
  72. .addkid(Node("e"))
  73. )
  74. def randtree(depth=2, alpha='abcdefghijklmnopqrstuvwxyz', repeat=2, width=2):
  75. labels = [''.join(x) for x in itertools.product(alpha, repeat=repeat)]
  76. shuffle(labels)
  77. labels = (x for x in labels)
  78. root = Node("root")
  79. p = [root]
  80. c = list()
  81. for x in range(depth-1):
  82. for y in p:
  83. for z in range(randint(1,1+width)):
  84. n = Node(next(labels))
  85. y.addkid(n)
  86. c.append(n)
  87. p = c
  88. c = list()
  89. return root
  90. class TestTestNode(unittest.TestCase):
  91. def test_contains(self):
  92. root = tree1()
  93. self.assertTrue("a" in root)
  94. self.assertTrue("b" in root)
  95. self.assertTrue("c" in root)
  96. self.assertTrue("d" in root)
  97. self.assertTrue("e" in root)
  98. self.assertTrue("f" in root)
  99. self.assertFalse("q" in root)
  100. def test_get(self):
  101. root = tree1()
  102. self.assertEqual(root.get("a").label, "a")
  103. self.assertEqual(root.get("b").label, "b")
  104. self.assertEqual(root.get("c").label, "c")
  105. self.assertEqual(root.get("d").label, "d")
  106. self.assertEqual(root.get("e").label, "e")
  107. self.assertEqual(root.get("f").label, "f")
  108. self.assertNotEqual(root.get("a").label, "x")
  109. self.assertNotEqual(root.get("b").label, "x")
  110. self.assertNotEqual(root.get("c").label, "x")
  111. self.assertNotEqual(root.get("d").label, "x")
  112. self.assertNotEqual(root.get("e").label, "x")
  113. self.assertNotEqual(root.get("f").label, "x")
  114. self.assertEqual(root.get("x"), None)
  115. def test_iter(self):
  116. root = tree1()
  117. self.assertEqual(list(x.label for x in root.iter()), ['f','d','e','a','c','b'])
  118. class TestCompare(unittest.TestCase):
  119. def test_distance(self):
  120. trees = itertools.product([tree1(), tree2(), tree3(), tree4()], repeat=2)
  121. for a,b in trees:
  122. ab = simple_distance(a,b)
  123. ba = simple_distance(b,a)
  124. #print '-----------------------------'
  125. #print a
  126. #print '------'
  127. #print b
  128. #print '------'
  129. #print ab, ba
  130. self.assertEqual(ab,ba)
  131. self.assertTrue((ab == 0 and a is b) or a is not b)
  132. #break
  133. trees = itertools.product([tree1(), tree2(), tree3(), tree4()], repeat=3)
  134. for a,b,c in trees:
  135. ab = simple_distance(a,b)
  136. bc = simple_distance(b,c)
  137. ac = simple_distance(a,c)
  138. self.assertTrue(ac <= ab + bc)
  139. #break
  140. #def test_randtree(self):
  141. #print randtree(5, repeat=3, width=2)
  142. def test_symmetry(self):
  143. trees = itertools.product((randtree(5, repeat=3, width=2) for x in range(N)), repeat=2)
  144. for a,b in trees:
  145. ab = simple_distance(a,b)
  146. ba = simple_distance(b,a)
  147. #print '-----------------------------'
  148. #print ab, ba
  149. self.assertEqual(ab, ba)
  150. def test_nondegenercy(self):
  151. trees = itertools.product((randtree(5, repeat=3, width=2) for x in range(N)), repeat=2)
  152. for a,b in trees:
  153. d = simple_distance(a,b)
  154. #print '-----------------------------'
  155. #print d, a is b
  156. self.assertTrue((d == 0 and a is b) or a is not b)
  157. def test_triangle_inequality(self):
  158. trees = itertools.product((randtree(5, repeat=3, width=2) for x in range(N)), (randtree(5, repeat=3, width=2) for x in range(N)), (randtree(5, repeat=3, width=2) for x in range(N)))
  159. for a,b,c in trees:
  160. #print '--------------------------------'
  161. ab = simple_distance(a,b)
  162. bc = simple_distance(b,c)
  163. ac = simple_distance(a,c)
  164. #print ab, bc, ac
  165. self.assertTrue(ac <= ab + bc)
  166. def test_labelchange(self):
  167. for A in (randtree(5, repeat=3, width=2) for x in range(N*4)):
  168. B = copy.deepcopy(A)
  169. node = random.choice([n for n in B.iter()])
  170. old_label = str(node.label)
  171. node.label = 'xty'
  172. assert simple_distance(A, B) == strdist(old_label, node.label)
  173. if __name__ == '__main__':
  174. if len(sys.argv) > 1:
  175. import cProfile
  176. cProfile.run('unittest.main()', 'profile')
  177. else:
  178. unittest.main()