/bnpy/learnalg/EMAlg.py

https://github.com/bnpy/bnpy · Python · 98 lines · 43 code · 20 blank · 35 comment · 8 complexity · 2dccf2bc012c63ecf064d7dc50baa8b1 MD5 · raw file

  1. import numpy as np
  2. from bnpy.learnalg.LearnAlg import LearnAlg, makeDictOfAllWorkspaceVars
  3. class EMAlg(LearnAlg):
  4. """ Implementation of expectation-maximization learning algorithm.
  5. Key Methods
  6. ------
  7. fit : fit a provided model object to data.
  8. Attributes
  9. ------
  10. See LearnAlg.py
  11. """
  12. def __init__(self, **kwargs):
  13. ''' Create EMAlg instance, subtype of generic LearnAlg
  14. '''
  15. super(type(self), self).__init__(**kwargs)
  16. def fit(self, hmodel, Data, LP=None):
  17. ''' Fit point estimates of global parameters of hmodel to Data
  18. Returns
  19. --------
  20. Info : dict of run information.
  21. Post Condition
  22. --------
  23. hmodel updated in place with improved global parameters.
  24. '''
  25. self.set_start_time_now()
  26. isConverged = False
  27. prev_loss = -np.inf
  28. # Save initial state
  29. self.saveParams(0, hmodel)
  30. # Custom func hook
  31. self.eval_custom_func(
  32. isInitial=1, **makeDictOfAllWorkspaceVars(**vars()))
  33. for iterid in range(1, self.algParams['nLap'] + 1):
  34. lap = self.algParams['startLap'] + iterid
  35. nLapsCompleted = lap - self.algParams['startLap']
  36. self.set_random_seed_at_lap(lap)
  37. # Local/E step
  38. LP = hmodel.calc_local_params(Data, LP, **self.algParamsLP)
  39. # Summary step
  40. SS = hmodel.get_global_suff_stats(Data, LP)
  41. # ELBO calculation (needs to be BEFORE Mstep for EM)
  42. cur_loss = -1 * hmodel.calc_evidence(Data, SS, LP)
  43. if lap > 1.0:
  44. # Report warning if bound isn't increasing monotonically
  45. self.verify_monotonic_decrease(cur_loss, prev_loss)
  46. # Global/M step
  47. hmodel.update_global_params(SS)
  48. # Check convergence of expected counts
  49. countVec = SS.getCountVec()
  50. if lap > 1.0:
  51. isConverged = self.isCountVecConverged(countVec, prevCountVec)
  52. self.setStatus(lap, isConverged)
  53. # Display progress
  54. self.updateNumDataProcessed(Data.get_size())
  55. if self.isLogCheckpoint(lap, iterid):
  56. self.printStateToLog(hmodel, cur_loss, lap, iterid)
  57. # Save diagnostics and params
  58. if self.isSaveDiagnosticsCheckpoint(lap, iterid):
  59. self.saveDiagnostics(lap, SS, cur_loss)
  60. if self.isSaveParamsCheckpoint(lap, iterid):
  61. self.saveParams(lap, hmodel, SS)
  62. # Custom func hook
  63. self.eval_custom_func(**makeDictOfAllWorkspaceVars(**vars()))
  64. if nLapsCompleted >= self.algParams['minLaps'] and isConverged:
  65. break
  66. prev_loss = cur_loss
  67. prevCountVec = countVec.copy()
  68. # .... end loop over laps
  69. # Finished! Save, print and exit
  70. self.saveParams(lap, hmodel, SS)
  71. self.printStateToLog(hmodel, cur_loss, lap, iterid, isFinal=1)
  72. self.eval_custom_func(
  73. isFinal=1, **makeDictOfAllWorkspaceVars(**vars()))
  74. return self.buildRunInfo(Data=Data, loss=cur_loss, SS=SS, LP=LP)