/refinery/bnpy/bnpy-dev/bnpy/learnalg/StochasticOnlineVBLearnAlg.py

https://github.com/daeilkim/refinery
Python | 75 lines | 38 code | 12 blank | 25 comment | 3 complexity | 73b62d68a0c56dcd0a1ba420576b4878 MD5 | raw file
  1. '''
  2. StochasticOnlineVBLearnAlg.py
  3. Implementation of stochastic online VB (soVB) for bnpy models
  4. '''
  5. import numpy as np
  6. from LearnAlg import LearnAlg
  7. class StochasticOnlineVBLearnAlg(LearnAlg):
  8. def __init__(self, **kwargs):
  9. ''' Creates stochastic online learning algorithm,
  10. with fields rhodelay, rhoexp that define learning rate schedule.
  11. '''
  12. super(type(self),self).__init__(**kwargs)
  13. self.rhodelay = self.algParams['rhodelay']
  14. self.rhoexp = self.algParams['rhoexp']
  15. def fit(self, hmodel, DataIterator, SS=None):
  16. ''' Run soVB learning algorithm, fit global parameters of hmodel to Data
  17. Returns
  18. --------
  19. LP : local params from final pass of Data
  20. Info : dict of run information, with fields
  21. evBound : final ELBO evidence bound
  22. status : str message indicating reason for termination
  23. {'all data processed'}
  24. '''
  25. LP = None
  26. rho = 1.0 # Learning rate
  27. nBatch = float(DataIterator.nBatch)
  28. # Set-up progress-tracking variables
  29. iterid = -1
  30. lapFrac = np.maximum(0, self.algParams['startLap'] - 1.0/nBatch)
  31. if lapFrac > 0:
  32. # When restarting an existing run,
  33. # need to start with last update for final batch from previous lap
  34. DataIterator.lapID = int(np.ceil(lapFrac)) - 1
  35. DataIterator.curLapPos = nBatch - 2
  36. iterid = int(nBatch * lapFrac) - 1
  37. self.set_start_time_now()
  38. while DataIterator.has_next_batch():
  39. # Grab new data
  40. Dchunk = DataIterator.get_next_batch()
  41. # Update progress-tracking variables
  42. iterid += 1
  43. lapFrac += 1.0/nBatch
  44. self.set_random_seed_at_lap(lapFrac)
  45. # M step with learning rate
  46. if SS is not None:
  47. rho = (iterid + self.rhodelay) ** (-1.0 * self.rhoexp)
  48. hmodel.update_global_params(SS, rho)
  49. # E step
  50. LP = hmodel.calc_local_params(Dchunk)
  51. SS = hmodel.get_global_suff_stats(Dchunk, LP, doAmplify=True)
  52. # ELBO calculation
  53. evBound = hmodel.calc_evidence(Dchunk, SS, LP)
  54. # Save and display progress
  55. self.add_nObs(Dchunk.nObs)
  56. self.save_state(hmodel, iterid, lapFrac, evBound)
  57. self.print_state(hmodel, iterid, lapFrac, evBound)
  58. #Finally, save, print and exit
  59. status = "all data processed."
  60. self.save_state(hmodel,iterid, lapFrac, evBound, doFinal=True)
  61. self.print_state(hmodel, iterid, lapFrac, evBound, doFinal=True, status=status)
  62. return None, self.buildRunInfo(evBound, status)