/tests/zzz_deprecated_unmaintained/allocmodel/hmm/TestSummaryAlg.py

https://github.com/bnpy/bnpy · Python · 155 lines · 111 code · 33 blank · 11 comment · 10 complexity · 8590466db21a583c975677385a7464e1 MD5 · raw file

  1. import unittest
  2. import numpy as np
  3. from bnpy.allocmodel.hmm.HMMUtil import FwdAlg_py, BwdAlg_py, SummaryAlg_py
  4. from bnpy.allocmodel.hmm.HMMUtil import SummaryAlg_cpp, calcRespPair_fast
  5. from bnpy.allocmodel.hmm.HMMUtil import calc_sub_Htable_forMergePair
  6. from bnpy.init.FromTruth import convertLPFromHardToSoft
  7. class TestSummaryAlg_K4T2(unittest.TestCase):
  8. def shortDescription(self):
  9. return None
  10. def setUp(self, K=4, T=2):
  11. initPi = 1.0 / K * np.ones(K)
  12. transPi = 1.0 / K * np.ones((K, K))
  13. SoftEv = 10 * np.ones((T, K)) + np.random.rand(T, K)
  14. self._setUpFromParams(initPi, transPi, SoftEv)
  15. def _setUpFromParams(self, initPi, transPi, SoftEv):
  16. fMsg, margPrObs = FwdAlg_py(initPi, transPi, SoftEv)
  17. bMsg = BwdAlg_py(initPi, transPi, SoftEv, margPrObs)
  18. self.initPi = initPi
  19. self.transPi = transPi
  20. self.SoftEv = SoftEv
  21. self.fMsg = fMsg
  22. self.bMsg = bMsg
  23. self.margPrObs = margPrObs
  24. self.K = initPi.size
  25. self.T = SoftEv.shape[0]
  26. def test_python_equals_cpp(self):
  27. ''' Test both versions of C++ and python, verify same value returned
  28. '''
  29. print('')
  30. print('-------- python')
  31. T1, H1, _ = SummaryAlg_py(self.initPi, self.transPi, self.SoftEv,
  32. self.margPrObs, self.fMsg, self.bMsg)
  33. if self.K < 5:
  34. print(H1)
  35. else:
  36. print(H1[:5, :5])
  37. print('-------- cpp')
  38. T2, H2, _ = SummaryAlg_cpp(self.initPi, self.transPi, self.SoftEv,
  39. self.margPrObs, self.fMsg, self.bMsg)
  40. if self.K < 5:
  41. print(H2)
  42. else:
  43. print(H2[:5, :5])
  44. assert np.allclose(T1, T2)
  45. assert np.allclose(H1, H2)
  46. def test_all_possible_single_merges(self):
  47. ''' Iterate over all possible pairs (kA, kB), verify merge Htable correct.
  48. '''
  49. print('')
  50. for kA in range(self.K):
  51. for kB in range(kA + 1, self.K):
  52. self.test_single_merge__python_equals_cpp(kA=kA, kB=kB)
  53. def test_single_merge__python_equals_cpp(self, kA=0, kB=1):
  54. ''' Test both versions of C++ and python, verify same value returned
  55. '''
  56. print('')
  57. mPairIDs = [(kA, kB)]
  58. print('-------- python')
  59. _, _, mH1 = SummaryAlg_py(self.initPi, self.transPi, self.SoftEv,
  60. self.margPrObs, self.fMsg, self.bMsg, mPairIDs)
  61. print(mH1[:5, :5])
  62. print('-------- cpp')
  63. _, _, mH2 = SummaryAlg_cpp(self.initPi, self.transPi, self.SoftEv,
  64. self.margPrObs, self.fMsg, self.bMsg, mPairIDs)
  65. print(mH2[:5, :5])
  66. assert np.allclose(mH1, mH2)
  67. def test_many_possible_multiple_merges(self):
  68. for M in range(5, 10):
  69. for seed in range(3):
  70. self.test_tracking_multiple_merges__python_equals_cpp(
  71. M=M,
  72. seed=seed)
  73. def test_tracking_multiple_merges__python_equals_cpp(self, M=3, seed=0):
  74. ''' Test both versions of C++ and python, verify same value returned
  75. Here, we track M pairs simultaneously
  76. Chosen by random shuffling from all possible valid pairs (kA < kB)
  77. '''
  78. print('')
  79. mPairIDs = list()
  80. for kA in range(self.K):
  81. for kB in range(kA + 1, self.K):
  82. mPairIDs.append((kA, kB))
  83. PRNG = np.random.RandomState(seed)
  84. PRNG.shuffle(mPairIDs)
  85. mPairIDs = mPairIDs[:M]
  86. print('mPairIDs:', mPairIDs)
  87. print('-------- python')
  88. _, _, mH1 = SummaryAlg_py(self.initPi, self.transPi, self.SoftEv,
  89. self.margPrObs, self.fMsg, self.bMsg, mPairIDs)
  90. print(mH1[:10, :5])
  91. print('-------- cpp')
  92. _, _, mH2 = SummaryAlg_cpp(self.initPi, self.transPi, self.SoftEv,
  93. self.margPrObs, self.fMsg, self.bMsg, mPairIDs)
  94. print(mH2[:10, :5])
  95. print('MaxError: ', np.max(np.abs(mH1 - mH2)))
  96. assert np.allclose(mH1, mH2, atol=1e-6, rtol=0)
  97. class TestSummaryAlg_K4T100(TestSummaryAlg_K4T2):
  98. def setUp(self, K=4, T=100):
  99. parent = super(type(self), self)
  100. parent.setUp(K, T)
  101. class TestSummaryAlg_K22T55(TestSummaryAlg_K4T2):
  102. def setUp(self, K=22, T=55):
  103. parent = super(type(self), self)
  104. parent.setUp(K, T)
  105. class TestSummaryAlg_ToyData(TestSummaryAlg_K4T2):
  106. def setUp(self):
  107. T = 3000
  108. import DDToyHMM
  109. Data = DDToyHMM.get_data(seed=0, nDocTotal=1, T=T)
  110. initPi = DDToyHMM.initPi
  111. transPi = DDToyHMM.transPi
  112. LP = dict(Z=Data.TrueParams['Z'])
  113. LP = convertLPFromHardToSoft(LP, Data)
  114. Keff = LP['resp'].shape[1]
  115. assert LP['resp'].shape[0] == T
  116. assert LP['Z'].shape[0] == T
  117. K = initPi.size
  118. SoftEv = np.zeros((T, K))
  119. SoftEv[:, :Keff] = LP['resp']
  120. SoftEv += 0.05
  121. self._setUpFromParams(initPi, transPi, SoftEv)