PageRenderTime 478ms CodeModel.GetById 43ms RepoModel.GetById 2ms app.codeStats 0ms

/packages/celery/celery/tests/utils.py

https://github.com/mozilla/kuma-lib
Python | 485 lines | 447 code | 31 blank | 7 comment | 26 complexity | 57ebeef0541231237dc978639891c704 MD5 | raw file
  1. from __future__ import absolute_import
  2. try:
  3. import unittest
  4. unittest.skip
  5. from unittest.util import safe_repr, unorderable_list_difference
  6. except AttributeError:
  7. import unittest2 as unittest
  8. from unittest2.util import safe_repr, unorderable_list_difference # noqa
  9. import importlib
  10. import logging
  11. import os
  12. import re
  13. import sys
  14. import time
  15. import warnings
  16. try:
  17. import __builtin__ as builtins
  18. except ImportError: # py3k
  19. import builtins # noqa
  20. from functools import wraps
  21. from contextlib import contextmanager
  22. import mock
  23. from nose import SkipTest
  24. from ..app import app_or_default
  25. from ..utils import noop
  26. from ..utils.compat import WhateverIO, LoggerAdapter
  27. from .compat import catch_warnings
  28. class Mock(mock.Mock):
  29. def __init__(self, *args, **kwargs):
  30. attrs = kwargs.pop("attrs", None) or {}
  31. super(Mock, self).__init__(*args, **kwargs)
  32. for attr_name, attr_value in attrs.items():
  33. setattr(self, attr_name, attr_value)
  34. def skip_unless_module(module):
  35. def _inner(fun):
  36. @wraps(fun)
  37. def __inner(*args, **kwargs):
  38. try:
  39. importlib.import_module(module)
  40. except ImportError:
  41. raise SkipTest("Does not have %s" % (module, ))
  42. return fun(*args, **kwargs)
  43. return __inner
  44. return _inner
  45. # -- adds assertWarns from recent unittest2, not in Python 2.7.
  46. class _AssertRaisesBaseContext(object):
  47. def __init__(self, expected, test_case, callable_obj=None,
  48. expected_regex=None):
  49. self.expected = expected
  50. self.failureException = test_case.failureException
  51. self.obj_name = None
  52. if isinstance(expected_regex, basestring):
  53. expected_regex = re.compile(expected_regex)
  54. self.expected_regex = expected_regex
  55. class _AssertWarnsContext(_AssertRaisesBaseContext):
  56. """A context manager used to implement TestCase.assertWarns* methods."""
  57. def __enter__(self):
  58. # The __warningregistry__'s need to be in a pristine state for tests
  59. # to work properly.
  60. warnings.resetwarnings()
  61. for v in sys.modules.values():
  62. if getattr(v, '__warningregistry__', None):
  63. v.__warningregistry__ = {}
  64. self.warnings_manager = catch_warnings(record=True)
  65. self.warnings = self.warnings_manager.__enter__()
  66. warnings.simplefilter("always", self.expected)
  67. return self
  68. def __exit__(self, exc_type, exc_value, tb):
  69. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  70. if exc_type is not None:
  71. # let unexpected exceptions pass through
  72. return
  73. try:
  74. exc_name = self.expected.__name__
  75. except AttributeError:
  76. exc_name = str(self.expected)
  77. first_matching = None
  78. for m in self.warnings:
  79. w = m.message
  80. if not isinstance(w, self.expected):
  81. continue
  82. if first_matching is None:
  83. first_matching = w
  84. if (self.expected_regex is not None and
  85. not self.expected_regex.search(str(w))):
  86. continue
  87. # store warning for later retrieval
  88. self.warning = w
  89. self.filename = m.filename
  90. self.lineno = m.lineno
  91. return
  92. # Now we simply try to choose a helpful failure message
  93. if first_matching is not None:
  94. raise self.failureException('%r does not match %r' %
  95. (self.expected_regex.pattern, str(first_matching)))
  96. if self.obj_name:
  97. raise self.failureException("%s not triggered by %s"
  98. % (exc_name, self.obj_name))
  99. else:
  100. raise self.failureException("%s not triggered"
  101. % exc_name)
  102. class Case(unittest.TestCase):
  103. def assertWarns(self, expected_warning):
  104. return _AssertWarnsContext(expected_warning, self, None)
  105. def assertWarnsRegex(self, expected_warning, expected_regex):
  106. return _AssertWarnsContext(expected_warning, self,
  107. None, expected_regex)
  108. def assertDictContainsSubset(self, expected, actual, msg=None):
  109. missing, mismatched = [], []
  110. for key, value in expected.iteritems():
  111. if key not in actual:
  112. missing.append(key)
  113. elif value != actual[key]:
  114. mismatched.append("%s, expected: %s, actual: %s" % (
  115. safe_repr(key), safe_repr(value),
  116. safe_repr(actual[key])))
  117. if not (missing or mismatched):
  118. return
  119. standard_msg = ""
  120. if missing:
  121. standard_msg = "Missing: %s" % ','.join(map(safe_repr, missing))
  122. if mismatched:
  123. if standard_msg:
  124. standard_msg += "; "
  125. standard_msg += "Mismatched values: %s" % (
  126. ','.join(mismatched))
  127. self.fail(self._formatMessage(msg, standard_msg))
  128. def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
  129. try:
  130. expected = sorted(expected_seq)
  131. actual = sorted(actual_seq)
  132. except TypeError:
  133. # Unsortable items (example: set(), complex(), ...)
  134. expected = list(expected_seq)
  135. actual = list(actual_seq)
  136. missing, unexpected = unorderable_list_difference(
  137. expected, actual)
  138. else:
  139. return self.assertSequenceEqual(expected, actual, msg=msg)
  140. errors = []
  141. if missing:
  142. errors.append('Expected, but missing:\n %s' % (
  143. safe_repr(missing)))
  144. if unexpected:
  145. errors.append('Unexpected, but present:\n %s' % (
  146. safe_repr(unexpected)))
  147. if errors:
  148. standardMsg = '\n'.join(errors)
  149. self.fail(self._formatMessage(msg, standardMsg))
  150. class AppCase(Case):
  151. def setUp(self):
  152. from ..app import current_app
  153. from ..backends.cache import CacheBackend, DummyClient
  154. app = self.app = self._current_app = current_app()
  155. if isinstance(app.backend, CacheBackend):
  156. if isinstance(app.backend.client, DummyClient):
  157. app.backend.client.cache.clear()
  158. app.backend._cache.clear()
  159. self.setup()
  160. def tearDown(self):
  161. self.teardown()
  162. self._current_app.set_current()
  163. def setup(self):
  164. pass
  165. def teardown(self):
  166. pass
  167. def get_handlers(logger):
  168. if isinstance(logger, LoggerAdapter):
  169. return logger.logger.handlers
  170. return logger.handlers
  171. def set_handlers(logger, new_handlers):
  172. if isinstance(logger, LoggerAdapter):
  173. logger.logger.handlers = new_handlers
  174. logger.handlers = new_handlers
  175. @contextmanager
  176. def wrap_logger(logger, loglevel=logging.ERROR):
  177. old_handlers = get_handlers(logger)
  178. sio = WhateverIO()
  179. siohandler = logging.StreamHandler(sio)
  180. set_handlers(logger, [siohandler])
  181. yield sio
  182. set_handlers(logger, old_handlers)
  183. @contextmanager
  184. def eager_tasks():
  185. app = app_or_default()
  186. prev = app.conf.CELERY_ALWAYS_EAGER
  187. app.conf.CELERY_ALWAYS_EAGER = True
  188. yield True
  189. app.conf.CELERY_ALWAYS_EAGER = prev
  190. def with_eager_tasks(fun):
  191. @wraps(fun)
  192. def _inner(*args, **kwargs):
  193. app = app_or_default()
  194. prev = app.conf.CELERY_ALWAYS_EAGER
  195. app.conf.CELERY_ALWAYS_EAGER = True
  196. try:
  197. return fun(*args, **kwargs)
  198. finally:
  199. app.conf.CELERY_ALWAYS_EAGER = prev
  200. def with_environ(env_name, env_value):
  201. def _envpatched(fun):
  202. @wraps(fun)
  203. def _patch_environ(*args, **kwargs):
  204. prev_val = os.environ.get(env_name)
  205. os.environ[env_name] = env_value
  206. try:
  207. return fun(*args, **kwargs)
  208. finally:
  209. if prev_val is not None:
  210. os.environ[env_name] = prev_val
  211. return _patch_environ
  212. return _envpatched
  213. def sleepdeprived(module=time):
  214. def _sleepdeprived(fun):
  215. @wraps(fun)
  216. def __sleepdeprived(*args, **kwargs):
  217. old_sleep = module.sleep
  218. module.sleep = noop
  219. try:
  220. return fun(*args, **kwargs)
  221. finally:
  222. module.sleep = old_sleep
  223. return __sleepdeprived
  224. return _sleepdeprived
  225. def skip_if_environ(env_var_name):
  226. def _wrap_test(fun):
  227. @wraps(fun)
  228. def _skips_if_environ(*args, **kwargs):
  229. if os.environ.get(env_var_name):
  230. raise SkipTest("SKIP %s: %s set\n" % (
  231. fun.__name__, env_var_name))
  232. return fun(*args, **kwargs)
  233. return _skips_if_environ
  234. return _wrap_test
  235. def skip_if_quick(fun):
  236. return skip_if_environ("QUICKTEST")(fun)
  237. def _skip_test(reason, sign):
  238. def _wrap_test(fun):
  239. @wraps(fun)
  240. def _skipped_test(*args, **kwargs):
  241. raise SkipTest("%s: %s" % (sign, reason))
  242. return _skipped_test
  243. return _wrap_test
  244. def todo(reason):
  245. """TODO test decorator."""
  246. return _skip_test(reason, "TODO")
  247. def skip(reason):
  248. """Skip test decorator."""
  249. return _skip_test(reason, "SKIP")
  250. def skip_if(predicate, reason):
  251. """Skip test if predicate is :const:`True`."""
  252. def _inner(fun):
  253. return predicate and skip(reason)(fun) or fun
  254. return _inner
  255. def skip_unless(predicate, reason):
  256. """Skip test if predicate is :const:`False`."""
  257. return skip_if(not predicate, reason)
  258. # Taken from
  259. # http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py
  260. @contextmanager
  261. def mask_modules(*modnames):
  262. """Ban some modules from being importable inside the context
  263. For example:
  264. >>> with missing_modules("sys"):
  265. ... try:
  266. ... import sys
  267. ... except ImportError:
  268. ... print "sys not found"
  269. sys not found
  270. >>> import sys
  271. >>> sys.version
  272. (2, 5, 2, 'final', 0)
  273. """
  274. realimport = builtins.__import__
  275. def myimp(name, *args, **kwargs):
  276. if name in modnames:
  277. raise ImportError("No module named %s" % name)
  278. else:
  279. return realimport(name, *args, **kwargs)
  280. builtins.__import__ = myimp
  281. yield True
  282. builtins.__import__ = realimport
  283. @contextmanager
  284. def override_stdouts():
  285. """Override `sys.stdout` and `sys.stderr` with `WhateverIO`."""
  286. prev_out, prev_err = sys.stdout, sys.stderr
  287. mystdout, mystderr = WhateverIO(), WhateverIO()
  288. sys.stdout = sys.__stdout__ = mystdout
  289. sys.stderr = sys.__stderr__ = mystderr
  290. yield mystdout, mystderr
  291. sys.stdout = sys.__stdout__ = prev_out
  292. sys.stderr = sys.__stderr__ = prev_err
  293. def patch(module, name, mocked):
  294. module = importlib.import_module(module)
  295. def _patch(fun):
  296. @wraps(fun)
  297. def __patched(*args, **kwargs):
  298. prev = getattr(module, name)
  299. setattr(module, name, mocked)
  300. try:
  301. return fun(*args, **kwargs)
  302. finally:
  303. setattr(module, name, prev)
  304. return __patched
  305. return _patch
  306. @contextmanager
  307. def platform_pyimp(replace=None):
  308. import platform
  309. has_prev = hasattr(platform, "python_implementation")
  310. prev = getattr(platform, "python_implementation", None)
  311. if replace:
  312. platform.python_implementation = replace
  313. else:
  314. try:
  315. delattr(platform, "python_implementation")
  316. except AttributeError:
  317. pass
  318. yield
  319. if prev is not None:
  320. platform.python_implementation = prev
  321. if not has_prev:
  322. try:
  323. delattr(platform, "python_implementation")
  324. except AttributeError:
  325. pass
  326. @contextmanager
  327. def sys_platform(value):
  328. prev, sys.platform = sys.platform, value
  329. yield
  330. sys.platform = prev
  331. @contextmanager
  332. def pypy_version(value=None):
  333. has_prev = hasattr(sys, "pypy_version_info")
  334. prev = getattr(sys, "pypy_version_info", None)
  335. if value:
  336. sys.pypy_version_info = value
  337. else:
  338. try:
  339. delattr(sys, "pypy_version_info")
  340. except AttributeError:
  341. pass
  342. yield
  343. if prev is not None:
  344. sys.pypy_version_info = prev
  345. if not has_prev:
  346. try:
  347. delattr(sys, "pypy_version_info")
  348. except AttributeError:
  349. pass
  350. @contextmanager
  351. def reset_modules(*modules):
  352. prev = dict((k, sys.modules.pop(k)) for k in modules if k in sys.modules)
  353. yield
  354. sys.modules.update(prev)
  355. @contextmanager
  356. def patch_modules(*modules):
  357. from types import ModuleType
  358. prev = {}
  359. for mod in modules:
  360. prev[mod], sys.modules[mod] = sys.modules[mod], ModuleType(mod)
  361. yield
  362. for name, mod in prev.iteritems():
  363. if mod is None:
  364. sys.modules.pop(name, None)
  365. else:
  366. sys.modules[name] = mod