PageRenderTime 66ms CodeModel.GetById 26ms RepoModel.GetById 0ms app.codeStats 1ms

/Lib/site-packages/gensim/models/lda_worker.py

https://gitlab.com/pierreEffiScience/TwitterClustering
Python | 134 lines | 99 code | 21 blank | 14 comment | 11 complexity | e7c5df76a19cd383fee48a2c5bf9ce98 MD5 | raw file
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
  5. # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
  6. """
  7. USAGE: %(program)s
  8. Worker ("slave") process used in computing distributed LDA. Run this script \
  9. on every node in your cluster. If you wish, you may even run it multiple times \
  10. on a single machine, to make better use of multiple cores (just beware that \
  11. memory footprint increases accordingly).
  12. Example: python -m gensim.models.lda_worker
  13. """
  14. from __future__ import with_statement
  15. import os, sys, logging
  16. import threading
  17. import tempfile
  18. try:
  19. import Queue
  20. except ImportError:
  21. import queue as Queue
  22. import Pyro4
  23. from gensim.models import ldamodel
  24. from gensim import utils
  25. logger = logging.getLogger('gensim.models.lda_worker')
  26. # periodically save intermediate models after every SAVE_DEBUG updates (0 for never)
  27. SAVE_DEBUG = 0
  28. class Worker(object):
  29. def __init__(self):
  30. self.model = None
  31. def initialize(self, myid, dispatcher, **model_params):
  32. self.lock_update = threading.Lock()
  33. self.jobsdone = 0 # how many jobs has this worker completed?
  34. self.myid = myid # id of this worker in the dispatcher; just a convenience var for easy access/logging TODO remove?
  35. self.dispatcher = dispatcher
  36. self.finished = False
  37. logger.info("initializing worker #%s" % myid)
  38. self.model = ldamodel.LdaModel(**model_params)
  39. @Pyro4.oneway
  40. def requestjob(self):
  41. """
  42. Request jobs from the dispatcher, in a perpetual loop until `getstate()` is called.
  43. """
  44. if self.model is None:
  45. raise RuntimeError("worker must be initialized before receiving jobs")
  46. job = None
  47. while job is None and not self.finished:
  48. try:
  49. job = self.dispatcher.getjob(self.myid)
  50. except Queue.Empty:
  51. # no new job: try again, unless we're finished with all work
  52. continue
  53. if job is not None:
  54. logger.info("worker #%s received job #%i" % (self.myid, self.jobsdone))
  55. self.processjob(job)
  56. self.dispatcher.jobdone(self.myid)
  57. else:
  58. logger.info("worker #%i stopping asking for jobs" % self.myid)
  59. @utils.synchronous('lock_update')
  60. def processjob(self, job):
  61. logger.debug("starting to process job #%i" % self.jobsdone)
  62. self.model.do_estep(job)
  63. self.jobsdone += 1
  64. if SAVE_DEBUG and self.jobsdone % SAVE_DEBUG == 0:
  65. fname = os.path.join(tempfile.gettempdir(), 'lda_worker.pkl')
  66. self.model.save(fname)
  67. logger.info("finished processing job #%i" % (self.jobsdone - 1))
  68. @utils.synchronous('lock_update')
  69. def getstate(self):
  70. logger.info("worker #%i returning its state after %s jobs" %
  71. (self.myid, self.jobsdone))
  72. result = self.model.state
  73. assert isinstance(result, ldamodel.LdaState)
  74. self.model.clear() # free up mem in-between two EM cycles
  75. self.finished = True
  76. return result
  77. @utils.synchronous('lock_update')
  78. def reset(self, state):
  79. assert state is not None
  80. logger.info("resetting worker #%i" % self.myid)
  81. self.model.state = state
  82. self.model.sync_state()
  83. self.model.state.reset()
  84. self.finished = False
  85. @Pyro4.oneway
  86. def exit(self):
  87. logger.info("terminating worker #%i" % self.myid)
  88. os._exit(0)
  89. #endclass Worker
  90. def main():
  91. logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
  92. logger.info("running %s" % " ".join(sys.argv))
  93. program = os.path.basename(sys.argv[0])
  94. # make sure we have enough cmd line parameters
  95. if len(sys.argv) < 1:
  96. print(globals()["__doc__"] % locals())
  97. sys.exit(1)
  98. utils.pyro_daemon('gensim.lda_worker', Worker(), random_suffix=True)
  99. logger.info("finished running %s" % program)
  100. if __name__ == '__main__':
  101. main()