/pandas/util/testing.py
Python | 2951 lines | 2771 code | 46 blank | 134 comment | 59 complexity | 5035c8a995786d15bedccd1987a65a0f MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
Large files files are truncated, but you can click here to view the full file
- from __future__ import division
- from collections import Counter
- from contextlib import contextmanager
- from datetime import datetime
- from functools import wraps
- import locale
- import os
- import re
- from shutil import rmtree
- import string
- import subprocess
- import sys
- import tempfile
- import traceback
- import warnings
- import numpy as np
- from numpy.random import rand, randn
- from pandas._libs import testing as _testing
- import pandas.compat as compat
- from pandas.compat import (
- PY2, PY3, filter, httplib, lmap, lrange, lzip, map, raise_with_traceback,
- range, string_types, u, unichr, zip)
- from pandas.core.dtypes.common import (
- is_bool, is_categorical_dtype, is_datetime64_dtype, is_datetime64tz_dtype,
- is_datetimelike_v_numeric, is_datetimelike_v_object,
- is_extension_array_dtype, is_interval_dtype, is_list_like, is_number,
- is_period_dtype, is_sequence, is_timedelta64_dtype, needs_i8_conversion)
- from pandas.core.dtypes.missing import array_equivalent
- import pandas as pd
- from pandas import (
- Categorical, CategoricalIndex, DataFrame, DatetimeIndex, Index,
- IntervalIndex, MultiIndex, RangeIndex, Series, bdate_range)
- from pandas.core.algorithms import take_1d
- from pandas.core.arrays import (
- DatetimeArray, ExtensionArray, IntervalArray, PeriodArray, TimedeltaArray,
- period_array)
- import pandas.core.common as com
- from pandas.io.common import urlopen
- from pandas.io.formats.printing import pprint_thing
- N = 30
- K = 4
- _RAISE_NETWORK_ERROR_DEFAULT = False
- # set testing_mode
- _testing_mode_warnings = (DeprecationWarning, compat.ResourceWarning)
- def set_testing_mode():
- # set the testing mode filters
- testing_mode = os.environ.get('PANDAS_TESTING_MODE', 'None')
- if 'deprecate' in testing_mode:
- warnings.simplefilter('always', _testing_mode_warnings)
- def reset_testing_mode():
- # reset the testing mode filters
- testing_mode = os.environ.get('PANDAS_TESTING_MODE', 'None')
- if 'deprecate' in testing_mode:
- warnings.simplefilter('ignore', _testing_mode_warnings)
- set_testing_mode()
- def reset_display_options():
- """
- Reset the display options for printing and representing objects.
- """
- pd.reset_option('^display.', silent=True)
- def round_trip_pickle(obj, path=None):
- """
- Pickle an object and then read it again.
- Parameters
- ----------
- obj : pandas object
- The object to pickle and then re-read.
- path : str, default None
- The path where the pickled object is written and then read.
- Returns
- -------
- round_trip_pickled_object : pandas object
- The original object that was pickled and then re-read.
- """
- if path is None:
- path = u('__{random_bytes}__.pickle'.format(random_bytes=rands(10)))
- with ensure_clean(path) as path:
- pd.to_pickle(obj, path)
- return pd.read_pickle(path)
- def round_trip_pathlib(writer, reader, path=None):
- """
- Write an object to file specified by a pathlib.Path and read it back
- Parameters
- ----------
- writer : callable bound to pandas object
- IO writing function (e.g. DataFrame.to_csv )
- reader : callable
- IO reading function (e.g. pd.read_csv )
- path : str, default None
- The path where the object is written and then read.
- Returns
- -------
- round_trip_object : pandas object
- The original object that was serialized and then re-read.
- """
- import pytest
- Path = pytest.importorskip('pathlib').Path
- if path is None:
- path = '___pathlib___'
- with ensure_clean(path) as path:
- writer(Path(path))
- obj = reader(Path(path))
- return obj
- def round_trip_localpath(writer, reader, path=None):
- """
- Write an object to file specified by a py.path LocalPath and read it back
- Parameters
- ----------
- writer : callable bound to pandas object
- IO writing function (e.g. DataFrame.to_csv )
- reader : callable
- IO reading function (e.g. pd.read_csv )
- path : str, default None
- The path where the object is written and then read.
- Returns
- -------
- round_trip_object : pandas object
- The original object that was serialized and then re-read.
- """
- import pytest
- LocalPath = pytest.importorskip('py.path').local
- if path is None:
- path = '___localpath___'
- with ensure_clean(path) as path:
- writer(LocalPath(path))
- obj = reader(LocalPath(path))
- return obj
- @contextmanager
- def decompress_file(path, compression):
- """
- Open a compressed file and return a file object
- Parameters
- ----------
- path : str
- The path where the file is read from
- compression : {'gzip', 'bz2', 'zip', 'xz', None}
- Name of the decompression to use
- Returns
- -------
- f : file object
- """
- if compression is None:
- f = open(path, 'rb')
- elif compression == 'gzip':
- import gzip
- f = gzip.open(path, 'rb')
- elif compression == 'bz2':
- import bz2
- f = bz2.BZ2File(path, 'rb')
- elif compression == 'xz':
- lzma = compat.import_lzma()
- f = lzma.LZMAFile(path, 'rb')
- elif compression == 'zip':
- import zipfile
- zip_file = zipfile.ZipFile(path)
- zip_names = zip_file.namelist()
- if len(zip_names) == 1:
- f = zip_file.open(zip_names.pop())
- else:
- raise ValueError('ZIP file {} error. Only one file per ZIP.'
- .format(path))
- else:
- msg = 'Unrecognized compression type: {}'.format(compression)
- raise ValueError(msg)
- try:
- yield f
- finally:
- f.close()
- if compression == "zip":
- zip_file.close()
- def write_to_compressed(compression, path, data, dest="test"):
- """
- Write data to a compressed file.
- Parameters
- ----------
- compression : {'gzip', 'bz2', 'zip', 'xz'}
- The compression type to use.
- path : str
- The file path to write the data.
- data : str
- The data to write.
- dest : str, default "test"
- The destination file (for ZIP only)
- Raises
- ------
- ValueError : An invalid compression value was passed in.
- """
- if compression == "zip":
- import zipfile
- compress_method = zipfile.ZipFile
- elif compression == "gzip":
- import gzip
- compress_method = gzip.GzipFile
- elif compression == "bz2":
- import bz2
- compress_method = bz2.BZ2File
- elif compression == "xz":
- lzma = compat.import_lzma()
- compress_method = lzma.LZMAFile
- else:
- msg = "Unrecognized compression type: {}".format(compression)
- raise ValueError(msg)
- if compression == "zip":
- mode = "w"
- args = (dest, data)
- method = "writestr"
- else:
- mode = "wb"
- args = (data,)
- method = "write"
- with compress_method(path, mode=mode) as f:
- getattr(f, method)(*args)
- def assert_almost_equal(left, right, check_dtype="equiv",
- check_less_precise=False, **kwargs):
- """
- Check that the left and right objects are approximately equal.
- By approximately equal, we refer to objects that are numbers or that
- contain numbers which may be equivalent to specific levels of precision.
- Parameters
- ----------
- left : object
- right : object
- check_dtype : bool / string {'equiv'}, default 'equiv'
- Check dtype if both a and b are the same type. If 'equiv' is passed in,
- then `RangeIndex` and `Int64Index` are also considered equivalent
- when doing type checking.
- check_less_precise : bool or int, default False
- Specify comparison precision. 5 digits (False) or 3 digits (True)
- after decimal points are compared. If int, then specify the number
- of digits to compare.
- When comparing two numbers, if the first number has magnitude less
- than 1e-5, we compare the two numbers directly and check whether
- they are equivalent within the specified precision. Otherwise, we
- compare the **ratio** of the second number to the first number and
- check whether it is equivalent to 1 within the specified precision.
- """
- if isinstance(left, pd.Index):
- return assert_index_equal(left, right,
- check_exact=False,
- exact=check_dtype,
- check_less_precise=check_less_precise,
- **kwargs)
- elif isinstance(left, pd.Series):
- return assert_series_equal(left, right,
- check_exact=False,
- check_dtype=check_dtype,
- check_less_precise=check_less_precise,
- **kwargs)
- elif isinstance(left, pd.DataFrame):
- return assert_frame_equal(left, right,
- check_exact=False,
- check_dtype=check_dtype,
- check_less_precise=check_less_precise,
- **kwargs)
- else:
- # Other sequences.
- if check_dtype:
- if is_number(left) and is_number(right):
- # Do not compare numeric classes, like np.float64 and float.
- pass
- elif is_bool(left) and is_bool(right):
- # Do not compare bool classes, like np.bool_ and bool.
- pass
- else:
- if (isinstance(left, np.ndarray) or
- isinstance(right, np.ndarray)):
- obj = "numpy array"
- else:
- obj = "Input"
- assert_class_equal(left, right, obj=obj)
- return _testing.assert_almost_equal(
- left, right,
- check_dtype=check_dtype,
- check_less_precise=check_less_precise,
- **kwargs)
- def _check_isinstance(left, right, cls):
- """
- Helper method for our assert_* methods that ensures that
- the two objects being compared have the right type before
- proceeding with the comparison.
- Parameters
- ----------
- left : The first object being compared.
- right : The second object being compared.
- cls : The class type to check against.
- Raises
- ------
- AssertionError : Either `left` or `right` is not an instance of `cls`.
- """
- err_msg = "{name} Expected type {exp_type}, found {act_type} instead"
- cls_name = cls.__name__
- if not isinstance(left, cls):
- raise AssertionError(err_msg.format(name=cls_name, exp_type=cls,
- act_type=type(left)))
- if not isinstance(right, cls):
- raise AssertionError(err_msg.format(name=cls_name, exp_type=cls,
- act_type=type(right)))
- def assert_dict_equal(left, right, compare_keys=True):
- _check_isinstance(left, right, dict)
- return _testing.assert_dict_equal(left, right, compare_keys=compare_keys)
- def randbool(size=(), p=0.5):
- return rand(*size) <= p
- RANDS_CHARS = np.array(list(string.ascii_letters + string.digits),
- dtype=(np.str_, 1))
- RANDU_CHARS = np.array(list(u("").join(map(unichr, lrange(1488, 1488 + 26))) +
- string.digits), dtype=(np.unicode_, 1))
- def rands_array(nchars, size, dtype='O'):
- """Generate an array of byte strings."""
- retval = (np.random.choice(RANDS_CHARS, size=nchars * np.prod(size))
- .view((np.str_, nchars)).reshape(size))
- if dtype is None:
- return retval
- else:
- return retval.astype(dtype)
- def randu_array(nchars, size, dtype='O'):
- """Generate an array of unicode strings."""
- retval = (np.random.choice(RANDU_CHARS, size=nchars * np.prod(size))
- .view((np.unicode_, nchars)).reshape(size))
- if dtype is None:
- return retval
- else:
- return retval.astype(dtype)
- def rands(nchars):
- """
- Generate one random byte string.
- See `rands_array` if you want to create an array of random strings.
- """
- return ''.join(np.random.choice(RANDS_CHARS, nchars))
- def randu(nchars):
- """
- Generate one random unicode string.
- See `randu_array` if you want to create an array of random unicode strings.
- """
- return ''.join(np.random.choice(RANDU_CHARS, nchars))
- def close(fignum=None):
- from matplotlib.pyplot import get_fignums, close as _close
- if fignum is None:
- for fignum in get_fignums():
- _close(fignum)
- else:
- _close(fignum)
- # -----------------------------------------------------------------------------
- # locale utilities
- def check_output(*popenargs, **kwargs):
- # shamelessly taken from Python 2.7 source
- r"""Run command with arguments and return its output as a byte string.
- If the exit code was non-zero it raises a CalledProcessError. The
- CalledProcessError object will have the return code in the returncode
- attribute and output in the output attribute.
- The arguments are the same as for the Popen constructor. Example:
- >>> check_output(["ls", "-l", "/dev/null"])
- 'crw-rw-rw- 1 root root 1, 3 Oct 18 2007 /dev/null\n'
- The stdout argument is not allowed as it is used internally.
- To capture standard error in the result, use stderr=STDOUT.
- >>> check_output(["/bin/sh", "-c",
- ... "ls -l non_existent_file ; exit 0"],
- ... stderr=STDOUT)
- 'ls: non_existent_file: No such file or directory\n'
- """
- if 'stdout' in kwargs:
- raise ValueError('stdout argument not allowed, it will be overridden.')
- process = subprocess.Popen(stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- *popenargs, **kwargs)
- output, unused_err = process.communicate()
- retcode = process.poll()
- if retcode:
- cmd = kwargs.get("args")
- if cmd is None:
- cmd = popenargs[0]
- raise subprocess.CalledProcessError(retcode, cmd, output=output)
- return output
- def _default_locale_getter():
- try:
- raw_locales = check_output(['locale -a'], shell=True)
- except subprocess.CalledProcessError as e:
- raise type(e)("{exception}, the 'locale -a' command cannot be found "
- "on your system".format(exception=e))
- return raw_locales
- def get_locales(prefix=None, normalize=True,
- locale_getter=_default_locale_getter):
- """Get all the locales that are available on the system.
- Parameters
- ----------
- prefix : str
- If not ``None`` then return only those locales with the prefix
- provided. For example to get all English language locales (those that
- start with ``"en"``), pass ``prefix="en"``.
- normalize : bool
- Call ``locale.normalize`` on the resulting list of available locales.
- If ``True``, only locales that can be set without throwing an
- ``Exception`` are returned.
- locale_getter : callable
- The function to use to retrieve the current locales. This should return
- a string with each locale separated by a newline character.
- Returns
- -------
- locales : list of strings
- A list of locale strings that can be set with ``locale.setlocale()``.
- For example::
- locale.setlocale(locale.LC_ALL, locale_string)
- On error will return None (no locale available, e.g. Windows)
- """
- try:
- raw_locales = locale_getter()
- except Exception:
- return None
- try:
- # raw_locales is "\n" separated list of locales
- # it may contain non-decodable parts, so split
- # extract what we can and then rejoin.
- raw_locales = raw_locales.split(b'\n')
- out_locales = []
- for x in raw_locales:
- if PY3:
- out_locales.append(str(
- x, encoding=pd.options.display.encoding))
- else:
- out_locales.append(str(x))
- except TypeError:
- pass
- if prefix is None:
- return _valid_locales(out_locales, normalize)
- pattern = re.compile('{prefix}.*'.format(prefix=prefix))
- found = pattern.findall('\n'.join(out_locales))
- return _valid_locales(found, normalize)
- @contextmanager
- def set_locale(new_locale, lc_var=locale.LC_ALL):
- """Context manager for temporarily setting a locale.
- Parameters
- ----------
- new_locale : str or tuple
- A string of the form <language_country>.<encoding>. For example to set
- the current locale to US English with a UTF8 encoding, you would pass
- "en_US.UTF-8".
- lc_var : int, default `locale.LC_ALL`
- The category of the locale being set.
- Notes
- -----
- This is useful when you want to run a particular block of code under a
- particular locale, without globally setting the locale. This probably isn't
- thread-safe.
- """
- current_locale = locale.getlocale()
- try:
- locale.setlocale(lc_var, new_locale)
- normalized_locale = locale.getlocale()
- if com._all_not_none(*normalized_locale):
- yield '.'.join(normalized_locale)
- else:
- yield new_locale
- finally:
- locale.setlocale(lc_var, current_locale)
- def can_set_locale(lc, lc_var=locale.LC_ALL):
- """
- Check to see if we can set a locale, and subsequently get the locale,
- without raising an Exception.
- Parameters
- ----------
- lc : str
- The locale to attempt to set.
- lc_var : int, default `locale.LC_ALL`
- The category of the locale being set.
- Returns
- -------
- is_valid : bool
- Whether the passed locale can be set
- """
- try:
- with set_locale(lc, lc_var=lc_var):
- pass
- except (ValueError,
- locale.Error): # horrible name for a Exception subclass
- return False
- else:
- return True
- def _valid_locales(locales, normalize):
- """Return a list of normalized locales that do not throw an ``Exception``
- when set.
- Parameters
- ----------
- locales : str
- A string where each locale is separated by a newline.
- normalize : bool
- Whether to call ``locale.normalize`` on each locale.
- Returns
- -------
- valid_locales : list
- A list of valid locales.
- """
- if normalize:
- normalizer = lambda x: locale.normalize(x.strip())
- else:
- normalizer = lambda x: x.strip()
- return list(filter(can_set_locale, map(normalizer, locales)))
- # -----------------------------------------------------------------------------
- # Stdout / stderr decorators
- @contextmanager
- def set_defaultencoding(encoding):
- """
- Set default encoding (as given by sys.getdefaultencoding()) to the given
- encoding; restore on exit.
- Parameters
- ----------
- encoding : str
- """
- if not PY2:
- raise ValueError("set_defaultencoding context is only available "
- "in Python 2.")
- orig = sys.getdefaultencoding()
- reload(sys) # noqa:F821
- sys.setdefaultencoding(encoding)
- try:
- yield
- finally:
- sys.setdefaultencoding(orig)
- # -----------------------------------------------------------------------------
- # contextmanager to ensure the file cleanup
- @contextmanager
- def ensure_clean(filename=None, return_filelike=False):
- """Gets a temporary path and agrees to remove on close.
- Parameters
- ----------
- filename : str (optional)
- if None, creates a temporary file which is then removed when out of
- scope. if passed, creates temporary file with filename as ending.
- return_filelike : bool (default False)
- if True, returns a file-like which is *always* cleaned. Necessary for
- savefig and other functions which want to append extensions.
- """
- filename = filename or ''
- fd = None
- if return_filelike:
- f = tempfile.TemporaryFile(suffix=filename)
- try:
- yield f
- finally:
- f.close()
- else:
- # don't generate tempfile if using a path with directory specified
- if len(os.path.dirname(filename)):
- raise ValueError("Can't pass a qualified name to ensure_clean()")
- try:
- fd, filename = tempfile.mkstemp(suffix=filename)
- except UnicodeEncodeError:
- import pytest
- pytest.skip('no unicode file names on this system')
- try:
- yield filename
- finally:
- try:
- os.close(fd)
- except Exception:
- print("Couldn't close file descriptor: {fdesc} (file: {fname})"
- .format(fdesc=fd, fname=filename))
- try:
- if os.path.exists(filename):
- os.remove(filename)
- except Exception as e:
- print("Exception on removing file: {error}".format(error=e))
- @contextmanager
- def ensure_clean_dir():
- """
- Get a temporary directory path and agrees to remove on close.
- Yields
- ------
- Temporary directory path
- """
- directory_name = tempfile.mkdtemp(suffix='')
- try:
- yield directory_name
- finally:
- try:
- rmtree(directory_name)
- except Exception:
- pass
- @contextmanager
- def ensure_safe_environment_variables():
- """
- Get a context manager to safely set environment variables
- All changes will be undone on close, hence environment variables set
- within this contextmanager will neither persist nor change global state.
- """
- saved_environ = dict(os.environ)
- try:
- yield
- finally:
- os.environ.clear()
- os.environ.update(saved_environ)
- # -----------------------------------------------------------------------------
- # Comparators
- def equalContents(arr1, arr2):
- """Checks if the set of unique elements of arr1 and arr2 are equivalent.
- """
- return frozenset(arr1) == frozenset(arr2)
- def assert_index_equal(left, right, exact='equiv', check_names=True,
- check_less_precise=False, check_exact=True,
- check_categorical=True, obj='Index'):
- """Check that left and right Index are equal.
- Parameters
- ----------
- left : Index
- right : Index
- exact : bool / string {'equiv'}, default 'equiv'
- Whether to check the Index class, dtype and inferred_type
- are identical. If 'equiv', then RangeIndex can be substituted for
- Int64Index as well.
- check_names : bool, default True
- Whether to check the names attribute.
- check_less_precise : bool or int, default False
- Specify comparison precision. Only used when check_exact is False.
- 5 digits (False) or 3 digits (True) after decimal points are compared.
- If int, then specify the digits to compare
- check_exact : bool, default True
- Whether to compare number exactly.
- check_categorical : bool, default True
- Whether to compare internal Categorical exactly.
- obj : str, default 'Index'
- Specify object name being compared, internally used to show appropriate
- assertion message
- """
- __tracebackhide__ = True
- def _check_types(l, r, obj='Index'):
- if exact:
- assert_class_equal(l, r, exact=exact, obj=obj)
- # Skip exact dtype checking when `check_categorical` is False
- if check_categorical:
- assert_attr_equal('dtype', l, r, obj=obj)
- # allow string-like to have different inferred_types
- if l.inferred_type in ('string', 'unicode'):
- assert r.inferred_type in ('string', 'unicode')
- else:
- assert_attr_equal('inferred_type', l, r, obj=obj)
- def _get_ilevel_values(index, level):
- # accept level number only
- unique = index.levels[level]
- labels = index.codes[level]
- filled = take_1d(unique.values, labels, fill_value=unique._na_value)
- values = unique._shallow_copy(filled, name=index.names[level])
- return values
- # instance validation
- _check_isinstance(left, right, Index)
- # class / dtype comparison
- _check_types(left, right, obj=obj)
- # level comparison
- if left.nlevels != right.nlevels:
- msg1 = '{obj} levels are different'.format(obj=obj)
- msg2 = '{nlevels}, {left}'.format(nlevels=left.nlevels, left=left)
- msg3 = '{nlevels}, {right}'.format(nlevels=right.nlevels, right=right)
- raise_assert_detail(obj, msg1, msg2, msg3)
- # length comparison
- if len(left) != len(right):
- msg1 = '{obj} length are different'.format(obj=obj)
- msg2 = '{length}, {left}'.format(length=len(left), left=left)
- msg3 = '{length}, {right}'.format(length=len(right), right=right)
- raise_assert_detail(obj, msg1, msg2, msg3)
- # MultiIndex special comparison for little-friendly error messages
- if left.nlevels > 1:
- for level in range(left.nlevels):
- # cannot use get_level_values here because it can change dtype
- llevel = _get_ilevel_values(left, level)
- rlevel = _get_ilevel_values(right, level)
- lobj = 'MultiIndex level [{level}]'.format(level=level)
- assert_index_equal(llevel, rlevel,
- exact=exact, check_names=check_names,
- check_less_precise=check_less_precise,
- check_exact=check_exact, obj=lobj)
- # get_level_values may change dtype
- _check_types(left.levels[level], right.levels[level], obj=obj)
- # skip exact index checking when `check_categorical` is False
- if check_exact and check_categorical:
- if not left.equals(right):
- diff = np.sum((left.values != right.values)
- .astype(int)) * 100.0 / len(left)
- msg = '{obj} values are different ({pct} %)'.format(
- obj=obj, pct=np.round(diff, 5))
- raise_assert_detail(obj, msg, left, right)
- else:
- _testing.assert_almost_equal(left.values, right.values,
- check_less_precise=check_less_precise,
- check_dtype=exact,
- obj=obj, lobj=left, robj=right)
- # metadata comparison
- if check_names:
- assert_attr_equal('names', left, right, obj=obj)
- if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex):
- assert_attr_equal('freq', left, right, obj=obj)
- if (isinstance(left, pd.IntervalIndex) or
- isinstance(right, pd.IntervalIndex)):
- assert_interval_array_equal(left.values, right.values)
- if check_categorical:
- if is_categorical_dtype(left) or is_categorical_dtype(right):
- assert_categorical_equal(left.values, right.values,
- obj='{obj} category'.format(obj=obj))
- def assert_class_equal(left, right, exact=True, obj='Input'):
- """checks classes are equal."""
- __tracebackhide__ = True
- def repr_class(x):
- if isinstance(x, Index):
- # return Index as it is to include values in the error message
- return x
- try:
- return x.__class__.__name__
- except AttributeError:
- return repr(type(x))
- if exact == 'equiv':
- if type(left) != type(right):
- # allow equivalence of Int64Index/RangeIndex
- types = {type(left).__name__, type(right).__name__}
- if len(types - {'Int64Index', 'RangeIndex'}):
- msg = '{obj} classes are not equivalent'.format(obj=obj)
- raise_assert_detail(obj, msg, repr_class(left),
- repr_class(right))
- elif exact:
- if type(left) != type(right):
- msg = '{obj} classes are different'.format(obj=obj)
- raise_assert_detail(obj, msg, repr_class(left),
- repr_class(right))
- def assert_attr_equal(attr, left, right, obj='Attributes'):
- """checks attributes are equal. Both objects must have attribute.
- Parameters
- ----------
- attr : str
- Attribute name being compared.
- left : object
- right : object
- obj : str, default 'Attributes'
- Specify object name being compared, internally used to show appropriate
- assertion message
- """
- __tracebackhide__ = True
- left_attr = getattr(left, attr)
- right_attr = getattr(right, attr)
- if left_attr is right_attr:
- return True
- elif (is_number(left_attr) and np.isnan(left_attr) and
- is_number(right_attr) and np.isnan(right_attr)):
- # np.nan
- return True
- try:
- result = left_attr == right_attr
- except TypeError:
- # datetimetz on rhs may raise TypeError
- result = False
- if not isinstance(result, bool):
- result = result.all()
- if result:
- return True
- else:
- msg = 'Attribute "{attr}" are different'.format(attr=attr)
- raise_assert_detail(obj, msg, left_attr, right_attr)
- def assert_is_valid_plot_return_object(objs):
- import matplotlib.pyplot as plt
- if isinstance(objs, (pd.Series, np.ndarray)):
- for el in objs.ravel():
- msg = ("one of 'objs' is not a matplotlib Axes instance, type "
- "encountered {name!r}").format(name=el.__class__.__name__)
- assert isinstance(el, (plt.Axes, dict)), msg
- else:
- assert isinstance(objs, (plt.Artist, tuple, dict)), (
- 'objs is neither an ndarray of Artist instances nor a '
- 'single Artist instance, tuple, or dict, "objs" is a {name!r}'
- .format(name=objs.__class__.__name__))
- def isiterable(obj):
- return hasattr(obj, '__iter__')
- def is_sorted(seq):
- if isinstance(seq, (Index, Series)):
- seq = seq.values
- # sorting does not change precisions
- return assert_numpy_array_equal(seq, np.sort(np.array(seq)))
- def assert_categorical_equal(left, right, check_dtype=True,
- check_category_order=True, obj='Categorical'):
- """Test that Categoricals are equivalent.
- Parameters
- ----------
- left : Categorical
- right : Categorical
- check_dtype : bool, default True
- Check that integer dtype of the codes are the same
- check_category_order : bool, default True
- Whether the order of the categories should be compared, which
- implies identical integer codes. If False, only the resulting
- values are compared. The ordered attribute is
- checked regardless.
- obj : str, default 'Categorical'
- Specify object name being compared, internally used to show appropriate
- assertion message
- """
- _check_isinstance(left, right, Categorical)
- if check_category_order:
- assert_index_equal(left.categories, right.categories,
- obj='{obj}.categories'.format(obj=obj))
- assert_numpy_array_equal(left.codes, right.codes,
- check_dtype=check_dtype,
- obj='{obj}.codes'.format(obj=obj))
- else:
- assert_index_equal(left.categories.sort_values(),
- right.categories.sort_values(),
- obj='{obj}.categories'.format(obj=obj))
- assert_index_equal(left.categories.take(left.codes),
- right.categories.take(right.codes),
- obj='{obj}.values'.format(obj=obj))
- assert_attr_equal('ordered', left, right, obj=obj)
- def assert_interval_array_equal(left, right, exact='equiv',
- obj='IntervalArray'):
- """Test that two IntervalArrays are equivalent.
- Parameters
- ----------
- left, right : IntervalArray
- The IntervalArrays to compare.
- exact : bool / string {'equiv'}, default 'equiv'
- Whether to check the Index class, dtype and inferred_type
- are identical. If 'equiv', then RangeIndex can be substituted for
- Int64Index as well.
- obj : str, default 'IntervalArray'
- Specify object name being compared, internally used to show appropriate
- assertion message
- """
- _check_isinstance(left, right, IntervalArray)
- assert_index_equal(left.left, right.left, exact=exact,
- obj='{obj}.left'.format(obj=obj))
- assert_index_equal(left.right, right.right, exact=exact,
- obj='{obj}.left'.format(obj=obj))
- assert_attr_equal('closed', left, right, obj=obj)
- def assert_period_array_equal(left, right, obj='PeriodArray'):
- _check_isinstance(left, right, PeriodArray)
- assert_numpy_array_equal(left._data, right._data,
- obj='{obj}.values'.format(obj=obj))
- assert_attr_equal('freq', left, right, obj=obj)
- def assert_datetime_array_equal(left, right, obj='DatetimeArray'):
- __tracebackhide__ = True
- _check_isinstance(left, right, DatetimeArray)
- assert_numpy_array_equal(left._data, right._data,
- obj='{obj}._data'.format(obj=obj))
- assert_attr_equal('freq', left, right, obj=obj)
- assert_attr_equal('tz', left, right, obj=obj)
- def assert_timedelta_array_equal(left, right, obj='TimedeltaArray'):
- __tracebackhide__ = True
- _check_isinstance(left, right, TimedeltaArray)
- assert_numpy_array_equal(left._data, right._data,
- obj='{obj}._data'.format(obj=obj))
- assert_attr_equal('freq', left, right, obj=obj)
- def raise_assert_detail(obj, message, left, right, diff=None):
- __tracebackhide__ = True
- if isinstance(left, np.ndarray):
- left = pprint_thing(left)
- elif is_categorical_dtype(left):
- left = repr(left)
- if PY2 and isinstance(left, string_types):
- # left needs to be printable in native text type in python2
- left = left.encode('utf-8')
- if isinstance(right, np.ndarray):
- right = pprint_thing(right)
- elif is_categorical_dtype(right):
- right = repr(right)
- if PY2 and isinstance(right, string_types):
- # right needs to be printable in native text type in python2
- right = right.encode('utf-8')
- msg = """{obj} are different
- {message}
- [left]: {left}
- [right]: {right}""".format(obj=obj, message=message, left=left, right=right)
- if diff is not None:
- msg += "\n[diff]: {diff}".format(diff=diff)
- raise AssertionError(msg)
- def assert_numpy_array_equal(left, right, strict_nan=False,
- check_dtype=True, err_msg=None,
- check_same=None, obj='numpy array'):
- """ Checks that 'np.ndarray' is equivalent
- Parameters
- ----------
- left : np.ndarray or iterable
- right : np.ndarray or iterable
- strict_nan : bool, default False
- If True, consider NaN and None to be different.
- check_dtype: bool, default True
- check dtype if both a and b are np.ndarray
- err_msg : str, default None
- If provided, used as assertion message
- check_same : None|'copy'|'same', default None
- Ensure left and right refer/do not refer to the same memory area
- obj : str, default 'numpy array'
- Specify object name being compared, internally used to show appropriate
- assertion message
- """
- __tracebackhide__ = True
- # instance validation
- # Show a detailed error message when classes are different
- assert_class_equal(left, right, obj=obj)
- # both classes must be an np.ndarray
- _check_isinstance(left, right, np.ndarray)
- def _get_base(obj):
- return obj.base if getattr(obj, 'base', None) is not None else obj
- left_base = _get_base(left)
- right_base = _get_base(right)
- if check_same == 'same':
- if left_base is not right_base:
- msg = "{left!r} is not {right!r}".format(
- left=left_base, right=right_base)
- raise AssertionError(msg)
- elif check_same == 'copy':
- if left_base is right_base:
- msg = "{left!r} is {right!r}".format(
- left=left_base, right=right_base)
- raise AssertionError(msg)
- def _raise(left, right, err_msg):
- if err_msg is None:
- if left.shape != right.shape:
- raise_assert_detail(obj, '{obj} shapes are different'
- .format(obj=obj), left.shape, right.shape)
- diff = 0
- for l, r in zip(left, right):
- # count up differences
- if not array_equivalent(l, r, strict_nan=strict_nan):
- diff += 1
- diff = diff * 100.0 / left.size
- msg = '{obj} values are different ({pct} %)'.format(
- obj=obj, pct=np.round(diff, 5))
- raise_assert_detail(obj, msg, left, right)
- raise AssertionError(err_msg)
- # compare shape and values
- if not array_equivalent(left, right, strict_nan=strict_nan):
- _raise(left, right, err_msg)
- if check_dtype:
- if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
- assert_attr_equal('dtype', left, right, obj=obj)
- return True
- def assert_extension_array_equal(left, right, check_dtype=True,
- check_less_precise=False,
- check_exact=False):
- """Check that left and right ExtensionArrays are equal.
- Parameters
- ----------
- left, right : ExtensionArray
- The two arrays to compare
- check_dtype : bool, default True
- Whether to check if the ExtensionArray dtypes are identical.
- check_less_precise : bool or int, default False
- Specify comparison precision. Only used when check_exact is False.
- 5 digits (False) or 3 digits (True) after decimal points are compared.
- If int, then specify the digits to compare.
- check_exact : bool, default False
- Whether to compare number exactly.
- Notes
- -----
- Missing values are checked separately from valid values.
- A mask of missing values is computed for each and checked to match.
- The remaining all-valid values are cast to object dtype and checked.
- """
- assert isinstance(left, ExtensionArray), 'left is not an ExtensionArray'
- assert isinstance(right, ExtensionArray), 'right is not an ExtensionArray'
- if check_dtype:
- assert_attr_equal('dtype', left, right, obj='ExtensionArray')
- if hasattr(left, "asi8") and type(right) == type(left):
- # Avoid slow object-dtype comparisons
- assert_numpy_array_equal(left.asi8, right.asi8)
- return
- left_na = np.asarray(left.isna())
- right_na = np.asarray(right.isna())
- assert_numpy_array_equal(left_na, right_na, obj='ExtensionArray NA mask')
- left_valid = np.asarray(left[~left_na].astype(object))
- right_valid = np.asarray(right[~right_na].astype(object))
- if check_exact:
- assert_numpy_array_equal(left_valid, right_valid, obj='ExtensionArray')
- else:
- _testing.assert_almost_equal(left_valid, right_valid,
- check_dtype=check_dtype,
- check_less_precise=check_less_precise,
- obj='ExtensionArray')
- # This could be refactored to use the NDFrame.equals method
- def assert_series_equal(left, right, check_dtype=True,
- check_index_type='equiv',
- check_series_type=True,
- check_less_precise=False,
- check_names=True,
- check_exact=False,
- check_datetimelike_compat=False,
- check_categorical=True,
- obj='Series'):
- """Check that left and right Series are equal.
- Parameters
- ----------
- left : Series
- right : Series
- check_dtype : bool, default True
- Whether to check the Series dtype is identical.
- check_index_type : bool / string {'equiv'}, default 'equiv'
- Whether to check the Index class, dtype and inferred_type
- are identical.
- check_series_type : bool, default True
- Whether to check the Series class is identical.
- check_less_precise : bool or int, default False
- Specify comparison precision. Only used when check_exact is False.
- 5 digits (False) or 3 digits (True) after decimal points are compared.
- If int, then specify the digits to compare.
- check_names : bool, default True
- Whether to check the Series and Index names attribute.
- check_exact : bool, default False
- Whether to compare number exactly.
- check_datetimelike_compat : bool, default False
- Compare datetime-like which is comparable ignoring dtype.
- check_categorical : bool, default True
- Whether to compare internal Categorical exactly.
- obj : str, default 'Series'
- Specify object name being compared, internally used to show appropriate
- assertion message.
- """
- __tracebackhide__ = True
- # instance validation
- _check_isinstance(left, right, Series)
- if check_series_type:
- # ToDo: There are some tests using rhs is sparse
- # lhs is dense. Should use assert_class_equal in future
- assert isinstance(left, type(right))
- # assert_class_equal(left, right, obj=obj)
- # length comparison
- if len(left) != len(right):
- msg1 = '{len}, {left}'.format(len=len(left), left=left.index)
- msg2 = '{len}, {right}'.format(len=len(right), right=right.index)
- raise_assert_detail(obj, 'Series length are different', msg1, msg2)
- # index comparison
- assert_index_equal(left.index, right.index, exact=check_index_type,
- check_names=check_names,
- check_less_precise=check_less_precise,
- check_exact=check_exact,
- check_categorical=check_categorical,
- obj='{obj}.index'.format(obj=obj))
- if check_dtype:
- # We want to skip exact dtype checking when `check_categorical`
- # is False. We'll still raise if only one is a `Categorical`,
- # regardless of `check_categorical`
- if (is_categorical_dtype(left) and is_categorical_dtype(right) and
- not check_categorical):
- pass
- else:
- assert_attr_equal('dtype', left, right)
- if check_exact:
- assert_numpy_array_equal(left.get_values(), right.get_values(),
- check_dtype=check_dtype,
- obj='{obj}'.format(obj=obj),)
- elif check_datetimelike_compat:
- # we want to check only if we have compat dtypes
- # e.g. integer and M|m are NOT compat, but we can simply check
- # the values in that case
- if (is_datetimelike_v_numeric(left, right) or
- is_datetimelike_v_object(left, right) or
- needs_i8_conversion(left) or
- needs_i8_conversion(right)):
- # datetimelike may have different objects (e.g. datetime.datetime
- # vs Timestamp) but will compare equal
- if not Index(left.values).equals(Index(right.values)):
- msg = ('[datetimelike_compat=True] {left} is not equal to '
- '{right}.').format(left=left.values, right=right.values)
- raise AssertionError(msg)
- else:
- assert_numpy_array_equal(left.get_values(), right.get_values(),
- check_dtype=check_dtype)
- elif is_interval_dtype(left) or is_interval_dtype(right):
- assert_interval_array_equal(left.array, right.array)
- elif (is_extension_array_dtype(left.dtype) and
- is_datetime64tz_dtype(left.dtype)):
- # .values is an ndarray, but ._values is the ExtensionArray.
- # TODO: Use .array
- assert is_extension_array_dtype(right.dtype)
- return assert_extension_array_equal(left._values, right._values)
- elif (is_extension_array_dtype(left) and not is_categorical_dtype(left) and
- is_extension_array_dtype(right) and not is_categorical_dtype(right)):
- return assert_extension_array_equal(left.array, right.array)
- else:
- _testing.assert_almost_equal(left.get_values(), right.get_values(),
- check_less_precise=check_less_precise,
- check_dtype=check_dtype,
- obj='{obj}'.format(obj=obj))
- # metadata comparison
- if check_names:
- assert_attr_equal('name', left, right, obj=obj)
- if check_categorical:
- if is_categorical_dtype(left) or is_categorical_dtype(right):
- assert_categorical_equal(left.values, right.values,
- obj='{obj} category'.format(obj=obj))
- # This could be refactored to use the NDFrame.equals method
- def assert_frame_equal(left, right, check_dtype=True,
- check_index_type='equiv',
- check_column_type='equiv',
- check_frame_type=True,
- check_less_precise=False,
- check_names=True,
- by_blocks=False,
- check_exact=False,
- check_datetimelike_compat=False,
- check_categorical=True,
- check_like=False,
- obj='DataFrame'):
- """
- Check that left and right DataFrame are equal.
- This function is intended to compare two DataFrames and output any
- differences. Is is mostly intended for use in unit tests.
- Additional parameters allow varying the strictness of the
- equality checks performed.
- Parameters
- ----------
- left : DataFrame
- First DataFrame to compare.
- right : DataFrame
- Second DataFrame to compare.
- check_dtype : bool, default True
- Whether to check the DataFrame dtype is identical.
- check_index_type : bool / string {'equiv'}, default 'equiv'
- Whether to check the Index class, dtype and inferred_type
- are identical.
- check_column_type : bool / string {'equiv'}, default 'equiv'
- Whether to check the columns class, dtype and inferred_type
- are identical. Is passed as the ``exact`` argument of
- :func:`assert_index_equal`.
- check_frame_type : bool, default True
- Whether to check the DataFrame class is identical.
- check_less_precise : bool or int, default False
- Specify comparison precision. Only used when check_exact is False.
- 5 digits (False) or 3 digits (True) after decimal points are compared.
- If int, then specify the digits to compare.
- check_names : bool, default True
- Whether to check that the `names` attribute for both the `index`
- and `column` attributes of the DataFrame is identical, i.e.
- * left.index.names == right.index.names
- * left.columns.names == right.columns.names
- by_blocks : bool, default False
- Specify how to compare internal data. If False, compare by columns.
- If True, compare by blocks.
- check_exact : bool, default False
- Whether to compare number exactly.
- check_datetimelike_compat : bool, default False
- Compare datetime-like which is comparable ignoring dtype.
- check_categorical : bool, default True
- Whether to compare internal Categorical exactly.
- check_like : bool, default False
- If True, ignore the order of index & columns.
- Note: index labels must match their respective rows
- (same as in columns) - same labels must be with the same data.
- obj : str, default 'DataFrame'
- Specify object name being compared, internally used to show appropriate
- assertion message.
- See Also
- --------
- assert_series_equal : Equivalent method for asserting Series equality.
- DataFrame.equals : Check DataFrame equality.
- Examples
- --------
- This example shows comparing two DataFrames that are equal
- but with columns of differing dtypes.
- >>> from pandas.util.testing import assert_frame_equal
- >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
- >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]})
- df1 equals itself.
- >>> assert_frame_equal(df1, df1)
- df1 differs from df2 as column 'b' is of a different type.
- >>> assert_frame_equal(df1, df2)
- Traceback (most recent call last):
- AssertionError: Attributes are different
- Attribute "dtype" are different
- [left]: int64
- [right]: float64
- Ignore differing dtypes in columns with check_dtype.
- >>> assert_frame_equal(df1, df2, check_dtype=False)
- """
- __tracebackhide__ = True
- # instance validation
- _check_isinstance(left, right, DataFrame)
- if check_frame_type:
- # ToDo: There are some tests using rhs is SparseDataFrame
- # lhs is DataFrame. Should use assert_class_equal in future
- assert isinstance(left, type(right))
- # assert_class_equal(left, right, obj=obj)
- # shape comparison
- if left.shape != right.shape:
- raise_assert_detail(obj,
- 'DataFrame shape mismatch',
- '{shape!r}'.format(shape=left.shape),
- '{shape!r}'.format(shape=right.shape))
- if check_like:
- left, right = left.reindex_like(right), right
- # index comparison
- assert_index_equal(left.index, right.index, exact=check_index_type,
- check_names=check_names,
- check_less_precise=check_less_precise,
- check_exact=check_exact,
- check_categorical=check_categorical,
- obj='…
Large files files are truncated, but you can click here to view the full file