PageRenderTime 45ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/pandas/tests/indexing/common.py

https://github.com/neurodebian/pandas
Python | 279 lines | 218 code | 40 blank | 21 comment | 27 complexity | 6e60d2d36102f695f580bf03e408ed5c MD5 | raw file
  1. """ common utilities """
  2. import itertools
  3. from warnings import catch_warnings
  4. import numpy as np
  5. from pandas.compat import lrange
  6. from pandas.core.dtypes.common import is_scalar
  7. from pandas import Series, DataFrame, Panel, date_range, UInt64Index
  8. from pandas.util import testing as tm
  9. from pandas.io.formats.printing import pprint_thing
  10. _verbose = False
  11. def _mklbl(prefix, n):
  12. return ["%s%s" % (prefix, i) for i in range(n)]
  13. def _axify(obj, key, axis):
  14. # create a tuple accessor
  15. axes = [slice(None)] * obj.ndim
  16. axes[axis] = key
  17. return tuple(axes)
  18. class Base(object):
  19. """ indexing comprehensive base class """
  20. _objs = set(['series', 'frame', 'panel'])
  21. _typs = set(['ints', 'uints', 'labels', 'mixed',
  22. 'ts', 'floats', 'empty', 'ts_rev'])
  23. def setup_method(self, method):
  24. self.series_ints = Series(np.random.rand(4), index=lrange(0, 8, 2))
  25. self.frame_ints = DataFrame(np.random.randn(4, 4),
  26. index=lrange(0, 8, 2),
  27. columns=lrange(0, 12, 3))
  28. with catch_warnings(record=True):
  29. self.panel_ints = Panel(np.random.rand(4, 4, 4),
  30. items=lrange(0, 8, 2),
  31. major_axis=lrange(0, 12, 3),
  32. minor_axis=lrange(0, 16, 4))
  33. self.series_uints = Series(np.random.rand(4),
  34. index=UInt64Index(lrange(0, 8, 2)))
  35. self.frame_uints = DataFrame(np.random.randn(4, 4),
  36. index=UInt64Index(lrange(0, 8, 2)),
  37. columns=UInt64Index(lrange(0, 12, 3)))
  38. with catch_warnings(record=True):
  39. self.panel_uints = Panel(np.random.rand(4, 4, 4),
  40. items=UInt64Index(lrange(0, 8, 2)),
  41. major_axis=UInt64Index(lrange(0, 12, 3)),
  42. minor_axis=UInt64Index(lrange(0, 16, 4)))
  43. self.series_labels = Series(np.random.randn(4), index=list('abcd'))
  44. self.frame_labels = DataFrame(np.random.randn(4, 4),
  45. index=list('abcd'), columns=list('ABCD'))
  46. with catch_warnings(record=True):
  47. self.panel_labels = Panel(np.random.randn(4, 4, 4),
  48. items=list('abcd'),
  49. major_axis=list('ABCD'),
  50. minor_axis=list('ZYXW'))
  51. self.series_mixed = Series(np.random.randn(4), index=[2, 4, 'null', 8])
  52. self.frame_mixed = DataFrame(np.random.randn(4, 4),
  53. index=[2, 4, 'null', 8])
  54. with catch_warnings(record=True):
  55. self.panel_mixed = Panel(np.random.randn(4, 4, 4),
  56. items=[2, 4, 'null', 8])
  57. self.series_ts = Series(np.random.randn(4),
  58. index=date_range('20130101', periods=4))
  59. self.frame_ts = DataFrame(np.random.randn(4, 4),
  60. index=date_range('20130101', periods=4))
  61. with catch_warnings(record=True):
  62. self.panel_ts = Panel(np.random.randn(4, 4, 4),
  63. items=date_range('20130101', periods=4))
  64. dates_rev = (date_range('20130101', periods=4)
  65. .sort_values(ascending=False))
  66. self.series_ts_rev = Series(np.random.randn(4),
  67. index=dates_rev)
  68. self.frame_ts_rev = DataFrame(np.random.randn(4, 4),
  69. index=dates_rev)
  70. with catch_warnings(record=True):
  71. self.panel_ts_rev = Panel(np.random.randn(4, 4, 4),
  72. items=dates_rev)
  73. self.frame_empty = DataFrame({})
  74. self.series_empty = Series({})
  75. with catch_warnings(record=True):
  76. self.panel_empty = Panel({})
  77. # form agglomerates
  78. for o in self._objs:
  79. d = dict()
  80. for t in self._typs:
  81. d[t] = getattr(self, '%s_%s' % (o, t), None)
  82. setattr(self, o, d)
  83. def generate_indices(self, f, values=False):
  84. """ generate the indicies
  85. if values is True , use the axis values
  86. is False, use the range
  87. """
  88. axes = f.axes
  89. if values:
  90. axes = [lrange(len(a)) for a in axes]
  91. return itertools.product(*axes)
  92. def get_result(self, obj, method, key, axis):
  93. """ return the result for this obj with this key and this axis """
  94. if isinstance(key, dict):
  95. key = key[axis]
  96. # use an artifical conversion to map the key as integers to the labels
  97. # so ix can work for comparisions
  98. if method == 'indexer':
  99. method = 'ix'
  100. key = obj._get_axis(axis)[key]
  101. # in case we actually want 0 index slicing
  102. try:
  103. with catch_warnings(record=True):
  104. xp = getattr(obj, method).__getitem__(_axify(obj, key, axis))
  105. except:
  106. xp = getattr(obj, method).__getitem__(key)
  107. return xp
  108. def get_value(self, f, i, values=False):
  109. """ return the value for the location i """
  110. # check agains values
  111. if values:
  112. return f.values[i]
  113. # this is equiv of f[col][row].....
  114. # v = f
  115. # for a in reversed(i):
  116. # v = v.__getitem__(a)
  117. # return v
  118. with catch_warnings(record=True):
  119. return f.ix[i]
  120. def check_values(self, f, func, values=False):
  121. if f is None:
  122. return
  123. axes = f.axes
  124. indicies = itertools.product(*axes)
  125. for i in indicies:
  126. result = getattr(f, func)[i]
  127. # check agains values
  128. if values:
  129. expected = f.values[i]
  130. else:
  131. expected = f
  132. for a in reversed(i):
  133. expected = expected.__getitem__(a)
  134. tm.assert_almost_equal(result, expected)
  135. def check_result(self, name, method1, key1, method2, key2, typs=None,
  136. objs=None, axes=None, fails=None):
  137. def _eq(t, o, a, obj, k1, k2):
  138. """ compare equal for these 2 keys """
  139. if a is not None and a > obj.ndim - 1:
  140. return
  141. def _print(result, error=None):
  142. if error is not None:
  143. error = str(error)
  144. v = ("%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
  145. "key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s" %
  146. (name, result, t, o, method1, method2, a, error or ''))
  147. if _verbose:
  148. pprint_thing(v)
  149. try:
  150. rs = getattr(obj, method1).__getitem__(_axify(obj, k1, a))
  151. try:
  152. xp = self.get_result(obj, method2, k2, a)
  153. except:
  154. result = 'no comp'
  155. _print(result)
  156. return
  157. detail = None
  158. try:
  159. if is_scalar(rs) and is_scalar(xp):
  160. assert rs == xp
  161. elif xp.ndim == 1:
  162. tm.assert_series_equal(rs, xp)
  163. elif xp.ndim == 2:
  164. tm.assert_frame_equal(rs, xp)
  165. elif xp.ndim == 3:
  166. tm.assert_panel_equal(rs, xp)
  167. result = 'ok'
  168. except AssertionError as e:
  169. detail = str(e)
  170. result = 'fail'
  171. # reverse the checks
  172. if fails is True:
  173. if result == 'fail':
  174. result = 'ok (fail)'
  175. _print(result)
  176. if not result.startswith('ok'):
  177. raise AssertionError(detail)
  178. except AssertionError:
  179. raise
  180. except Exception as detail:
  181. # if we are in fails, the ok, otherwise raise it
  182. if fails is not None:
  183. if isinstance(detail, fails):
  184. result = 'ok (%s)' % type(detail).__name__
  185. _print(result)
  186. return
  187. result = type(detail).__name__
  188. raise AssertionError(_print(result, error=detail))
  189. if typs is None:
  190. typs = self._typs
  191. if objs is None:
  192. objs = self._objs
  193. if axes is not None:
  194. if not isinstance(axes, (tuple, list)):
  195. axes = [axes]
  196. else:
  197. axes = list(axes)
  198. else:
  199. axes = [0, 1, 2]
  200. # check
  201. for o in objs:
  202. if o not in self._objs:
  203. continue
  204. d = getattr(self, o)
  205. for a in axes:
  206. for t in typs:
  207. if t not in self._typs:
  208. continue
  209. obj = d[t]
  210. if obj is None:
  211. continue
  212. def _call(obj=obj):
  213. obj = obj.copy()
  214. k2 = key2
  215. _eq(t, o, a, obj, key1, k2)
  216. # Panel deprecations
  217. if isinstance(obj, Panel):
  218. with catch_warnings(record=True):
  219. _call()
  220. else:
  221. _call()