PageRenderTime 43ms CodeModel.GetById 14ms RepoModel.GetById 1ms app.codeStats 0ms

/pyspark/accumulators.py

https://gitlab.com/nexemjail/mir
Python | 269 lines | 250 code | 1 blank | 18 comment | 0 complexity | fc751a82653034d10399b2c2c3dad44b MD5 | raw file
  1. #
  2. # Licensed to the Apache Software Foundation (ASF) under one or more
  3. # contributor license agreements. See the NOTICE file distributed with
  4. # this work for additional information regarding copyright ownership.
  5. # The ASF licenses this file to You under the Apache License, Version 2.0
  6. # (the "License"); you may not use this file except in compliance with
  7. # the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """
  18. >>> from pyspark.context import SparkContext
  19. >>> sc = SparkContext('local', 'test')
  20. >>> a = sc.accumulator(1)
  21. >>> a.value
  22. 1
  23. >>> a.value = 2
  24. >>> a.value
  25. 2
  26. >>> a += 5
  27. >>> a.value
  28. 7
  29. >>> sc.accumulator(1.0).value
  30. 1.0
  31. >>> sc.accumulator(1j).value
  32. 1j
  33. >>> rdd = sc.parallelize([1,2,3])
  34. >>> def f(x):
  35. ... global a
  36. ... a += x
  37. >>> rdd.foreach(f)
  38. >>> a.value
  39. 13
  40. >>> b = sc.accumulator(0)
  41. >>> def g(x):
  42. ... b.add(x)
  43. >>> rdd.foreach(g)
  44. >>> b.value
  45. 6
  46. >>> from pyspark.accumulators import AccumulatorParam
  47. >>> class VectorAccumulatorParam(AccumulatorParam):
  48. ... def zero(self, value):
  49. ... return [0.0] * len(value)
  50. ... def addInPlace(self, val1, val2):
  51. ... for i in range(len(val1)):
  52. ... val1[i] += val2[i]
  53. ... return val1
  54. >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
  55. >>> va.value
  56. [1.0, 2.0, 3.0]
  57. >>> def g(x):
  58. ... global va
  59. ... va += [x] * 3
  60. >>> rdd.foreach(g)
  61. >>> va.value
  62. [7.0, 8.0, 9.0]
  63. >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
  64. Traceback (most recent call last):
  65. ...
  66. Py4JJavaError:...
  67. >>> def h(x):
  68. ... global a
  69. ... a.value = 7
  70. >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
  71. Traceback (most recent call last):
  72. ...
  73. Py4JJavaError:...
  74. >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
  75. Traceback (most recent call last):
  76. ...
  77. TypeError:...
  78. """
  79. import sys
  80. import select
  81. import struct
  82. if sys.version < '3':
  83. import SocketServer
  84. else:
  85. import socketserver as SocketServer
  86. import threading
  87. from pyspark.cloudpickle import CloudPickler
  88. from pyspark.serializers import read_int, PickleSerializer
  89. __all__ = ['Accumulator', 'AccumulatorParam']
  90. pickleSer = PickleSerializer()
  91. # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
  92. # the local accumulator updates back to the driver program at the end of a task.
  93. _accumulatorRegistry = {}
  94. def _deserialize_accumulator(aid, zero_value, accum_param):
  95. from pyspark.accumulators import _accumulatorRegistry
  96. accum = Accumulator(aid, zero_value, accum_param)
  97. accum._deserialized = True
  98. _accumulatorRegistry[aid] = accum
  99. return accum
  100. class Accumulator(object):
  101. """
  102. A shared variable that can be accumulated, i.e., has a commutative and associative "add"
  103. operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
  104. operator, but only the driver program is allowed to access its value, using C{value}.
  105. Updates from the workers get propagated automatically to the driver program.
  106. While C{SparkContext} supports accumulators for primitive data types like C{int} and
  107. C{float}, users can also define accumulators for custom types by providing a custom
  108. L{AccumulatorParam} object. Refer to the doctest of this module for an example.
  109. """
  110. def __init__(self, aid, value, accum_param):
  111. """Create a new Accumulator with a given initial value and AccumulatorParam object"""
  112. from pyspark.accumulators import _accumulatorRegistry
  113. self.aid = aid
  114. self.accum_param = accum_param
  115. self._value = value
  116. self._deserialized = False
  117. _accumulatorRegistry[aid] = self
  118. def __reduce__(self):
  119. """Custom serialization; saves the zero value from our AccumulatorParam"""
  120. param = self.accum_param
  121. return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
  122. @property
  123. def value(self):
  124. """Get the accumulator's value; only usable in driver program"""
  125. if self._deserialized:
  126. raise Exception("Accumulator.value cannot be accessed inside tasks")
  127. return self._value
  128. @value.setter
  129. def value(self, value):
  130. """Sets the accumulator's value; only usable in driver program"""
  131. if self._deserialized:
  132. raise Exception("Accumulator.value cannot be accessed inside tasks")
  133. self._value = value
  134. def add(self, term):
  135. """Adds a term to this accumulator's value"""
  136. self._value = self.accum_param.addInPlace(self._value, term)
  137. def __iadd__(self, term):
  138. """The += operator; adds a term to this accumulator's value"""
  139. self.add(term)
  140. return self
  141. def __str__(self):
  142. return str(self._value)
  143. def __repr__(self):
  144. return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
  145. class AccumulatorParam(object):
  146. """
  147. Helper object that defines how to accumulate values of a given type.
  148. """
  149. def zero(self, value):
  150. """
  151. Provide a "zero value" for the type, compatible in dimensions with the
  152. provided C{value} (e.g., a zero vector)
  153. """
  154. raise NotImplementedError
  155. def addInPlace(self, value1, value2):
  156. """
  157. Add two values of the accumulator's data type, returning a new value;
  158. for efficiency, can also update C{value1} in place and return it.
  159. """
  160. raise NotImplementedError
  161. class AddingAccumulatorParam(AccumulatorParam):
  162. """
  163. An AccumulatorParam that uses the + operators to add values. Designed for simple types
  164. such as integers, floats, and lists. Requires the zero value for the underlying type
  165. as a parameter.
  166. """
  167. def __init__(self, zero_value):
  168. self.zero_value = zero_value
  169. def zero(self, value):
  170. return self.zero_value
  171. def addInPlace(self, value1, value2):
  172. value1 += value2
  173. return value1
  174. # Singleton accumulator params for some standard types
  175. INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
  176. FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
  177. COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
  178. class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
  179. """
  180. This handler will keep polling updates from the same socket until the
  181. server is shutdown.
  182. """
  183. def handle(self):
  184. from pyspark.accumulators import _accumulatorRegistry
  185. while not self.server.server_shutdown:
  186. # Poll every 1 second for new data -- don't block in case of shutdown.
  187. r, _, _ = select.select([self.rfile], [], [], 1)
  188. if self.rfile in r:
  189. num_updates = read_int(self.rfile)
  190. for _ in range(num_updates):
  191. (aid, update) = pickleSer._read_with_length(self.rfile)
  192. _accumulatorRegistry[aid] += update
  193. # Write a byte in acknowledgement
  194. self.wfile.write(struct.pack("!b", 1))
  195. class AccumulatorServer(SocketServer.TCPServer):
  196. """
  197. A simple TCP server that intercepts shutdown() in order to interrupt
  198. our continuous polling on the handler.
  199. """
  200. server_shutdown = False
  201. def shutdown(self):
  202. self.server_shutdown = True
  203. SocketServer.TCPServer.shutdown(self)
  204. self.server_close()
  205. def _start_update_server():
  206. """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
  207. server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
  208. thread = threading.Thread(target=server.serve_forever)
  209. thread.daemon = True
  210. thread.start()
  211. return server
  212. if __name__ == "__main__":
  213. import doctest
  214. (failure_count, test_count) = doctest.testmod()
  215. if failure_count:
  216. exit(-1)