PageRenderTime 66ms CodeModel.GetById 28ms RepoModel.GetById 1ms app.codeStats 0ms

/zerver/lib/test_helpers.py

https://gitlab.com/EnLab/zulip
Python | 350 lines | 282 code | 51 blank | 17 comment | 14 complexity | c656413c06c467d3f0a2d74a2cd62c08 MD5 | raw file
  1. from django.test import TestCase
  2. from zerver.lib.initial_password import initial_password
  3. from zerver.lib.db import TimeTrackingCursor
  4. from zerver.lib import cache
  5. from zerver.lib import event_queue
  6. from zerver.worker import queue_processors
  7. from zerver.lib.actions import (
  8. check_send_message, create_stream_if_needed, do_add_subscription,
  9. get_display_recipient, get_user_profile_by_email,
  10. )
  11. from zerver.models import (
  12. resolve_email_to_domain,
  13. Client,
  14. Message,
  15. Realm,
  16. Recipient,
  17. Stream,
  18. Subscription,
  19. UserMessage,
  20. )
  21. import base64
  22. import os
  23. import re
  24. import time
  25. import ujson
  26. import urllib
  27. from contextlib import contextmanager
  28. API_KEYS = {}
  29. @contextmanager
  30. def stub(obj, name, f):
  31. old_f = getattr(obj, name)
  32. setattr(obj, name, f)
  33. yield
  34. setattr(obj, name, old_f)
  35. @contextmanager
  36. def simulated_queue_client(client):
  37. real_SimpleQueueClient = queue_processors.SimpleQueueClient
  38. queue_processors.SimpleQueueClient = client
  39. yield
  40. queue_processors.SimpleQueueClient = real_SimpleQueueClient
  41. @contextmanager
  42. def tornado_redirected_to_list(lst):
  43. real_event_queue_process_notification = event_queue.process_notification
  44. event_queue.process_notification = lst.append
  45. yield
  46. event_queue.process_notification = real_event_queue_process_notification
  47. @contextmanager
  48. def simulated_empty_cache():
  49. cache_queries = []
  50. def my_cache_get(key, cache_name=None):
  51. cache_queries.append(('get', key, cache_name))
  52. return None
  53. def my_cache_get_many(keys, cache_name=None):
  54. cache_queries.append(('getmany', keys, cache_name))
  55. return None
  56. old_get = cache.cache_get
  57. old_get_many = cache.cache_get_many
  58. cache.cache_get = my_cache_get
  59. cache.cache_get_many = my_cache_get_many
  60. yield cache_queries
  61. cache.cache_get = old_get
  62. cache.cache_get_many = old_get_many
  63. @contextmanager
  64. def queries_captured():
  65. '''
  66. Allow a user to capture just the queries executed during
  67. the with statement.
  68. '''
  69. queries = []
  70. def wrapper_execute(self, action, sql, params=()):
  71. start = time.time()
  72. try:
  73. return action(sql, params)
  74. finally:
  75. stop = time.time()
  76. duration = stop - start
  77. queries.append({
  78. 'sql': self.mogrify(sql, params),
  79. 'time': "%.3f" % duration,
  80. })
  81. old_execute = TimeTrackingCursor.execute
  82. old_executemany = TimeTrackingCursor.executemany
  83. def cursor_execute(self, sql, params=()):
  84. return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
  85. TimeTrackingCursor.execute = cursor_execute
  86. def cursor_executemany(self, sql, params=()):
  87. return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params)
  88. TimeTrackingCursor.executemany = cursor_executemany
  89. yield queries
  90. TimeTrackingCursor.execute = old_execute
  91. TimeTrackingCursor.executemany = old_executemany
  92. def find_key_by_email(address):
  93. from django.core.mail import outbox
  94. key_regex = re.compile("accounts/do_confirm/([a-f0-9]{40})>")
  95. for message in reversed(outbox):
  96. if address in message.to:
  97. return key_regex.search(message.body).groups()[0]
  98. def message_ids(result):
  99. return set(message['id'] for message in result['messages'])
  100. def message_stream_count(user_profile):
  101. return UserMessage.objects. \
  102. select_related("message"). \
  103. filter(user_profile=user_profile). \
  104. count()
  105. def most_recent_usermessage(user_profile):
  106. query = UserMessage.objects. \
  107. select_related("message"). \
  108. filter(user_profile=user_profile). \
  109. order_by('-message')
  110. return query[0] # Django does LIMIT here
  111. def most_recent_message(user_profile):
  112. usermessage = most_recent_usermessage(user_profile)
  113. return usermessage.message
  114. def get_user_messages(user_profile):
  115. query = UserMessage.objects. \
  116. select_related("message"). \
  117. filter(user_profile=user_profile). \
  118. order_by('message')
  119. return [um.message for um in query]
  120. class DummyObject:
  121. pass
  122. class DummyTornadoRequest:
  123. def __init__(self):
  124. self.connection = DummyObject()
  125. self.connection.stream = DummyStream()
  126. class DummyHandler(object):
  127. def __init__(self, assert_callback):
  128. self.assert_callback = assert_callback
  129. self.request = DummyTornadoRequest()
  130. # Mocks RequestHandler.async_callback, which wraps a callback to
  131. # handle exceptions. We return the callback as-is.
  132. def async_callback(self, cb):
  133. return cb
  134. def write(self, response):
  135. raise NotImplemented
  136. def zulip_finish(self, response, *ignore):
  137. if self.assert_callback:
  138. self.assert_callback(response)
  139. class DummySession(object):
  140. session_key = "0"
  141. class DummyStream:
  142. def closed(self):
  143. return False
  144. class POSTRequestMock(object):
  145. method = "POST"
  146. def __init__(self, post_data, user_profile, assert_callback=None):
  147. self.REQUEST = self.POST = post_data
  148. self.user = user_profile
  149. self._tornado_handler = DummyHandler(assert_callback)
  150. self.session = DummySession()
  151. self._log_data = {}
  152. self.META = {'PATH_INFO': 'test'}
  153. self._log_data = {}
  154. class AuthedTestCase(TestCase):
  155. # Helper because self.client.patch annoying requires you to urlencode
  156. def client_patch(self, url, info={}, **kwargs):
  157. info = urllib.urlencode(info)
  158. return self.client.patch(url, info, **kwargs)
  159. def client_put(self, url, info={}, **kwargs):
  160. info = urllib.urlencode(info)
  161. return self.client.put(url, info, **kwargs)
  162. def client_delete(self, url, info={}, **kwargs):
  163. info = urllib.urlencode(info)
  164. return self.client.delete(url, info, **kwargs)
  165. def login(self, email, password=None):
  166. if password is None:
  167. password = initial_password(email)
  168. return self.client.post('/accounts/login/',
  169. {'username':email, 'password':password})
  170. def register(self, username, password, domain="zulip.com"):
  171. self.client.post('/accounts/home/',
  172. {'email': username + "@" + domain})
  173. return self.submit_reg_form_for_user(username, password, domain=domain)
  174. def submit_reg_form_for_user(self, username, password, domain="zulip.com"):
  175. """
  176. Stage two of the two-step registration process.
  177. If things are working correctly the account should be fully
  178. registered after this call.
  179. """
  180. return self.client.post('/accounts/register/',
  181. {'full_name': username, 'password': password,
  182. 'key': find_key_by_email(username + '@' + domain),
  183. 'terms': True})
  184. def get_api_key(self, email):
  185. if email not in API_KEYS:
  186. API_KEYS[email] = get_user_profile_by_email(email).api_key
  187. return API_KEYS[email]
  188. def api_auth(self, email):
  189. credentials = "%s:%s" % (email, self.get_api_key(email))
  190. return {
  191. 'HTTP_AUTHORIZATION': 'Basic ' + base64.b64encode(credentials)
  192. }
  193. def get_streams(self, email):
  194. """
  195. Helper function to get the stream names for a user
  196. """
  197. user_profile = get_user_profile_by_email(email)
  198. subs = Subscription.objects.filter(
  199. user_profile = user_profile,
  200. active = True,
  201. recipient__type = Recipient.STREAM)
  202. return [get_display_recipient(sub.recipient) for sub in subs]
  203. def send_message(self, sender_name, recipient_list, message_type,
  204. content="test content", subject="test", **kwargs):
  205. sender = get_user_profile_by_email(sender_name)
  206. if message_type == Recipient.PERSONAL:
  207. message_type_name = "private"
  208. else:
  209. message_type_name = "stream"
  210. if isinstance(recipient_list, basestring):
  211. recipient_list = [recipient_list]
  212. (sending_client, _) = Client.objects.get_or_create(name="test suite")
  213. return check_send_message(
  214. sender, sending_client, message_type_name, recipient_list, subject,
  215. content, forged=False, forged_timestamp=None,
  216. forwarder_user_profile=sender, realm=sender.realm, **kwargs)
  217. def get_old_messages(self, anchor=1, num_before=100, num_after=100):
  218. post_params = {"anchor": anchor, "num_before": num_before,
  219. "num_after": num_after}
  220. result = self.client.post("/json/get_old_messages", dict(post_params))
  221. data = ujson.loads(result.content)
  222. return data['messages']
  223. def users_subscribed_to_stream(self, stream_name, realm_domain):
  224. realm = Realm.objects.get(domain=realm_domain)
  225. stream = Stream.objects.get(name=stream_name, realm=realm)
  226. recipient = Recipient.objects.get(type_id=stream.id, type=Recipient.STREAM)
  227. subscriptions = Subscription.objects.filter(recipient=recipient, active=True)
  228. return [subscription.user_profile for subscription in subscriptions]
  229. def assert_json_success(self, result):
  230. """
  231. Successful POSTs return a 200 and JSON of the form {"result": "success",
  232. "msg": ""}.
  233. """
  234. self.assertEqual(result.status_code, 200, result)
  235. json = ujson.loads(result.content)
  236. self.assertEqual(json.get("result"), "success")
  237. # We have a msg key for consistency with errors, but it typically has an
  238. # empty value.
  239. self.assertIn("msg", json)
  240. return json
  241. def get_json_error(self, result, status_code=400):
  242. self.assertEqual(result.status_code, status_code)
  243. json = ujson.loads(result.content)
  244. self.assertEqual(json.get("result"), "error")
  245. return json['msg']
  246. def assert_json_error(self, result, msg, status_code=400):
  247. """
  248. Invalid POSTs return an error status code and JSON of the form
  249. {"result": "error", "msg": "reason"}.
  250. """
  251. self.assertEqual(self.get_json_error(result, status_code=status_code), msg)
  252. def assert_length(self, queries, count, exact=False):
  253. actual_count = len(queries)
  254. if exact:
  255. return self.assertTrue(actual_count == count,
  256. "len(%s) == %s, != %s" % (queries, actual_count, count))
  257. return self.assertTrue(actual_count <= count,
  258. "len(%s) == %s, > %s" % (queries, actual_count, count))
  259. def assert_json_error_contains(self, result, msg_substring):
  260. self.assertIn(msg_substring, self.get_json_error(result))
  261. def fixture_data(self, type, action, file_type='json'):
  262. return open(os.path.join(os.path.dirname(__file__),
  263. "../fixtures/%s/%s_%s.%s" % (type, type, action,file_type))).read()
  264. # Subscribe to a stream directly
  265. def subscribe_to_stream(self, email, stream_name, realm=None):
  266. realm = Realm.objects.get(domain=resolve_email_to_domain(email))
  267. stream, _ = create_stream_if_needed(realm, stream_name)
  268. user_profile = get_user_profile_by_email(email)
  269. do_add_subscription(user_profile, stream, no_log=True)
  270. # Subscribe to a stream by making an API request
  271. def common_subscribe_to_streams(self, email, streams, extra_post_data = {}, invite_only=False):
  272. post_data = {'subscriptions': ujson.dumps([{"name": stream} for stream in streams]),
  273. 'invite_only': ujson.dumps(invite_only)}
  274. post_data.update(extra_post_data)
  275. result = self.client.post("/api/v1/users/me/subscriptions", post_data, **self.api_auth(email))
  276. return result
  277. def send_json_payload(self, email, url, payload, stream_name=None, **post_params):
  278. if stream_name != None:
  279. self.subscribe_to_stream(email, stream_name)
  280. result = self.client.post(url, payload, **post_params)
  281. self.assert_json_success(result)
  282. # Check the correct message was sent
  283. msg = Message.objects.filter().order_by('-id')[0]
  284. self.assertEqual(msg.sender.email, email)
  285. self.assertEqual(get_display_recipient(msg.recipient), stream_name)
  286. return msg