/Lib/multiprocessing/pool.py

http://unladen-swallow.googlecode.com/ · Python · 596 lines · 436 code · 94 blank · 66 comment · 114 complexity · 6a5238ca7cadd47708fb0f16aeb8d8dd MD5 · raw file

  1. #
  2. # Module providing the `Pool` class for managing a process pool
  3. #
  4. # multiprocessing/pool.py
  5. #
  6. # Copyright (c) 2007-2008, R Oudkerk --- see COPYING.txt
  7. #
  8. __all__ = ['Pool']
  9. #
  10. # Imports
  11. #
  12. import threading
  13. import Queue
  14. import itertools
  15. import collections
  16. import time
  17. from multiprocessing import Process, cpu_count, TimeoutError
  18. from multiprocessing.util import Finalize, debug
  19. #
  20. # Constants representing the state of a pool
  21. #
  22. RUN = 0
  23. CLOSE = 1
  24. TERMINATE = 2
  25. #
  26. # Miscellaneous
  27. #
  28. job_counter = itertools.count()
  29. def mapstar(args):
  30. return map(*args)
  31. #
  32. # Code run by worker processes
  33. #
  34. def worker(inqueue, outqueue, initializer=None, initargs=()):
  35. put = outqueue.put
  36. get = inqueue.get
  37. if hasattr(inqueue, '_writer'):
  38. inqueue._writer.close()
  39. outqueue._reader.close()
  40. if initializer is not None:
  41. initializer(*initargs)
  42. while 1:
  43. try:
  44. task = get()
  45. except (EOFError, IOError):
  46. debug('worker got EOFError or IOError -- exiting')
  47. break
  48. if task is None:
  49. debug('worker got sentinel -- exiting')
  50. break
  51. job, i, func, args, kwds = task
  52. try:
  53. result = (True, func(*args, **kwds))
  54. except Exception, e:
  55. result = (False, e)
  56. put((job, i, result))
  57. #
  58. # Class representing a process pool
  59. #
  60. class Pool(object):
  61. '''
  62. Class which supports an async version of the `apply()` builtin
  63. '''
  64. Process = Process
  65. def __init__(self, processes=None, initializer=None, initargs=()):
  66. self._setup_queues()
  67. self._taskqueue = Queue.Queue()
  68. self._cache = {}
  69. self._state = RUN
  70. if processes is None:
  71. try:
  72. processes = cpu_count()
  73. except NotImplementedError:
  74. processes = 1
  75. self._pool = []
  76. for i in range(processes):
  77. w = self.Process(
  78. target=worker,
  79. args=(self._inqueue, self._outqueue, initializer, initargs)
  80. )
  81. self._pool.append(w)
  82. w.name = w.name.replace('Process', 'PoolWorker')
  83. w.daemon = True
  84. w.start()
  85. self._task_handler = threading.Thread(
  86. target=Pool._handle_tasks,
  87. args=(self._taskqueue, self._quick_put, self._outqueue, self._pool)
  88. )
  89. self._task_handler.daemon = True
  90. self._task_handler._state = RUN
  91. self._task_handler.start()
  92. self._result_handler = threading.Thread(
  93. target=Pool._handle_results,
  94. args=(self._outqueue, self._quick_get, self._cache)
  95. )
  96. self._result_handler.daemon = True
  97. self._result_handler._state = RUN
  98. self._result_handler.start()
  99. self._terminate = Finalize(
  100. self, self._terminate_pool,
  101. args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
  102. self._task_handler, self._result_handler, self._cache),
  103. exitpriority=15
  104. )
  105. def _setup_queues(self):
  106. from .queues import SimpleQueue
  107. self._inqueue = SimpleQueue()
  108. self._outqueue = SimpleQueue()
  109. self._quick_put = self._inqueue._writer.send
  110. self._quick_get = self._outqueue._reader.recv
  111. def apply(self, func, args=(), kwds={}):
  112. '''
  113. Equivalent of `apply()` builtin
  114. '''
  115. assert self._state == RUN
  116. return self.apply_async(func, args, kwds).get()
  117. def map(self, func, iterable, chunksize=None):
  118. '''
  119. Equivalent of `map()` builtin
  120. '''
  121. assert self._state == RUN
  122. return self.map_async(func, iterable, chunksize).get()
  123. def imap(self, func, iterable, chunksize=1):
  124. '''
  125. Equivalent of `itertools.imap()` -- can be MUCH slower than `Pool.map()`
  126. '''
  127. assert self._state == RUN
  128. if chunksize == 1:
  129. result = IMapIterator(self._cache)
  130. self._taskqueue.put((((result._job, i, func, (x,), {})
  131. for i, x in enumerate(iterable)), result._set_length))
  132. return result
  133. else:
  134. assert chunksize > 1
  135. task_batches = Pool._get_tasks(func, iterable, chunksize)
  136. result = IMapIterator(self._cache)
  137. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  138. for i, x in enumerate(task_batches)), result._set_length))
  139. return (item for chunk in result for item in chunk)
  140. def imap_unordered(self, func, iterable, chunksize=1):
  141. '''
  142. Like `imap()` method but ordering of results is arbitrary
  143. '''
  144. assert self._state == RUN
  145. if chunksize == 1:
  146. result = IMapUnorderedIterator(self._cache)
  147. self._taskqueue.put((((result._job, i, func, (x,), {})
  148. for i, x in enumerate(iterable)), result._set_length))
  149. return result
  150. else:
  151. assert chunksize > 1
  152. task_batches = Pool._get_tasks(func, iterable, chunksize)
  153. result = IMapUnorderedIterator(self._cache)
  154. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  155. for i, x in enumerate(task_batches)), result._set_length))
  156. return (item for chunk in result for item in chunk)
  157. def apply_async(self, func, args=(), kwds={}, callback=None):
  158. '''
  159. Asynchronous equivalent of `apply()` builtin
  160. '''
  161. assert self._state == RUN
  162. result = ApplyResult(self._cache, callback)
  163. self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
  164. return result
  165. def map_async(self, func, iterable, chunksize=None, callback=None):
  166. '''
  167. Asynchronous equivalent of `map()` builtin
  168. '''
  169. assert self._state == RUN
  170. if not hasattr(iterable, '__len__'):
  171. iterable = list(iterable)
  172. if chunksize is None:
  173. chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
  174. if extra:
  175. chunksize += 1
  176. task_batches = Pool._get_tasks(func, iterable, chunksize)
  177. result = MapResult(self._cache, chunksize, len(iterable), callback)
  178. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  179. for i, x in enumerate(task_batches)), None))
  180. return result
  181. @staticmethod
  182. def _handle_tasks(taskqueue, put, outqueue, pool):
  183. thread = threading.current_thread()
  184. for taskseq, set_length in iter(taskqueue.get, None):
  185. i = -1
  186. for i, task in enumerate(taskseq):
  187. if thread._state:
  188. debug('task handler found thread._state != RUN')
  189. break
  190. try:
  191. put(task)
  192. except IOError:
  193. debug('could not put task on queue')
  194. break
  195. else:
  196. if set_length:
  197. debug('doing set_length()')
  198. set_length(i+1)
  199. continue
  200. break
  201. else:
  202. debug('task handler got sentinel')
  203. try:
  204. # tell result handler to finish when cache is empty
  205. debug('task handler sending sentinel to result handler')
  206. outqueue.put(None)
  207. # tell workers there is no more work
  208. debug('task handler sending sentinel to workers')
  209. for p in pool:
  210. put(None)
  211. except IOError:
  212. debug('task handler got IOError when sending sentinels')
  213. debug('task handler exiting')
  214. @staticmethod
  215. def _handle_results(outqueue, get, cache):
  216. thread = threading.current_thread()
  217. while 1:
  218. try:
  219. task = get()
  220. except (IOError, EOFError):
  221. debug('result handler got EOFError/IOError -- exiting')
  222. return
  223. if thread._state:
  224. assert thread._state == TERMINATE
  225. debug('result handler found thread._state=TERMINATE')
  226. break
  227. if task is None:
  228. debug('result handler got sentinel')
  229. break
  230. job, i, obj = task
  231. try:
  232. cache[job]._set(i, obj)
  233. except KeyError:
  234. pass
  235. while cache and thread._state != TERMINATE:
  236. try:
  237. task = get()
  238. except (IOError, EOFError):
  239. debug('result handler got EOFError/IOError -- exiting')
  240. return
  241. if task is None:
  242. debug('result handler ignoring extra sentinel')
  243. continue
  244. job, i, obj = task
  245. try:
  246. cache[job]._set(i, obj)
  247. except KeyError:
  248. pass
  249. if hasattr(outqueue, '_reader'):
  250. debug('ensuring that outqueue is not full')
  251. # If we don't make room available in outqueue then
  252. # attempts to add the sentinel (None) to outqueue may
  253. # block. There is guaranteed to be no more than 2 sentinels.
  254. try:
  255. for i in range(10):
  256. if not outqueue._reader.poll():
  257. break
  258. get()
  259. except (IOError, EOFError):
  260. pass
  261. debug('result handler exiting: len(cache)=%s, thread._state=%s',
  262. len(cache), thread._state)
  263. @staticmethod
  264. def _get_tasks(func, it, size):
  265. it = iter(it)
  266. while 1:
  267. x = tuple(itertools.islice(it, size))
  268. if not x:
  269. return
  270. yield (func, x)
  271. def __reduce__(self):
  272. raise NotImplementedError(
  273. 'pool objects cannot be passed between processes or pickled'
  274. )
  275. def close(self):
  276. debug('closing pool')
  277. if self._state == RUN:
  278. self._state = CLOSE
  279. self._taskqueue.put(None)
  280. def terminate(self):
  281. debug('terminating pool')
  282. self._state = TERMINATE
  283. self._terminate()
  284. def join(self):
  285. debug('joining pool')
  286. assert self._state in (CLOSE, TERMINATE)
  287. self._task_handler.join()
  288. self._result_handler.join()
  289. for p in self._pool:
  290. p.join()
  291. @staticmethod
  292. def _help_stuff_finish(inqueue, task_handler, size):
  293. # task_handler may be blocked trying to put items on inqueue
  294. debug('removing tasks from inqueue until task handler finished')
  295. inqueue._rlock.acquire()
  296. while task_handler.is_alive() and inqueue._reader.poll():
  297. inqueue._reader.recv()
  298. time.sleep(0)
  299. @classmethod
  300. def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
  301. task_handler, result_handler, cache):
  302. # this is guaranteed to only be called once
  303. debug('finalizing pool')
  304. task_handler._state = TERMINATE
  305. taskqueue.put(None) # sentinel
  306. debug('helping task handler/workers to finish')
  307. cls._help_stuff_finish(inqueue, task_handler, len(pool))
  308. assert result_handler.is_alive() or len(cache) == 0
  309. result_handler._state = TERMINATE
  310. outqueue.put(None) # sentinel
  311. if pool and hasattr(pool[0], 'terminate'):
  312. debug('terminating workers')
  313. for p in pool:
  314. p.terminate()
  315. debug('joining task handler')
  316. task_handler.join(1e100)
  317. debug('joining result handler')
  318. result_handler.join(1e100)
  319. if pool and hasattr(pool[0], 'terminate'):
  320. debug('joining pool workers')
  321. for p in pool:
  322. p.join()
  323. #
  324. # Class whose instances are returned by `Pool.apply_async()`
  325. #
  326. class ApplyResult(object):
  327. def __init__(self, cache, callback):
  328. self._cond = threading.Condition(threading.Lock())
  329. self._job = job_counter.next()
  330. self._cache = cache
  331. self._ready = False
  332. self._callback = callback
  333. cache[self._job] = self
  334. def ready(self):
  335. return self._ready
  336. def successful(self):
  337. assert self._ready
  338. return self._success
  339. def wait(self, timeout=None):
  340. self._cond.acquire()
  341. try:
  342. if not self._ready:
  343. self._cond.wait(timeout)
  344. finally:
  345. self._cond.release()
  346. def get(self, timeout=None):
  347. self.wait(timeout)
  348. if not self._ready:
  349. raise TimeoutError
  350. if self._success:
  351. return self._value
  352. else:
  353. raise self._value
  354. def _set(self, i, obj):
  355. self._success, self._value = obj
  356. if self._callback and self._success:
  357. self._callback(self._value)
  358. self._cond.acquire()
  359. try:
  360. self._ready = True
  361. self._cond.notify()
  362. finally:
  363. self._cond.release()
  364. del self._cache[self._job]
  365. #
  366. # Class whose instances are returned by `Pool.map_async()`
  367. #
  368. class MapResult(ApplyResult):
  369. def __init__(self, cache, chunksize, length, callback):
  370. ApplyResult.__init__(self, cache, callback)
  371. self._success = True
  372. self._value = [None] * length
  373. self._chunksize = chunksize
  374. if chunksize <= 0:
  375. self._number_left = 0
  376. self._ready = True
  377. else:
  378. self._number_left = length//chunksize + bool(length % chunksize)
  379. def _set(self, i, success_result):
  380. success, result = success_result
  381. if success:
  382. self._value[i*self._chunksize:(i+1)*self._chunksize] = result
  383. self._number_left -= 1
  384. if self._number_left == 0:
  385. if self._callback:
  386. self._callback(self._value)
  387. del self._cache[self._job]
  388. self._cond.acquire()
  389. try:
  390. self._ready = True
  391. self._cond.notify()
  392. finally:
  393. self._cond.release()
  394. else:
  395. self._success = False
  396. self._value = result
  397. del self._cache[self._job]
  398. self._cond.acquire()
  399. try:
  400. self._ready = True
  401. self._cond.notify()
  402. finally:
  403. self._cond.release()
  404. #
  405. # Class whose instances are returned by `Pool.imap()`
  406. #
  407. class IMapIterator(object):
  408. def __init__(self, cache):
  409. self._cond = threading.Condition(threading.Lock())
  410. self._job = job_counter.next()
  411. self._cache = cache
  412. self._items = collections.deque()
  413. self._index = 0
  414. self._length = None
  415. self._unsorted = {}
  416. cache[self._job] = self
  417. def __iter__(self):
  418. return self
  419. def next(self, timeout=None):
  420. self._cond.acquire()
  421. try:
  422. try:
  423. item = self._items.popleft()
  424. except IndexError:
  425. if self._index == self._length:
  426. raise StopIteration
  427. self._cond.wait(timeout)
  428. try:
  429. item = self._items.popleft()
  430. except IndexError:
  431. if self._index == self._length:
  432. raise StopIteration
  433. raise TimeoutError
  434. finally:
  435. self._cond.release()
  436. success, value = item
  437. if success:
  438. return value
  439. raise value
  440. __next__ = next # XXX
  441. def _set(self, i, obj):
  442. self._cond.acquire()
  443. try:
  444. if self._index == i:
  445. self._items.append(obj)
  446. self._index += 1
  447. while self._index in self._unsorted:
  448. obj = self._unsorted.pop(self._index)
  449. self._items.append(obj)
  450. self._index += 1
  451. self._cond.notify()
  452. else:
  453. self._unsorted[i] = obj
  454. if self._index == self._length:
  455. del self._cache[self._job]
  456. finally:
  457. self._cond.release()
  458. def _set_length(self, length):
  459. self._cond.acquire()
  460. try:
  461. self._length = length
  462. if self._index == self._length:
  463. self._cond.notify()
  464. del self._cache[self._job]
  465. finally:
  466. self._cond.release()
  467. #
  468. # Class whose instances are returned by `Pool.imap_unordered()`
  469. #
  470. class IMapUnorderedIterator(IMapIterator):
  471. def _set(self, i, obj):
  472. self._cond.acquire()
  473. try:
  474. self._items.append(obj)
  475. self._index += 1
  476. self._cond.notify()
  477. if self._index == self._length:
  478. del self._cache[self._job]
  479. finally:
  480. self._cond.release()
  481. #
  482. #
  483. #
  484. class ThreadPool(Pool):
  485. from .dummy import Process
  486. def __init__(self, processes=None, initializer=None, initargs=()):
  487. Pool.__init__(self, processes, initializer, initargs)
  488. def _setup_queues(self):
  489. self._inqueue = Queue.Queue()
  490. self._outqueue = Queue.Queue()
  491. self._quick_put = self._inqueue.put
  492. self._quick_get = self._outqueue.get
  493. @staticmethod
  494. def _help_stuff_finish(inqueue, task_handler, size):
  495. # put sentinels at head of inqueue to make workers finish
  496. inqueue.not_empty.acquire()
  497. try:
  498. inqueue.queue.clear()
  499. inqueue.queue.extend([None] * size)
  500. inqueue.not_empty.notify_all()
  501. finally:
  502. inqueue.not_empty.release()