PageRenderTime 26ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/frontera/contrib/messagebus/zeromq/__init__.py

https://gitlab.com/e0/frontera
Python | 221 lines | 171 code | 45 blank | 5 comment | 31 complexity | e6d18391ddd800b44bbc438f001d71c3 MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. from time import time, sleep
  3. from struct import pack, unpack
  4. from logging import getLogger
  5. import zmq
  6. import six
  7. from frontera.core.messagebus import BaseMessageBus, BaseSpiderLogStream, BaseStreamConsumer, \
  8. BaseSpiderFeedStream, BaseScoringLogStream
  9. from frontera.contrib.backends.partitioners import FingerprintPartitioner, Crc32NamePartitioner
  10. from frontera.contrib.messagebus.zeromq.socket_config import SocketConfig
  11. class Consumer(BaseStreamConsumer):
  12. def __init__(self, context, location, partition_id, identity, seq_warnings=False, hwm=1000):
  13. self.subscriber = context.zeromq.socket(zmq.SUB)
  14. self.subscriber.connect(location)
  15. self.subscriber.set(zmq.RCVHWM, hwm)
  16. filter = identity + pack('>B', partition_id) if partition_id is not None else identity
  17. self.subscriber.setsockopt(zmq.SUBSCRIBE, filter)
  18. self.counter = 0
  19. self.count_global = partition_id is None
  20. self.logger = getLogger("distributed_frontera.messagebus.zeromq.Consumer(%s-%s)" % (identity, partition_id))
  21. self.seq_warnings = seq_warnings
  22. self.stats = context.stats
  23. self.stat_key = "consumer-%s" % identity
  24. self.stats[self.stat_key] = 0
  25. def get_messages(self, timeout=0.1, count=1):
  26. started = time()
  27. sleep_time = timeout / 10.0
  28. while count:
  29. try:
  30. msg = self.subscriber.recv_multipart(copy=True, flags=zmq.NOBLOCK)
  31. except zmq.Again:
  32. if time() - started > timeout:
  33. break
  34. sleep(sleep_time)
  35. else:
  36. partition_seqno, global_seqno = unpack(">II", msg[2])
  37. seqno = global_seqno if self.count_global else partition_seqno
  38. if not self.counter:
  39. self.counter = seqno
  40. elif self.counter != seqno:
  41. if self.seq_warnings:
  42. self.logger.warning("Sequence counter mismatch: expected %d, got %d. Check if system "
  43. "isn't missing messages." % (self.counter, seqno))
  44. self.counter = None
  45. yield msg[1]
  46. count -= 1
  47. if self.counter:
  48. self.counter += 1
  49. self.stats[self.stat_key] += 1
  50. def get_offset(self):
  51. return self.counter
  52. class Producer(object):
  53. def __init__(self, context, location, identity):
  54. self.identity = identity
  55. self.sender = context.zeromq.socket(zmq.PUB)
  56. self.sender.connect(location)
  57. self.counters = {}
  58. self.global_counter = 0
  59. self.stats = context.stats
  60. self.stat_key = "producer-%s" % identity
  61. self.stats[self.stat_key] = 0
  62. def send(self, key, *messages):
  63. # Guarantee that msg is actually a list or tuple (should always be true)
  64. if not isinstance(messages, (list, tuple)):
  65. raise TypeError("msg is not a list or tuple!")
  66. # Raise TypeError if any message is not encoded as bytes
  67. if any(not isinstance(m, six.binary_type) for m in messages):
  68. raise TypeError("all produce message payloads must be type bytes")
  69. partition = self.partitioner.partition(key)
  70. counter = self.counters.get(partition, 0)
  71. for msg in messages:
  72. self.sender.send_multipart([self.identity + pack(">B", partition), msg,
  73. pack(">II", counter, self.global_counter)])
  74. counter += 1
  75. self.global_counter += 1
  76. if counter == 4294967296:
  77. counter = 0
  78. if self.global_counter == 4294967296:
  79. self.global_counter = 0
  80. self.stats[self.stat_key] += 1
  81. self.counters[partition] = counter
  82. def flush(self):
  83. pass
  84. def get_offset(self, partition_id):
  85. return self.counters[partition_id]
  86. class SpiderLogProducer(Producer):
  87. def __init__(self, context, location, partitions):
  88. super(SpiderLogProducer, self).__init__(context, location, 'sl')
  89. self.partitioner = FingerprintPartitioner(partitions)
  90. class SpiderLogStream(BaseSpiderLogStream):
  91. def __init__(self, messagebus):
  92. self.context = messagebus.context
  93. self.sw_in_location = messagebus.socket_config.sw_in()
  94. self.db_in_location = messagebus.socket_config.db_in()
  95. self.out_location = messagebus.socket_config.spiders_out()
  96. self.partitions = messagebus.spider_log_partitions
  97. def producer(self):
  98. return SpiderLogProducer(self.context, self.out_location, self.partitions)
  99. def consumer(self, partition_id, type):
  100. location = self.sw_in_location if type == 'sw' else self.db_in_location
  101. return Consumer(self.context, location, partition_id, 'sl')
  102. class UpdateScoreProducer(Producer):
  103. def __init__(self, context, location):
  104. super(UpdateScoreProducer, self).__init__(context, location, 'us')
  105. def send(self, key, *messages):
  106. # Guarantee that msg is actually a list or tuple (should always be true)
  107. if not isinstance(messages, (list, tuple)):
  108. raise TypeError("msg is not a list or tuple!")
  109. # Raise TypeError if any message is not encoded as bytes
  110. if any(not isinstance(m, six.binary_type) for m in messages):
  111. raise TypeError("all produce message payloads must be type bytes")
  112. counter = self.counters.get(0, 0)
  113. for msg in messages:
  114. self.sender.send_multipart([self.identity, msg, pack(">II", counter, counter)])
  115. counter += 1
  116. if counter == 4294967296:
  117. counter = 0
  118. self.stats[self.stat_key] += 1
  119. self.counters[0] = counter
  120. class ScorinLogStream(BaseScoringLogStream):
  121. def __init__(self, messagebus):
  122. self.context = messagebus.context
  123. self.in_location = messagebus.socket_config.sw_out()
  124. self.out_location = messagebus.socket_config.db_in()
  125. def consumer(self):
  126. return Consumer(self.context, self.out_location, None, 'us')
  127. def producer(self):
  128. return UpdateScoreProducer(self.context, self.in_location)
  129. class SpiderFeedProducer(Producer):
  130. def __init__(self, context, location, partitions, hwm, hostname_partitioning):
  131. super(SpiderFeedProducer, self).__init__(context, location, 'sf')
  132. self.partitioner = Crc32NamePartitioner(partitions) if hostname_partitioning else \
  133. FingerprintPartitioner(partitions)
  134. self.sender.set(zmq.SNDHWM, hwm)
  135. class SpiderFeedStream(BaseSpiderFeedStream):
  136. def __init__(self, messagebus):
  137. self.context = messagebus.context
  138. self.in_location = messagebus.socket_config.db_out()
  139. self.out_location = messagebus.socket_config.spiders_in()
  140. self.partitions = messagebus.spider_feed_partitions
  141. self.ready_partitions = set(self.partitions)
  142. self.consumer_hwm = messagebus.spider_feed_rcvhwm
  143. self.producer_hwm = messagebus.spider_feed_sndhwm
  144. self.hostname_partitioning = messagebus.hostname_partitioning
  145. def consumer(self, partition_id):
  146. return Consumer(self.context, self.out_location, partition_id, 'sf', seq_warnings=True, hwm=self.consumer_hwm)
  147. def producer(self):
  148. return SpiderFeedProducer(self.context, self.in_location, self.partitions,
  149. self.producer_hwm, self.hostname_partitioning)
  150. def available_partitions(self):
  151. return self.ready_partitions
  152. def mark_ready(self, partition_id):
  153. self.ready_partitions.add(partition_id)
  154. def mark_busy(self, partition_id):
  155. self.ready_partitions.discard(partition_id)
  156. class Context(object):
  157. zeromq = zmq.Context()
  158. stats = {}
  159. class MessageBus(BaseMessageBus):
  160. def __init__(self, settings):
  161. self.context = Context()
  162. self.socket_config = SocketConfig(settings.get('ZMQ_ADDRESS'),
  163. settings.get('ZMQ_BASE_PORT'))
  164. self.spider_log_partitions = [i for i in range(settings.get('SPIDER_LOG_PARTITIONS'))]
  165. self.spider_feed_partitions = [i for i in range(settings.get('SPIDER_FEED_PARTITIONS'))]
  166. self.spider_feed_sndhwm = int(settings.get('MAX_NEXT_REQUESTS') * len(self.spider_feed_partitions) * 1.2)
  167. self.spider_feed_rcvhwm = int(settings.get('MAX_NEXT_REQUESTS') * 2.0)
  168. self.hostname_partitioning = settings.get('QUEUE_HOSTNAME_PARTITIONING')
  169. if self.socket_config.is_ipv6:
  170. self.context.zeromq.setsockopt(zmq.IPV6, True)
  171. def spider_log(self):
  172. return SpiderLogStream(self)
  173. def scoring_log(self):
  174. return ScorinLogStream(self)
  175. def spider_feed(self):
  176. return SpiderFeedStream(self)