/bnpy/learnalg/zzzdeprecated/GSAlg.py

https://github.com/bnpy/bnpy
Python | 63 lines | 28 code | 10 blank | 25 comment | 1 complexity | f181b489a03e3e377cc4c62ac08416fb MD5 | raw file
  1. '''
  2. GSAlg.py
  3. Implementation of Gibbs Sampling for bnpy models
  4. For more info, see the documentation [TODO]
  5. '''
  6. import numpy as np
  7. import scipy.sparse as sp
  8. from collections import defaultdict
  9. from LearnAlg import LearnAlg
  10. class GSAlg(LearnAlg):
  11. def __init__(self, **kwargs):
  12. ''' Create GSAlg, subtype of generic LearnAlg
  13. '''
  14. super(type(self), self).__init__(**kwargs)
  15. def fit(self, hmodel, Data):
  16. ''' Run Gibbs sampling to fit hmodel to data
  17. Returns
  18. --------
  19. Info : dict of run information.
  20. Post Condition
  21. --------
  22. hmodel updated in place with improved global parameters.
  23. '''
  24. # get initial allocations and corresponding suff stats
  25. LP = hmodel.calc_local_params(Data)
  26. LP = hmodel.allocModel.make_hard_asgn_local_params(LP)
  27. SS = hmodel.get_global_suff_stats(Data, LP)
  28. self.set_start_time_now()
  29. for iterid in range(self.algParams['nLap'] + 1):
  30. lap = self.algParams['startLap'] + iterid
  31. self.set_random_seed_at_lap(lap)
  32. # sample posterior allocations
  33. LP, SS = hmodel.allocModel.sample_local_params(hmodel.obsModel,
  34. Data, SS, LP,
  35. self.PRNG,
  36. **self.algParams)
  37. # Make posterior params
  38. hmodel.update_global_params(SS)
  39. # Log prob of total sampler state
  40. ll = hmodel.calcLogLikCollapsedSamplerState(SS)
  41. # Save and display progress
  42. self.add_nObs(Data.nObsTotal)
  43. self.save_state(hmodel, iterid, lap, ll)
  44. self.print_state(hmodel, iterid, lap, ll)
  45. # Finally, save, print and exit
  46. status = "max passes thru data exceeded."
  47. self.save_state(hmodel, iterid, lap, ll, doFinal=True)
  48. self.print_state(hmodel, iterid, lap, ll, doFinal=True, status=status)
  49. return LP, self.buildRunInfo(ll, status)