/tests/zzz_deprecated_unmaintained/parallel/localstep-benchmark/LocalStepUtil_ParallelSharedMem.py

https://github.com/bnpy/bnpy · Python · 157 lines · 119 code · 14 blank · 24 comment · 4 complexity · 154cc2ee58889d7a511d039a27d8f258 MD5 · raw file

  1. import multiprocessing
  2. import numpy as np
  3. import time
  4. import bnpy
  5. from bnpy.util.ParallelUtil import sharedMemDictToNumpy
  6. from RunBenchmark import sliceGenerator
  7. def calcLocalParamsAndSummarize(
  8. JobQ, ResultQ, Data, hmodel, nWorker=0,
  9. LPkwargs=dict(),
  10. **kwargs):
  11. """ Execute processed by workers in parallel.
  12. """
  13. LPkwargs.update(
  14. hmodel.obsModel.getSerializableParamsForLocalStep())
  15. LPkwargs.update(
  16. hmodel.allocModel.getSerializableParamsForLocalStep())
  17. # MAP step
  18. # Create several tasks (one per worker) and add to job queue
  19. for start, stop in sliceGenerator(Data, nWorker):
  20. JobQ.put((start, stop, LPkwargs))
  21. # Pause at this line until all jobs are marked complete.
  22. JobQ.join()
  23. # REDUCE step
  24. # Aggregate results across across all workers
  25. SS, telapsed_max = ResultQ.get()
  26. while not ResultQ.empty():
  27. SSslice, telapsed_cur = ResultQ.get()
  28. SS += SSslice
  29. telapsed_max = np.maximum(telapsed_max, telapsed_cur)
  30. return SS, telapsed_max
  31. def setUpWorkers(
  32. Data=None, hmodel=None,
  33. nWorker=1, verbose=0, nRepsForMinDuration=1, **kwargs):
  34. ''' Create queues and launch all workers.
  35. Returns
  36. -------
  37. JobQ
  38. ResultQ
  39. '''
  40. # Create a JobQ (to hold tasks to be done)
  41. # and a ResultsQ (to hold results of completed tasks)
  42. manager = multiprocessing.Manager()
  43. JobQ = manager.Queue()
  44. ResultQ = manager.Queue()
  45. # Create sharedmem representations of Data and hmodel
  46. dataSharedMem = Data.getRawDataAsSharedMemDict()
  47. aSharedMem = hmodel.allocModel.fillSharedMemDictForLocalStep()
  48. oSharedMem = hmodel.obsModel.fillSharedMemDictForLocalStep()
  49. ShMem = dict(dataSharedMem=dataSharedMem,
  50. aSharedMem=aSharedMem,
  51. oSharedMem=oSharedMem)
  52. # Get relevant function handles
  53. afuncHTuple = hmodel.allocModel.getLocalAndSummaryFunctionHandles()
  54. ofuncHTuple = hmodel.obsModel.getLocalAndSummaryFunctionHandles()
  55. funcH = dict(
  56. makeDataSliceFromSharedMem=Data.getDataSliceFunctionHandle(),
  57. a_calcLocalParams=afuncHTuple[0],
  58. a_calcSummaryStats=afuncHTuple[1],
  59. o_calcLocalParams=ofuncHTuple[0],
  60. o_calcSummaryStats=ofuncHTuple[1],
  61. )
  62. # Launch desired number of worker processes
  63. # We don't need to store references to these processes,
  64. # We can get everything we need from JobQ and ResultsQ
  65. for uid in range(nWorker):
  66. workerProcess = Worker_SHMData_SHMModel(
  67. uid, JobQ, ResultQ,
  68. ShMem=ShMem,
  69. funcH=funcH,
  70. nReps=nRepsForMinDuration,
  71. verbose=verbose)
  72. workerProcess.start()
  73. return JobQ, ResultQ
  74. def tearDownWorkers(JobQ=None, ResultQ=None, nWorker=1, **kwargs):
  75. ''' Shutdown pool of workers.
  76. '''
  77. for workerID in range(nWorker):
  78. # Passing None to JobQ is shutdown signal
  79. JobQ.put(None)
  80. time.sleep(0.1) # let workers all shut down before we quit
  81. class Worker_SHMData_SHMModel(multiprocessing.Process):
  82. ''' Single "worker" process that processes tasks delivered via queues.
  83. Attributes
  84. ----------
  85. JobQ : multiprocessing.Queue
  86. ResultQ : multiprocessing.Queue
  87. '''
  88. def __init__(self, uid, JobQ, ResultQ,
  89. ShMem=dict(),
  90. funcH=dict(),
  91. verbose=0, nReps=1):
  92. ''' Create single worker process, linked to provided queues.
  93. '''
  94. super(type(self), self).__init__() # Required super constructor call
  95. self.uid = uid
  96. self.ShMem = ShMem
  97. self.funcH = funcH
  98. self.JobQ = JobQ
  99. self.ResultQ = ResultQ
  100. self.verbose = verbose
  101. self.nReps = nReps
  102. def run(self):
  103. ''' Perform calcLocalParamsAndSummarize on jobs in JobQ.
  104. Post Condition
  105. --------------
  106. '''
  107. # Construct iterator with sentinel value of None (for termination)
  108. jobIterator = iter(self.JobQ.get, None)
  109. for jobArgs in jobIterator:
  110. start, stop, LPkwargs = jobArgs
  111. Dslice = self.funcH['makeDataSliceFromSharedMem'](
  112. self.ShMem['dataSharedMem'], cslice=(start, stop))
  113. # Fill in params needed for local step
  114. LPkwargs.update(
  115. sharedMemDictToNumpy(self.ShMem['aSharedMem']))
  116. LPkwargs.update(
  117. sharedMemDictToNumpy(self.ShMem['oSharedMem']))
  118. tstart = time.time()
  119. for rep in range(self.nReps):
  120. # Do local step
  121. LP = self.funcH['o_calcLocalParams'](Dslice, **LPkwargs)
  122. LP = self.funcH['a_calcLocalParams'](Dslice, LP, **LPkwargs)
  123. # Do summary step
  124. SSslice = self.funcH['a_calcSummaryStats'](
  125. Dslice, LP, **LPkwargs)
  126. SSslice = self.funcH['o_calcSummaryStats'](
  127. Dslice, SSslice, LP, **LPkwargs)
  128. twork = time.time() - tstart
  129. self.ResultQ.put((SSslice, twork))
  130. self.JobQ.task_done()