/pandas/tests/indexing/common.py
Python | 279 lines | 218 code | 40 blank | 21 comment | 27 complexity | 6e60d2d36102f695f580bf03e408ed5c MD5 | raw file
- """ common utilities """
- import itertools
- from warnings import catch_warnings
- import numpy as np
- from pandas.compat import lrange
- from pandas.core.dtypes.common import is_scalar
- from pandas import Series, DataFrame, Panel, date_range, UInt64Index
- from pandas.util import testing as tm
- from pandas.io.formats.printing import pprint_thing
- _verbose = False
- def _mklbl(prefix, n):
- return ["%s%s" % (prefix, i) for i in range(n)]
- def _axify(obj, key, axis):
- # create a tuple accessor
- axes = [slice(None)] * obj.ndim
- axes[axis] = key
- return tuple(axes)
- class Base(object):
- """ indexing comprehensive base class """
- _objs = set(['series', 'frame', 'panel'])
- _typs = set(['ints', 'uints', 'labels', 'mixed',
- 'ts', 'floats', 'empty', 'ts_rev'])
- def setup_method(self, method):
- self.series_ints = Series(np.random.rand(4), index=lrange(0, 8, 2))
- self.frame_ints = DataFrame(np.random.randn(4, 4),
- index=lrange(0, 8, 2),
- columns=lrange(0, 12, 3))
- with catch_warnings(record=True):
- self.panel_ints = Panel(np.random.rand(4, 4, 4),
- items=lrange(0, 8, 2),
- major_axis=lrange(0, 12, 3),
- minor_axis=lrange(0, 16, 4))
- self.series_uints = Series(np.random.rand(4),
- index=UInt64Index(lrange(0, 8, 2)))
- self.frame_uints = DataFrame(np.random.randn(4, 4),
- index=UInt64Index(lrange(0, 8, 2)),
- columns=UInt64Index(lrange(0, 12, 3)))
- with catch_warnings(record=True):
- self.panel_uints = Panel(np.random.rand(4, 4, 4),
- items=UInt64Index(lrange(0, 8, 2)),
- major_axis=UInt64Index(lrange(0, 12, 3)),
- minor_axis=UInt64Index(lrange(0, 16, 4)))
- self.series_labels = Series(np.random.randn(4), index=list('abcd'))
- self.frame_labels = DataFrame(np.random.randn(4, 4),
- index=list('abcd'), columns=list('ABCD'))
- with catch_warnings(record=True):
- self.panel_labels = Panel(np.random.randn(4, 4, 4),
- items=list('abcd'),
- major_axis=list('ABCD'),
- minor_axis=list('ZYXW'))
- self.series_mixed = Series(np.random.randn(4), index=[2, 4, 'null', 8])
- self.frame_mixed = DataFrame(np.random.randn(4, 4),
- index=[2, 4, 'null', 8])
- with catch_warnings(record=True):
- self.panel_mixed = Panel(np.random.randn(4, 4, 4),
- items=[2, 4, 'null', 8])
- self.series_ts = Series(np.random.randn(4),
- index=date_range('20130101', periods=4))
- self.frame_ts = DataFrame(np.random.randn(4, 4),
- index=date_range('20130101', periods=4))
- with catch_warnings(record=True):
- self.panel_ts = Panel(np.random.randn(4, 4, 4),
- items=date_range('20130101', periods=4))
- dates_rev = (date_range('20130101', periods=4)
- .sort_values(ascending=False))
- self.series_ts_rev = Series(np.random.randn(4),
- index=dates_rev)
- self.frame_ts_rev = DataFrame(np.random.randn(4, 4),
- index=dates_rev)
- with catch_warnings(record=True):
- self.panel_ts_rev = Panel(np.random.randn(4, 4, 4),
- items=dates_rev)
- self.frame_empty = DataFrame({})
- self.series_empty = Series({})
- with catch_warnings(record=True):
- self.panel_empty = Panel({})
- # form agglomerates
- for o in self._objs:
- d = dict()
- for t in self._typs:
- d[t] = getattr(self, '%s_%s' % (o, t), None)
- setattr(self, o, d)
- def generate_indices(self, f, values=False):
- """ generate the indicies
- if values is True , use the axis values
- is False, use the range
- """
- axes = f.axes
- if values:
- axes = [lrange(len(a)) for a in axes]
- return itertools.product(*axes)
- def get_result(self, obj, method, key, axis):
- """ return the result for this obj with this key and this axis """
- if isinstance(key, dict):
- key = key[axis]
- # use an artifical conversion to map the key as integers to the labels
- # so ix can work for comparisions
- if method == 'indexer':
- method = 'ix'
- key = obj._get_axis(axis)[key]
- # in case we actually want 0 index slicing
- try:
- with catch_warnings(record=True):
- xp = getattr(obj, method).__getitem__(_axify(obj, key, axis))
- except:
- xp = getattr(obj, method).__getitem__(key)
- return xp
- def get_value(self, f, i, values=False):
- """ return the value for the location i """
- # check agains values
- if values:
- return f.values[i]
- # this is equiv of f[col][row].....
- # v = f
- # for a in reversed(i):
- # v = v.__getitem__(a)
- # return v
- with catch_warnings(record=True):
- return f.ix[i]
- def check_values(self, f, func, values=False):
- if f is None:
- return
- axes = f.axes
- indicies = itertools.product(*axes)
- for i in indicies:
- result = getattr(f, func)[i]
- # check agains values
- if values:
- expected = f.values[i]
- else:
- expected = f
- for a in reversed(i):
- expected = expected.__getitem__(a)
- tm.assert_almost_equal(result, expected)
- def check_result(self, name, method1, key1, method2, key2, typs=None,
- objs=None, axes=None, fails=None):
- def _eq(t, o, a, obj, k1, k2):
- """ compare equal for these 2 keys """
- if a is not None and a > obj.ndim - 1:
- return
- def _print(result, error=None):
- if error is not None:
- error = str(error)
- v = ("%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
- "key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s" %
- (name, result, t, o, method1, method2, a, error or ''))
- if _verbose:
- pprint_thing(v)
- try:
- rs = getattr(obj, method1).__getitem__(_axify(obj, k1, a))
- try:
- xp = self.get_result(obj, method2, k2, a)
- except:
- result = 'no comp'
- _print(result)
- return
- detail = None
- try:
- if is_scalar(rs) and is_scalar(xp):
- assert rs == xp
- elif xp.ndim == 1:
- tm.assert_series_equal(rs, xp)
- elif xp.ndim == 2:
- tm.assert_frame_equal(rs, xp)
- elif xp.ndim == 3:
- tm.assert_panel_equal(rs, xp)
- result = 'ok'
- except AssertionError as e:
- detail = str(e)
- result = 'fail'
- # reverse the checks
- if fails is True:
- if result == 'fail':
- result = 'ok (fail)'
- _print(result)
- if not result.startswith('ok'):
- raise AssertionError(detail)
- except AssertionError:
- raise
- except Exception as detail:
- # if we are in fails, the ok, otherwise raise it
- if fails is not None:
- if isinstance(detail, fails):
- result = 'ok (%s)' % type(detail).__name__
- _print(result)
- return
- result = type(detail).__name__
- raise AssertionError(_print(result, error=detail))
- if typs is None:
- typs = self._typs
- if objs is None:
- objs = self._objs
- if axes is not None:
- if not isinstance(axes, (tuple, list)):
- axes = [axes]
- else:
- axes = list(axes)
- else:
- axes = [0, 1, 2]
- # check
- for o in objs:
- if o not in self._objs:
- continue
- d = getattr(self, o)
- for a in axes:
- for t in typs:
- if t not in self._typs:
- continue
- obj = d[t]
- if obj is None:
- continue
- def _call(obj=obj):
- obj = obj.copy()
- k2 = key2
- _eq(t, o, a, obj, key1, k2)
- # Panel deprecations
- if isinstance(obj, Panel):
- with catch_warnings(record=True):
- _call()
- else:
- _call()