PageRenderTime 62ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/test/python_SUITE_data/src/base.py

https://github.com/rabbitmq/rabbitmq-stomp
Python | 259 lines | 233 code | 15 blank | 11 comment | 33 complexity | ca9cfe5c285f6f347d796cb747acbd73 MD5 | raw file
  1. ## This Source Code Form is subject to the terms of the Mozilla Public
  2. ## License, v. 2.0. If a copy of the MPL was not distributed with this
  3. ## file, You can obtain one at https://mozilla.org/MPL/2.0/.
  4. ##
  5. ## Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
  6. ##
  7. import unittest
  8. import stomp
  9. import sys
  10. import threading
  11. import os
  12. class BaseTest(unittest.TestCase):
  13. def create_connection_obj(self, version='1.0', vhost='/', heartbeats=(0, 0)):
  14. if version == '1.0':
  15. conn = stomp.StompConnection10(host_and_ports=[('localhost', int(os.environ["STOMP_PORT"]))])
  16. self.ack_id_source_header = 'message-id'
  17. self.ack_id_header = 'message-id'
  18. elif version == '1.1':
  19. conn = stomp.StompConnection11(host_and_ports=[('localhost', int(os.environ["STOMP_PORT"]))],
  20. vhost=vhost,
  21. heartbeats=heartbeats)
  22. self.ack_id_source_header = 'message-id'
  23. self.ack_id_header = 'message-id'
  24. elif version == '1.2':
  25. conn = stomp.StompConnection12(host_and_ports=[('localhost', int(os.environ["STOMP_PORT"]))],
  26. vhost=vhost,
  27. heartbeats=heartbeats)
  28. self.ack_id_source_header = 'ack'
  29. self.ack_id_header = 'id'
  30. else:
  31. conn = stomp.StompConnection12(host_and_ports=[('localhost', int(os.environ["STOMP_PORT"]))],
  32. vhost=vhost,
  33. heartbeats=heartbeats)
  34. conn.version = version
  35. return conn
  36. def create_connection(self, user='guest', passcode='guest', wait=True, **kwargs):
  37. conn = self.create_connection_obj(**kwargs)
  38. conn.start()
  39. conn.connect(user, passcode, wait=wait)
  40. return conn
  41. def subscribe_dest(self, conn, destination, sub_id, **kwargs):
  42. if type(conn) is stomp.StompConnection10:
  43. # 'id' is optional in STOMP 1.0.
  44. if sub_id != None:
  45. kwargs['id'] = sub_id
  46. conn.subscribe(destination, **kwargs)
  47. else:
  48. # 'id' is required in STOMP 1.1+.
  49. if sub_id == None:
  50. sub_id = 'ctag'
  51. conn.subscribe(destination, sub_id, **kwargs)
  52. def unsubscribe_dest(self, conn, destination, sub_id, **kwargs):
  53. if type(conn) is stomp.StompConnection10:
  54. # 'id' is optional in STOMP 1.0.
  55. if sub_id != None:
  56. conn.unsubscribe(id=sub_id, **kwargs)
  57. else:
  58. conn.unsubscribe(destination=destination, **kwargs)
  59. else:
  60. # 'id' is required in STOMP 1.1+.
  61. if sub_id == None:
  62. sub_id = 'ctag'
  63. conn.unsubscribe(sub_id, **kwargs)
  64. def ack_message(self, conn, msg_id, sub_id, **kwargs):
  65. if type(conn) is stomp.StompConnection10:
  66. conn.ack(msg_id, **kwargs)
  67. elif type(conn) is stomp.StompConnection11:
  68. if sub_id == None:
  69. sub_id = 'ctag'
  70. conn.ack(msg_id, sub_id, **kwargs)
  71. elif type(conn) is stomp.StompConnection12:
  72. conn.ack(msg_id, **kwargs)
  73. def nack_message(self, conn, msg_id, sub_id, **kwargs):
  74. if type(conn) is stomp.StompConnection10:
  75. # Normally unsupported by STOMP 1.0.
  76. conn.send_frame("NACK", {"message-id": msg_id})
  77. elif type(conn) is stomp.StompConnection11:
  78. if sub_id == None:
  79. sub_id = 'ctag'
  80. conn.nack(msg_id, sub_id, **kwargs)
  81. elif type(conn) is stomp.StompConnection12:
  82. conn.nack(msg_id, **kwargs)
  83. def create_subscriber_connection(self, dest):
  84. conn = self.create_connection()
  85. listener = WaitableListener()
  86. conn.set_listener('', listener)
  87. self.subscribe_dest(conn, dest, None, receipt="sub.receipt")
  88. listener.wait()
  89. self.assertEquals(1, len(listener.receipts))
  90. listener.reset()
  91. return conn, listener
  92. def setUp(self):
  93. # Note: useful for debugging
  94. # import stomp.listener
  95. self.conn = self.create_connection()
  96. self.listener = WaitableListener()
  97. self.conn.set_listener('waitable', self.listener)
  98. # Note: useful for debugging
  99. # self.printing_listener = stomp.listener.PrintingListener()
  100. # self.conn.set_listener('printing', self.printing_listener)
  101. def tearDown(self):
  102. if self.conn.is_connected():
  103. self.conn.disconnect()
  104. self.conn.stop()
  105. def simple_test_send_rec(self, dest, headers={}):
  106. self.listener.reset()
  107. self.subscribe_dest(self.conn, dest, None)
  108. self.conn.send(dest, "foo", headers=headers)
  109. self.assertTrue(self.listener.wait(), "Timeout, no message received")
  110. # assert no errors
  111. if len(self.listener.errors) > 0:
  112. self.fail(self.listener.errors[0]['message'])
  113. # check header content
  114. msg = self.listener.messages[0]
  115. self.assertEquals("foo", msg['message'])
  116. self.assertEquals(dest, msg['headers']['destination'])
  117. return msg['headers']
  118. def assertListener(self, errMsg, numMsgs=0, numErrs=0, numRcts=0, timeout=10):
  119. if numMsgs + numErrs + numRcts > 0:
  120. self._assertTrue(self.listener.wait(timeout), errMsg + " (#awaiting)")
  121. else:
  122. self._assertFalse(self.listener.wait(timeout), errMsg + " (#awaiting)")
  123. self._assertEquals(numMsgs, len(self.listener.messages), errMsg + " (#messages)")
  124. self._assertEquals(numErrs, len(self.listener.errors), errMsg + " (#errors)")
  125. self._assertEquals(numRcts, len(self.listener.receipts), errMsg + " (#receipts)")
  126. def _assertTrue(self, bool, msg):
  127. if not bool:
  128. self.listener.print_state(msg, True)
  129. self.assertTrue(bool, msg)
  130. def _assertFalse(self, bool, msg):
  131. if bool:
  132. self.listener.print_state(msg, True)
  133. self.assertFalse(bool, msg)
  134. def _assertEquals(self, expected, actual, msg):
  135. if expected != actual:
  136. self.listener.print_state(msg, True)
  137. self.assertEquals(expected, actual, msg)
  138. def assertListenerAfter(self, verb, errMsg="", numMsgs=0, numErrs=0, numRcts=0, timeout=5):
  139. num = numMsgs + numErrs + numRcts
  140. self.listener.reset(num if num>0 else 1)
  141. verb()
  142. self.assertListener(errMsg=errMsg, numMsgs=numMsgs, numErrs=numErrs, numRcts=numRcts, timeout=timeout)
  143. class WaitableListener(object):
  144. def __init__(self):
  145. self.debug = False
  146. if self.debug:
  147. print('(listener) init')
  148. self.messages = []
  149. self.errors = []
  150. self.receipts = []
  151. self.latch = Latch(1)
  152. self.msg_no = 0
  153. def _next_msg_no(self):
  154. self.msg_no += 1
  155. return self.msg_no
  156. def _append(self, array, msg, hdrs):
  157. mno = self._next_msg_no()
  158. array.append({'message' : msg, 'headers' : hdrs, 'msg_no' : mno})
  159. self.latch.countdown()
  160. def on_receipt(self, headers, message):
  161. if self.debug:
  162. print('(on_receipt) message: {}, headers: {}'.format(message, headers))
  163. self._append(self.receipts, message, headers)
  164. def on_error(self, headers, message):
  165. if self.debug:
  166. print('(on_error) message: {}, headers: {}'.format(message, headers))
  167. self._append(self.errors, message, headers)
  168. def on_message(self, headers, message):
  169. if self.debug:
  170. print('(on_message) message: {}, headers: {}'.format(message, headers))
  171. self._append(self.messages, message, headers)
  172. def reset(self, count=1):
  173. if self.debug:
  174. self.print_state('(reset listener--old state)')
  175. self.messages = []
  176. self.errors = []
  177. self.receipts = []
  178. self.latch = Latch(count)
  179. self.msg_no = 0
  180. if self.debug:
  181. self.print_state('(reset listener--new state)')
  182. def wait(self, timeout=10):
  183. return self.latch.wait(timeout)
  184. def print_state(self, hdr="", full=False):
  185. print(hdr)
  186. print('#messages: {}'.format(len(self.messages)))
  187. print('#errors: {}', len(self.errors))
  188. print('#receipts: {}'.format(len(self.receipts)))
  189. print('Remaining count: {}'.format(self.latch.get_count()))
  190. if full:
  191. if len(self.messages) != 0: print('Messages: {}'.format(self.messages))
  192. if len(self.errors) != 0: print('Messages: {}'.format(self.errors))
  193. if len(self.receipts) != 0: print('Messages: {}'.format(self.receipts))
  194. class Latch(object):
  195. def __init__(self, count=1):
  196. self.cond = threading.Condition()
  197. self.cond.acquire()
  198. self.count = count
  199. self.cond.release()
  200. def countdown(self):
  201. self.cond.acquire()
  202. if self.count > 0:
  203. self.count -= 1
  204. if self.count == 0:
  205. self.cond.notify_all()
  206. self.cond.release()
  207. def wait(self, timeout=None):
  208. try:
  209. self.cond.acquire()
  210. if self.count == 0:
  211. return True
  212. else:
  213. self.cond.wait(timeout)
  214. return self.count == 0
  215. finally:
  216. self.cond.release()
  217. def get_count(self):
  218. try:
  219. self.cond.acquire()
  220. return self.count
  221. finally:
  222. self.cond.release()