PageRenderTime 43ms CodeModel.GetById 12ms RepoModel.GetById 0ms app.codeStats 0ms

/django/test/utils.py

https://github.com/insane/django
Python | 419 lines | 329 code | 38 blank | 52 comment | 20 complexity | d55851d80046a502f97d438bddc6dc20 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. from contextlib import contextmanager
  2. import logging
  3. import re
  4. import sys
  5. import warnings
  6. from functools import wraps
  7. from xml.dom.minidom import parseString, Node
  8. from django.conf import settings, UserSettingsHolder
  9. from django.core import mail
  10. from django.core.signals import request_started
  11. from django.db import reset_queries
  12. from django.http import request
  13. from django.template import Template, loader, TemplateDoesNotExist
  14. from django.template.loaders import cached
  15. from django.test.signals import template_rendered, setting_changed
  16. from django.utils.encoding import force_str
  17. from django.utils import six
  18. from django.utils.translation import deactivate
  19. __all__ = (
  20. 'Approximate', 'ContextList', 'get_runner', 'override_settings',
  21. 'setup_test_environment', 'teardown_test_environment',
  22. )
  23. RESTORE_LOADERS_ATTR = '_original_template_source_loaders'
  24. class Approximate(object):
  25. def __init__(self, val, places=7):
  26. self.val = val
  27. self.places = places
  28. def __repr__(self):
  29. return repr(self.val)
  30. def __eq__(self, other):
  31. if self.val == other:
  32. return True
  33. return round(abs(self.val - other), self.places) == 0
  34. class ContextList(list):
  35. """A wrapper that provides direct key access to context items contained
  36. in a list of context objects.
  37. """
  38. def __getitem__(self, key):
  39. if isinstance(key, six.string_types):
  40. for subcontext in self:
  41. if key in subcontext:
  42. return subcontext[key]
  43. raise KeyError(key)
  44. else:
  45. return super(ContextList, self).__getitem__(key)
  46. def __contains__(self, key):
  47. try:
  48. self[key]
  49. except KeyError:
  50. return False
  51. return True
  52. def keys(self):
  53. """
  54. Flattened keys of subcontexts.
  55. """
  56. keys = set()
  57. for subcontext in self:
  58. for dict in subcontext:
  59. keys |= set(dict.keys())
  60. return keys
  61. def instrumented_test_render(self, context):
  62. """
  63. An instrumented Template render method, providing a signal
  64. that can be intercepted by the test system Client
  65. """
  66. template_rendered.send(sender=self, template=self, context=context)
  67. return self.nodelist.render(context)
  68. def setup_test_environment():
  69. """Perform any global pre-test setup. This involves:
  70. - Installing the instrumented test renderer
  71. - Set the email backend to the locmem email backend.
  72. - Setting the active locale to match the LANGUAGE_CODE setting.
  73. """
  74. Template._original_render = Template._render
  75. Template._render = instrumented_test_render
  76. # Storing previous values in the settings module itself is problematic.
  77. # Store them in arbitrary (but related) modules instead. See #20636.
  78. mail._original_email_backend = settings.EMAIL_BACKEND
  79. settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
  80. request._original_allowed_hosts = settings.ALLOWED_HOSTS
  81. settings.ALLOWED_HOSTS = ['*']
  82. mail.outbox = []
  83. deactivate()
  84. def teardown_test_environment():
  85. """Perform any global post-test teardown. This involves:
  86. - Restoring the original test renderer
  87. - Restoring the email sending functions
  88. """
  89. Template._render = Template._original_render
  90. del Template._original_render
  91. settings.EMAIL_BACKEND = mail._original_email_backend
  92. del mail._original_email_backend
  93. settings.ALLOWED_HOSTS = request._original_allowed_hosts
  94. del request._original_allowed_hosts
  95. del mail.outbox
  96. def get_runner(settings, test_runner_class=None):
  97. if not test_runner_class:
  98. test_runner_class = settings.TEST_RUNNER
  99. test_path = test_runner_class.split('.')
  100. # Allow for Python 2.5 relative paths
  101. if len(test_path) > 1:
  102. test_module_name = '.'.join(test_path[:-1])
  103. else:
  104. test_module_name = '.'
  105. test_module = __import__(test_module_name, {}, {}, force_str(test_path[-1]))
  106. test_runner = getattr(test_module, test_path[-1])
  107. return test_runner
  108. def setup_test_template_loader(templates_dict, use_cached_loader=False):
  109. """
  110. Changes Django to only find templates from within a dictionary (where each
  111. key is the template name and each value is the corresponding template
  112. content to return).
  113. Use meth:`restore_template_loaders` to restore the original loaders.
  114. """
  115. if hasattr(loader, RESTORE_LOADERS_ATTR):
  116. raise Exception("loader.%s already exists" % RESTORE_LOADERS_ATTR)
  117. def test_template_loader(template_name, template_dirs=None):
  118. "A custom template loader that loads templates from a dictionary."
  119. try:
  120. return (templates_dict[template_name], "test:%s" % template_name)
  121. except KeyError:
  122. raise TemplateDoesNotExist(template_name)
  123. if use_cached_loader:
  124. template_loader = cached.Loader(('test_template_loader',))
  125. template_loader._cached_loaders = (test_template_loader,)
  126. else:
  127. template_loader = test_template_loader
  128. setattr(loader, RESTORE_LOADERS_ATTR, loader.template_source_loaders)
  129. loader.template_source_loaders = (template_loader,)
  130. return template_loader
  131. def restore_template_loaders():
  132. """
  133. Restores the original template loaders after
  134. :meth:`setup_test_template_loader` has been run.
  135. """
  136. loader.template_source_loaders = getattr(loader, RESTORE_LOADERS_ATTR)
  137. delattr(loader, RESTORE_LOADERS_ATTR)
  138. class override_settings(object):
  139. """
  140. Acts as either a decorator, or a context manager. If it's a decorator it
  141. takes a function and returns a wrapped function. If it's a contextmanager
  142. it's used with the ``with`` statement. In either event entering/exiting
  143. are called before and after, respectively, the function/block is executed.
  144. """
  145. def __init__(self, **kwargs):
  146. self.options = kwargs
  147. def __enter__(self):
  148. self.enable()
  149. def __exit__(self, exc_type, exc_value, traceback):
  150. self.disable()
  151. def __call__(self, test_func):
  152. from django.test import SimpleTestCase
  153. if isinstance(test_func, type):
  154. if not issubclass(test_func, SimpleTestCase):
  155. raise Exception(
  156. "Only subclasses of Django SimpleTestCase can be decorated "
  157. "with override_settings")
  158. original_pre_setup = test_func._pre_setup
  159. original_post_teardown = test_func._post_teardown
  160. def _pre_setup(innerself):
  161. self.enable()
  162. original_pre_setup(innerself)
  163. def _post_teardown(innerself):
  164. original_post_teardown(innerself)
  165. self.disable()
  166. test_func._pre_setup = _pre_setup
  167. test_func._post_teardown = _post_teardown
  168. return test_func
  169. else:
  170. @wraps(test_func)
  171. def inner(*args, **kwargs):
  172. with self:
  173. return test_func(*args, **kwargs)
  174. return inner
  175. def enable(self):
  176. override = UserSettingsHolder(settings._wrapped)
  177. for key, new_value in self.options.items():
  178. setattr(override, key, new_value)
  179. self.wrapped = settings._wrapped
  180. settings._wrapped = override
  181. for key, new_value in self.options.items():
  182. setting_changed.send(sender=settings._wrapped.__class__,
  183. setting=key, value=new_value, enter=True)
  184. def disable(self):
  185. settings._wrapped = self.wrapped
  186. del self.wrapped
  187. for key in self.options:
  188. new_value = getattr(settings, key, None)
  189. setting_changed.send(sender=settings._wrapped.__class__,
  190. setting=key, value=new_value, enter=False)
  191. def compare_xml(want, got):
  192. """Tries to do a 'xml-comparison' of want and got. Plain string
  193. comparison doesn't always work because, for example, attribute
  194. ordering should not be important. Comment nodes are not considered in the
  195. comparison.
  196. Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
  197. """
  198. _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
  199. def norm_whitespace(v):
  200. return _norm_whitespace_re.sub(' ', v)
  201. def child_text(element):
  202. return ''.join([c.data for c in element.childNodes
  203. if c.nodeType == Node.TEXT_NODE])
  204. def children(element):
  205. return [c for c in element.childNodes
  206. if c.nodeType == Node.ELEMENT_NODE]
  207. def norm_child_text(element):
  208. return norm_whitespace(child_text(element))
  209. def attrs_dict(element):
  210. return dict(element.attributes.items())
  211. def check_element(want_element, got_element):
  212. if want_element.tagName != got_element.tagName:
  213. return False
  214. if norm_child_text(want_element) != norm_child_text(got_element):
  215. return False
  216. if attrs_dict(want_element) != attrs_dict(got_element):
  217. return False
  218. want_children = children(want_element)
  219. got_children = children(got_element)
  220. if len(want_children) != len(got_children):
  221. return False
  222. for want, got in zip(want_children, got_children):
  223. if not check_element(want, got):
  224. return False
  225. return True
  226. def first_node(document):
  227. for node in document.childNodes:
  228. if node.nodeType != Node.COMMENT_NODE:
  229. return node
  230. want, got = strip_quotes(want, got)
  231. want = want.replace('\\n','\n')
  232. got = got.replace('\\n','\n')
  233. # If the string is not a complete xml document, we may need to add a
  234. # root element. This allow us to compare fragments, like "<foo/><bar/>"
  235. if not want.startswith('<?xml'):
  236. wrapper = '<root>%s</root>'
  237. want = wrapper % want
  238. got = wrapper % got
  239. # Parse the want and got strings, and compare the parsings.
  240. want_root = first_node(parseString(want))
  241. got_root = first_node(parseString(got))
  242. return check_element(want_root, got_root)
  243. def strip_quotes(want, got):
  244. """
  245. Strip quotes of doctests output values:
  246. >>> strip_quotes("'foo'")
  247. "foo"
  248. >>> strip_quotes('"foo"')
  249. "foo"
  250. """
  251. def is_quoted_string(s):
  252. s = s.strip()
  253. return (len(s) >= 2
  254. and s[0] == s[-1]
  255. and s[0] in ('"', "'"))
  256. def is_quoted_unicode(s):
  257. s = s.strip()
  258. return (len(s) >= 3
  259. and s[0] == 'u'
  260. and s[1] == s[-1]
  261. and s[1] in ('"', "'"))
  262. if is_quoted_string(want) and is_quoted_string(got):
  263. want = want.strip()[1:-1]
  264. got = got.strip()[1:-1]
  265. elif is_quoted_unicode(want) and is_quoted_unicode(got):
  266. want = want.strip()[2:-1]
  267. got = got.strip()[2:-1]
  268. return want, got
  269. def str_prefix(s):
  270. return s % {'_': '' if six.PY3 else 'u'}
  271. class CaptureQueriesContext(object):
  272. """
  273. Context manager that captures queries executed by the specified connection.
  274. """
  275. def __init__(self, connection):
  276. self.connection = connection
  277. def __iter__(self):
  278. return iter(self.captured_queries)
  279. def __getitem__(self, index):
  280. return self.captured_queries[index]
  281. def __len__(self):
  282. return len(self.captured_queries)
  283. @property
  284. def captured_queries(self):
  285. return self.connection.queries[self.initial_queries:self.final_queries]
  286. def __enter__(self):
  287. self.use_debug_cursor = self.connection.use_debug_cursor
  288. self.connection.use_debug_cursor = True
  289. self.initial_queries = len(self.connection.queries)
  290. self.final_queries = None
  291. request_started.disconnect(reset_queries)
  292. return self
  293. def __exit__(self, exc_type, exc_value, traceback):
  294. self.connection.use_debug_cursor = self.use_debug_cursor
  295. request_started.connect(reset_queries)
  296. if exc_type is not None:
  297. return
  298. self.final_queries = len(self.connection.queries)
  299. class IgnoreDeprecationWarningsMixin(object):
  300. warning_classes = [DeprecationWarning]
  301. def setUp(self):
  302. super(IgnoreDeprecationWarningsMixin, self).setUp()
  303. self.catch_warnings = warnings.catch_warnings()
  304. self.catch_warnings.__enter__()
  305. for warning_class in self.warning_classes:
  306. warnings.filterwarnings("ignore", category=warning_class)
  307. def tearDown(self):
  308. self.catch_warnings.__exit__(*sys.exc_info())
  309. super(IgnoreDeprecationWarningsMixin, self).tearDown()
  310. class IgnorePendingDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  311. warning_classes = [PendingDeprecationWarning]
  312. class IgnoreAllDeprecationWarningsMixin(IgnoreDeprecationWarningsMixin):
  313. warning_classes = [PendingDeprecationWarning, DeprecationWarning]
  314. @contextmanager
  315. def patch_logger(logger_name, log_level):
  316. """
  317. Context manager that takes a named logger and the logging level
  318. and provides a simple mock-like list of messages received
  319. """
  320. calls = []
  321. def replacement(msg):
  322. calls.append(msg)
  323. logger = logging.getLogger(logger_name)
  324. orig = getattr(logger, log_level)
  325. setattr(logger, log_level, replacement)
  326. try:
  327. yield calls
  328. finally:
  329. setattr(logger, log_level, orig)