PageRenderTime 54ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 1ms

/astropy/tests/helper.py

https://gitlab.com/Rockyspade/astropy
Python | 668 lines | 613 code | 16 blank | 39 comment | 19 complexity | 08c14eaab5af3e604c6a8cdef69ca07a MD5 | raw file
  1. # Licensed under a 3-clause BSD style license - see LICENSE.rst
  2. """
  3. This module provides the tools used to internally run the astropy test suite
  4. from the installed astropy. It makes use of the `pytest` testing framework.
  5. """
  6. from __future__ import (absolute_import, division, print_function,
  7. unicode_literals)
  8. from ..extern import six
  9. from ..extern.six.moves import cPickle as pickle
  10. import errno
  11. import shlex
  12. import sys
  13. import base64
  14. import zlib
  15. import functools
  16. import multiprocessing
  17. import os
  18. import shutil
  19. import tempfile
  20. import types
  21. import warnings
  22. try:
  23. # Import pkg_resources to prevent it from issuing warnings upon being
  24. # imported from within py.test. See
  25. # https://github.com/astropy/astropy/pull/537 for a detailed explanation.
  26. import pkg_resources
  27. except ImportError:
  28. pass
  29. from .. import test
  30. from ..utils.exceptions import (AstropyWarning,
  31. AstropyDeprecationWarning,
  32. AstropyPendingDeprecationWarning)
  33. from ..config import configuration
  34. if os.environ.get('ASTROPY_USE_SYSTEM_PYTEST') or '_pytest' in sys.modules:
  35. import pytest
  36. else:
  37. from ..extern import pytest as extern_pytest
  38. if six.PY3:
  39. exec("def do_exec_def(co, loc): exec(co, loc)\n")
  40. extern_pytest.do_exec = do_exec_def
  41. unpacked_sources = extern_pytest.sources.encode("ascii")
  42. unpacked_sources = pickle.loads(
  43. zlib.decompress(base64.decodebytes(unpacked_sources)), encoding='utf-8')
  44. elif six.PY2:
  45. exec("def do_exec_def(co, loc): exec co in loc\n")
  46. extern_pytest.do_exec = do_exec_def
  47. unpacked_sources = pickle.loads(
  48. zlib.decompress(base64.decodestring(extern_pytest.sources)))
  49. importer = extern_pytest.DictImporter(unpacked_sources)
  50. sys.meta_path.insert(0, importer)
  51. pytest = importer.load_module(str('pytest'))
  52. # Monkey-patch py.test to work around issue #811
  53. # https://github.com/astropy/astropy/issues/811
  54. from _pytest.assertion import rewrite as _rewrite
  55. _orig_write_pyc = _rewrite._write_pyc
  56. def _write_pyc_wrapper(*args):
  57. """Wraps the internal _write_pyc method in py.test to recognize
  58. PermissionErrors and just stop trying to cache its generated pyc files if
  59. it can't write them to the __pycache__ directory.
  60. When py.test scans for test modules, it actually rewrites the bytecode
  61. of each test module it discovers--this is how it manages to add extra
  62. instrumentation to the assert builtin. Normally it caches these
  63. rewritten bytecode files--``_write_pyc()`` is just a function that handles
  64. writing the rewritten pyc file to the cache. If it returns ``False`` for
  65. any reason py.test will stop trying to cache the files altogether. The
  66. original function catches some cases, but it has a long-standing bug of
  67. not catching permission errors on the ``__pycache__`` directory in Python
  68. 3. Hence this patch.
  69. """
  70. try:
  71. return _orig_write_pyc(*args)
  72. except IOError as e:
  73. if e.errno == errno.EACCES:
  74. return False
  75. _rewrite._write_pyc = _write_pyc_wrapper
  76. # pytest marker to mark tests which get data from the web
  77. remote_data = pytest.mark.remote_data
  78. class TestRunner(object):
  79. def __init__(self, base_path):
  80. self.base_path = base_path
  81. def run_tests(self, package=None, test_path=None, args=None, plugins=None,
  82. verbose=False, pastebin=None, remote_data=False, pep8=False,
  83. pdb=False, coverage=False, open_files=False, parallel=0,
  84. docs_path=None, skip_docs=False, repeat=None):
  85. """
  86. The docstring for this method lives in astropy/__init__.py:test
  87. """
  88. if coverage:
  89. warnings.warn(
  90. "The coverage option is ignored on run_tests, since it "
  91. "can not be made to work in that context. Use "
  92. "'python setup.py test --coverage' instead.",
  93. AstropyWarning)
  94. all_args = []
  95. if package is None:
  96. package_path = self.base_path
  97. else:
  98. package_path = os.path.join(self.base_path,
  99. package.replace('.', os.path.sep))
  100. if not os.path.isdir(package_path):
  101. raise ValueError('Package not found: {0}'.format(package))
  102. if docs_path is not None and not skip_docs:
  103. if package is not None:
  104. docs_path = os.path.join(
  105. docs_path, package.replace('.', os.path.sep))
  106. if not os.path.exists(docs_path):
  107. warnings.warn(
  108. "Can not test .rst docs, since docs path "
  109. "({0}) does not exist.".format(docs_path))
  110. docs_path = None
  111. if test_path:
  112. base, ext = os.path.splitext(test_path)
  113. if ext in ('.rst', ''):
  114. if docs_path is None:
  115. # This shouldn't happen from "python setup.py test"
  116. raise ValueError(
  117. "Can not test .rst files without a docs_path "
  118. "specified.")
  119. abs_docs_path = os.path.abspath(docs_path)
  120. abs_test_path = os.path.abspath(
  121. os.path.join(abs_docs_path, os.pardir, test_path))
  122. common = os.path.commonprefix((abs_docs_path, abs_test_path))
  123. if os.path.exists(abs_test_path) and common == abs_docs_path:
  124. # Since we aren't testing any Python files within
  125. # the astropy tree, we need to forcibly load the
  126. # astropy py.test plugins, and then turn on the
  127. # doctest_rst plugin.
  128. all_args.extend(['-p', 'astropy.tests.pytest_plugins',
  129. '--doctest-rst'])
  130. test_path = abs_test_path
  131. if not (os.path.isdir(test_path) or ext in ('.py', '.rst')):
  132. raise ValueError("Test path must be a directory or a path to "
  133. "a .py or .rst file")
  134. all_args.append(test_path)
  135. else:
  136. all_args.append(package_path)
  137. if docs_path is not None and not skip_docs:
  138. all_args.extend([docs_path, '--doctest-rst'])
  139. # add any additional args entered by the user
  140. if args is not None:
  141. all_args.extend(
  142. shlex.split(args, posix=not sys.platform.startswith('win')))
  143. # add verbosity flag
  144. if verbose:
  145. all_args.append('-v')
  146. # turn on pastebin output
  147. if pastebin is not None:
  148. if pastebin in ['failed', 'all']:
  149. all_args.append('--pastebin={0}'.format(pastebin))
  150. else:
  151. raise ValueError("pastebin should be 'failed' or 'all'")
  152. # run @remote_data tests
  153. if remote_data:
  154. all_args.append('--remote-data')
  155. if pep8:
  156. try:
  157. import pytest_pep8
  158. except ImportError:
  159. raise ImportError('PEP8 checking requires pytest-pep8 plugin: '
  160. 'http://pypi.python.org/pypi/pytest-pep8')
  161. else:
  162. all_args.extend(['--pep8', '-k', 'pep8'])
  163. # activate post-mortem PDB for failing tests
  164. if pdb:
  165. all_args.append('--pdb')
  166. # check for opened files after each test
  167. if open_files:
  168. if parallel != 0:
  169. raise SystemError(
  170. "open file detection may not be used in conjunction with "
  171. "parallel testing.")
  172. try:
  173. import psutil
  174. except ImportError:
  175. raise SystemError(
  176. "open file detection requested, but psutil package "
  177. "is not installed.")
  178. all_args.append('--open-files')
  179. print("Checking for unclosed files")
  180. if parallel != 0:
  181. try:
  182. import xdist
  183. except ImportError:
  184. raise ImportError(
  185. 'Parallel testing requires the pytest-xdist plugin '
  186. 'https://pypi.python.org/pypi/pytest-xdist')
  187. try:
  188. parallel = int(parallel)
  189. except ValueError:
  190. raise ValueError(
  191. "parallel must be an int, got {0}".format(parallel))
  192. if parallel < 0:
  193. parallel = multiprocessing.cpu_count()
  194. all_args.extend(['-n', six.text_type(parallel)])
  195. if repeat:
  196. all_args.append('--repeat={0}'.format(repeat))
  197. if six.PY2:
  198. all_args = [x.encode('utf-8') for x in all_args]
  199. # override the config locations to not make a new directory nor use
  200. # existing cache or config
  201. xdg_config_home = os.environ.get('XDG_CONFIG_HOME')
  202. xdg_cache_home = os.environ.get('XDG_CACHE_HOME')
  203. astropy_config = tempfile.mkdtemp('astropy_config')
  204. astropy_cache = tempfile.mkdtemp('astropy_cache')
  205. os.environ[str('XDG_CONFIG_HOME')] = str(astropy_config)
  206. os.environ[str('XDG_CACHE_HOME')] = str(astropy_cache)
  207. os.mkdir(os.path.join(os.environ['XDG_CONFIG_HOME'], 'astropy'))
  208. os.mkdir(os.path.join(os.environ['XDG_CACHE_HOME'], 'astropy'))
  209. # To fully force configuration reloading from a different file (in this
  210. # case our default one in a temp directory), clear the config object
  211. # cache.
  212. configuration._cfgobjs.clear()
  213. # This prevents cyclical import problems that make it
  214. # impossible to test packages that define Table types on their
  215. # own.
  216. from ..table import Table
  217. try:
  218. result = pytest.main(args=all_args, plugins=plugins)
  219. finally:
  220. shutil.rmtree(os.environ['XDG_CONFIG_HOME'])
  221. shutil.rmtree(os.environ['XDG_CACHE_HOME'])
  222. if xdg_config_home is not None:
  223. os.environ[str('XDG_CONFIG_HOME')] = xdg_config_home
  224. else:
  225. del os.environ['XDG_CONFIG_HOME']
  226. if xdg_cache_home is not None:
  227. os.environ[str('XDG_CACHE_HOME')] = xdg_cache_home
  228. else:
  229. del os.environ['XDG_CACHE_HOME']
  230. configuration._cfgobjs.clear()
  231. return result
  232. run_tests.__doc__ = test.__doc__
  233. # This is for Python 2.x and 3.x compatibility. distutils expects
  234. # options to all be byte strings on Python 2 and Unicode strings on
  235. # Python 3.
  236. def _fix_user_options(options):
  237. def to_str_or_none(x):
  238. if x is None:
  239. return None
  240. return str(x)
  241. return [tuple(to_str_or_none(x) for x in y) for y in options]
  242. def _save_coverage(cov, result, rootdir, testing_path):
  243. """
  244. This method is called after the tests have been run in coverage mode
  245. to cleanup and then save the coverage data and report.
  246. """
  247. from ..utils.console import color_print
  248. if result != 0:
  249. return
  250. # The coverage report includes the full path to the temporary
  251. # directory, so we replace all the paths with the true source
  252. # path. This means that the coverage line-by-line report will only
  253. # be correct for Python 2 code (since the Python 3 code will be
  254. # different in the build directory from the source directory as
  255. # long as 2to3 is needed). Therefore we only do this fix for
  256. # Python 2.x.
  257. if six.PY2:
  258. d = cov.data
  259. cov._harvest_data()
  260. for key in d.lines.keys():
  261. new_path = os.path.relpath(
  262. os.path.realpath(key),
  263. os.path.realpath(testing_path))
  264. new_path = os.path.abspath(
  265. os.path.join(rootdir, new_path))
  266. d.lines[new_path] = d.lines.pop(key)
  267. color_print('Saving coverage data in .coverage...', 'green')
  268. cov.save()
  269. color_print('Saving HTML coverage report in htmlcov...', 'green')
  270. cov.html_report(directory=os.path.join(rootdir, 'htmlcov'))
  271. class raises(object):
  272. """
  273. A decorator to mark that a test should raise a given exception.
  274. Use as follows::
  275. @raises(ZeroDivisionError)
  276. def test_foo():
  277. x = 1/0
  278. This can also be used a context manager, in which case it is just an alias
  279. for the `pytest.raises` context manager (because the two have the same name
  280. this help avoid confusion by being flexible).
  281. """
  282. # pep-8 naming exception -- this is a decorator class
  283. def __init__(self, exc):
  284. self._exc = exc
  285. self._ctx = None
  286. def __call__(self, func):
  287. @functools.wraps(func)
  288. def run_raises_test(*args, **kwargs):
  289. pytest.raises(self._exc, func, *args, **kwargs)
  290. return run_raises_test
  291. def __enter__(self):
  292. self._ctx = pytest.raises(self._exc)
  293. return self._ctx.__enter__()
  294. def __exit__(self, *exc_info):
  295. return self._ctx.__exit__(*exc_info)
  296. _deprecations_as_exceptions = False
  297. _include_astropy_deprecations = True
  298. def enable_deprecations_as_exceptions(include_astropy_deprecations=True):
  299. """
  300. Turn on the feature that turns deprecations into exceptions.
  301. """
  302. global _deprecations_as_exceptions
  303. _deprecations_as_exceptions = True
  304. global _include_astropy_deprecations
  305. _include_astropy_deprecations = include_astropy_deprecations
  306. def treat_deprecations_as_exceptions():
  307. """
  308. Turn all DeprecationWarnings (which indicate deprecated uses of
  309. Python itself or Numpy, but not within Astropy, where we use our
  310. own deprecation warning class) into exceptions so that we find
  311. out about them early.
  312. This completely resets the warning filters and any "already seen"
  313. warning state.
  314. """
  315. if not _deprecations_as_exceptions:
  316. return
  317. # First, totally reset the warning state
  318. for module in list(six.itervalues(sys.modules)):
  319. # We don't want to deal with six.MovedModules, only "real"
  320. # modules.
  321. if (isinstance(module, types.ModuleType) and
  322. hasattr(module, '__warningregistry__')):
  323. del module.__warningregistry__
  324. warnings.resetwarnings()
  325. # Hide the next couple of DeprecationWarnings
  326. warnings.simplefilter('ignore', DeprecationWarning)
  327. # Here's the wrinkle: a couple of our third-party dependencies
  328. # (py.test and scipy) are still using deprecated features
  329. # themselves, and we'd like to ignore those. Fortunately, those
  330. # show up only at import time, so if we import those things *now*,
  331. # before we turn the warnings into exceptions, we're golden.
  332. try:
  333. # A deprecated stdlib module used by py.test
  334. import compiler
  335. except ImportError:
  336. pass
  337. try:
  338. import scipy
  339. except ImportError:
  340. pass
  341. # Now, start over again with the warning filters
  342. warnings.resetwarnings()
  343. # Now, turn DeprecationWarnings into exceptions
  344. warnings.filterwarnings("error", ".*", DeprecationWarning)
  345. # Only turn astropy deprecation warnings into exceptions if requested
  346. if _include_astropy_deprecations:
  347. warnings.filterwarnings("error", ".*", AstropyDeprecationWarning)
  348. warnings.filterwarnings("error", ".*", AstropyPendingDeprecationWarning)
  349. if sys.version_info[:2] == (2, 6):
  350. # py.test's warning.showwarning does not include the line argument
  351. # on Python 2.6, so we need to explicitly ignore this warning.
  352. warnings.filterwarnings(
  353. "always",
  354. r"functions overriding warnings\.showwarning\(\) must support "
  355. r"the 'line' argument",
  356. DeprecationWarning)
  357. if sys.version_info[:2] >= (3, 4):
  358. # py.test reads files with the 'U' flag, which is now
  359. # deprecated in Python 3.4.
  360. warnings.filterwarnings(
  361. "always",
  362. r"'U' mode is deprecated",
  363. DeprecationWarning)
  364. # BeautifulSoup4 triggers a DeprecationWarning in stdlib's
  365. # html module.x
  366. warnings.filterwarnings(
  367. "always",
  368. "The strict argument and mode are deprecated.",
  369. DeprecationWarning)
  370. warnings.filterwarnings(
  371. "always",
  372. "The value of convert_charrefs will become True in 3.5. "
  373. "You are encouraged to set the value explicitly.",
  374. DeprecationWarning)
  375. class catch_warnings(warnings.catch_warnings):
  376. """
  377. A high-powered version of warnings.catch_warnings to use for testing
  378. and to make sure that there is no dependence on the order in which
  379. the tests are run.
  380. This completely blitzes any memory of any warnings that have
  381. appeared before so that all warnings will be caught and displayed.
  382. *args is a set of warning classes to collect. If no arguments are
  383. provided, all warnings are collected.
  384. Use as follows::
  385. with catch_warnings(MyCustomWarning) as w:
  386. do.something.bad()
  387. assert len(w) > 0
  388. """
  389. def __init__(self, *classes):
  390. super(catch_warnings, self).__init__(record=True)
  391. self.classes = classes
  392. def __enter__(self):
  393. warning_list = super(catch_warnings, self).__enter__()
  394. treat_deprecations_as_exceptions()
  395. if len(self.classes) == 0:
  396. warnings.simplefilter('always')
  397. else:
  398. warnings.simplefilter('ignore')
  399. for cls in self.classes:
  400. warnings.simplefilter('always', cls)
  401. return warning_list
  402. def __exit__(self, type, value, traceback):
  403. treat_deprecations_as_exceptions()
  404. def assert_follows_unicode_guidelines(
  405. x, roundtrip=None):
  406. """
  407. Test that an object follows our Unicode policy. See
  408. "Unicode guidelines" in the coding guidelines.
  409. Parameters
  410. ----------
  411. x : object
  412. The instance to test
  413. roundtrip : module, optional
  414. When provided, this namespace will be used to evaluate
  415. ``repr(x)`` and ensure that it roundtrips. It will also
  416. ensure that ``__bytes__(x)`` and ``__unicode__(x)`` roundtrip.
  417. If not provided, no roundtrip testing will be performed.
  418. """
  419. from .. import conf
  420. from ..extern import six
  421. with conf.set_temp('unicode_output', False):
  422. bytes_x = bytes(x)
  423. unicode_x = six.text_type(x)
  424. repr_x = repr(x)
  425. assert isinstance(bytes_x, bytes)
  426. bytes_x.decode('ascii')
  427. assert isinstance(unicode_x, six.text_type)
  428. unicode_x.encode('ascii')
  429. assert isinstance(repr_x, six.string_types)
  430. if isinstance(repr_x, bytes):
  431. repr_x.decode('ascii')
  432. else:
  433. repr_x.encode('ascii')
  434. if roundtrip is not None:
  435. assert x.__class__(bytes_x) == x
  436. assert x.__class__(unicode_x) == x
  437. assert eval(repr_x, roundtrip) == x
  438. with conf.set_temp('unicode_output', True):
  439. bytes_x = bytes(x)
  440. unicode_x = six.text_type(x)
  441. repr_x = repr(x)
  442. assert isinstance(bytes_x, bytes)
  443. bytes_x.decode('ascii')
  444. assert isinstance(unicode_x, six.text_type)
  445. assert isinstance(repr_x, six.string_types)
  446. if isinstance(repr_x, bytes):
  447. repr_x.decode('ascii')
  448. else:
  449. repr_x.encode('ascii')
  450. if roundtrip is not None:
  451. assert x.__class__(bytes_x) == x
  452. assert x.__class__(unicode_x) == x
  453. assert eval(repr_x, roundtrip) == x
  454. @pytest.fixture(params=[0, 1, -1])
  455. def pickle_protocol(request):
  456. """
  457. Fixture to run all the tests for protocols 0 and 1, and -1 (most advanced).
  458. (Originally from astropy.table.tests.test_pickle)
  459. """
  460. return request.param
  461. def generic_recursive_equality_test(a, b, class_history):
  462. """
  463. Check if the attributes of a and b are equal. Then,
  464. check if the attributes of the attributes are equal.
  465. """
  466. dict_a = a.__dict__
  467. dict_b = b.__dict__
  468. for key in dict_a:
  469. assert key in dict_b,\
  470. "Did not pickle {0}".format(key)
  471. if hasattr(dict_a[key], '__eq__'):
  472. eq = (dict_a[key] == dict_b[key])
  473. if '__iter__' in dir(eq):
  474. eq = (False not in eq)
  475. assert eq, "Value of {0} changed by pickling".format(key)
  476. if hasattr(dict_a[key], '__dict__'):
  477. if dict_a[key].__class__ in class_history:
  478. #attempt to prevent infinite recursion
  479. pass
  480. else:
  481. new_class_history = [dict_a[key].__class__]
  482. new_class_history.extend(class_history)
  483. generic_recursive_equality_test(dict_a[key],
  484. dict_b[key],
  485. new_class_history)
  486. def check_pickling_recovery(original, protocol):
  487. """
  488. Try to pickle an object. If successful, make sure
  489. the object's attributes survived pickling and unpickling.
  490. """
  491. f = pickle.dumps(original, protocol=protocol)
  492. unpickled = pickle.loads(f)
  493. class_history = [original.__class__]
  494. generic_recursive_equality_test(original, unpickled,
  495. class_history)
  496. def assert_quantity_allclose(actual, desired, rtol=1.e-7, atol=None,
  497. **kwargs):
  498. """
  499. Raise an assertion if two objects are not equal up to desired tolerance.
  500. This is a :class:`~astropy.units.Quantity`-aware version of
  501. :func:`numpy.testing.assert_allclose`.
  502. """
  503. import numpy as np
  504. np.testing.assert_allclose(*_unquantify_allclose_arguments(actual, desired,
  505. rtol, atol),
  506. **kwargs)
  507. def quantity_allclose(a, b, rtol=1.e-5, atol=None, **kwargs):
  508. """
  509. Returns True if two arrays are element-wise equal within a tolerance.
  510. This is a :class:`~astropy.units.Quantity`-aware version of
  511. :func:`numpy.allclose`.
  512. """
  513. import numpy as np
  514. return np.allclose(*_unquantify_allclose_arguments(a, b, rtol, atol),
  515. **kwargs)
  516. def _unquantify_allclose_arguments(actual, desired, rtol, atol):
  517. from .. import units as u
  518. actual = u.Quantity(actual, subok=True, copy=False)
  519. desired = u.Quantity(desired, subok=True, copy=False)
  520. try:
  521. desired = desired.to(actual.unit)
  522. except u.UnitsError:
  523. raise u.UnitsError("Units for 'desired' ({0}) and 'actual' ({1}) "
  524. "are not convertible"
  525. .format(desired.unit, actual.unit))
  526. if atol is None:
  527. # by default, we assume an absolute tolerance of 0
  528. atol = u.Quantity(0)
  529. else:
  530. atol = u.Quantity(atol, subok=True, copy=False)
  531. try:
  532. atol = atol.to(actual.unit)
  533. except u.UnitsError:
  534. raise u.UnitsError("Units for 'atol' ({0}) and 'actual' ({1}) "
  535. "are not convertible"
  536. .format(atol.unit, actual.unit))
  537. rtol = u.Quantity(rtol, subok=True, copy=False)
  538. try:
  539. rtol = rtol.to(u.dimensionless_unscaled)
  540. except:
  541. raise u.UnitsError("`rtol` should be dimensionless")
  542. return actual.value, desired.value, rtol.value, atol.value