PageRenderTime 25ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/pandas/stats/tests/test_moments.py

https://github.com/MLnick/pandas
Python | 273 lines | 196 code | 69 blank | 8 comment | 13 complexity | abead3da812eccea4e71ff8ef1ff8440 MD5 | raw file
  1. import unittest
  2. import nose
  3. from datetime import datetime
  4. from numpy.random import randn
  5. import numpy as np
  6. from pandas.core.api import Series, DataFrame, DateRange
  7. from pandas.util.testing import assert_almost_equal
  8. import pandas.core.datetools as datetools
  9. import pandas.stats.moments as moments
  10. import pandas.util.testing as tm
  11. N, K = 100, 10
  12. class TestMoments(unittest.TestCase):
  13. _nan_locs = np.arange(20, 40)
  14. _inf_locs = np.array([])
  15. def setUp(self):
  16. arr = randn(N)
  17. arr[self._nan_locs] = np.NaN
  18. self.arr = arr
  19. self.rng = DateRange(datetime(2009, 1, 1), periods=N)
  20. self.series = Series(arr.copy(), index=self.rng)
  21. self.frame = DataFrame(randn(N, K), index=self.rng,
  22. columns=np.arange(K))
  23. def test_rolling_sum(self):
  24. self._check_moment_func(moments.rolling_sum, np.sum)
  25. def test_rolling_count(self):
  26. counter = lambda x: np.isfinite(x).astype(float).sum()
  27. self._check_moment_func(moments.rolling_count, counter,
  28. has_min_periods=False,
  29. preserve_nan=False)
  30. def test_rolling_mean(self):
  31. self._check_moment_func(moments.rolling_mean, np.mean)
  32. def test_rolling_median(self):
  33. self._check_moment_func(moments.rolling_median, np.median)
  34. def test_rolling_min(self):
  35. self._check_moment_func(moments.rolling_min, np.min)
  36. def test_rolling_max(self):
  37. self._check_moment_func(moments.rolling_max, np.max)
  38. def test_rolling_quantile(self):
  39. qs = [.1, .5, .9]
  40. def scoreatpercentile(a, per):
  41. values = np.sort(a,axis=0)
  42. idx = per /1. * (values.shape[0] - 1)
  43. return values[int(idx)]
  44. for q in qs:
  45. def f(x, window, min_periods=None, time_rule=None):
  46. return moments.rolling_quantile(x, window, q,
  47. min_periods=min_periods,
  48. time_rule=time_rule)
  49. def alt(x):
  50. return scoreatpercentile(x, q)
  51. self._check_moment_func(f, alt)
  52. def test_rolling_apply(self):
  53. def roll_mean(x, window, min_periods=None, time_rule=None):
  54. return moments.rolling_apply(x, window,
  55. lambda x: x[np.isfinite(x)].mean(),
  56. min_periods=min_periods,
  57. time_rule=time_rule)
  58. self._check_moment_func(roll_mean, np.mean)
  59. def test_rolling_std(self):
  60. self._check_moment_func(moments.rolling_std,
  61. lambda x: np.std(x, ddof=1))
  62. def test_rolling_var(self):
  63. self._check_moment_func(moments.rolling_var,
  64. lambda x: np.var(x, ddof=1))
  65. def test_rolling_skew(self):
  66. try:
  67. from scipy.stats import skew
  68. except ImportError:
  69. raise nose.SkipTest('no scipy')
  70. self._check_moment_func(moments.rolling_skew,
  71. lambda x: skew(x, bias=False))
  72. def test_rolling_kurt(self):
  73. try:
  74. from scipy.stats import kurtosis
  75. except ImportError:
  76. raise nose.SkipTest('no scipy')
  77. self._check_moment_func(moments.rolling_kurt,
  78. lambda x: kurtosis(x, bias=False))
  79. def _check_moment_func(self, func, static_comp, window=50,
  80. has_min_periods=True,
  81. has_time_rule=True,
  82. preserve_nan=True):
  83. self._check_ndarray(func, static_comp, window=window,
  84. has_min_periods=has_min_periods,
  85. preserve_nan=preserve_nan)
  86. self._check_structures(func, static_comp,
  87. has_min_periods=has_min_periods,
  88. has_time_rule=has_time_rule)
  89. def _check_ndarray(self, func, static_comp, window=50,
  90. has_min_periods=True,
  91. preserve_nan=True):
  92. result = func(self.arr, window)
  93. assert_almost_equal(result[-1],
  94. static_comp(self.arr[-50:]))
  95. if preserve_nan:
  96. assert(np.isnan(result[self._nan_locs]).all())
  97. # excluding NaNs correctly
  98. arr = randn(50)
  99. arr[:10] = np.NaN
  100. arr[-10:] = np.NaN
  101. if has_min_periods:
  102. result = func(arr, 50, min_periods=30)
  103. assert_almost_equal(result[-1], static_comp(arr[10:-10]))
  104. # min_periods is working correctly
  105. result = func(arr, 20, min_periods=15)
  106. self.assert_(np.isnan(result[23]))
  107. self.assert_(not np.isnan(result[24]))
  108. self.assert_(not np.isnan(result[-6]))
  109. self.assert_(np.isnan(result[-5]))
  110. else:
  111. result = func(arr, 50)
  112. assert_almost_equal(result[-1], static_comp(arr[10:-10]))
  113. def _check_structures(self, func, static_comp,
  114. has_min_periods=True, has_time_rule=True):
  115. series_result = func(self.series, 50)
  116. self.assert_(isinstance(series_result, Series))
  117. frame_result = func(self.frame, 50)
  118. self.assertEquals(type(frame_result), DataFrame)
  119. # check time_rule works
  120. if has_time_rule:
  121. win = 25
  122. minp = 10
  123. if has_min_periods:
  124. series_result = func(self.series[::2], win, min_periods=minp,
  125. time_rule='WEEKDAY')
  126. frame_result = func(self.frame[::2], win, min_periods=minp,
  127. time_rule='WEEKDAY')
  128. else:
  129. series_result = func(self.series[::2], win, time_rule='WEEKDAY')
  130. frame_result = func(self.frame[::2], win, time_rule='WEEKDAY')
  131. last_date = series_result.index[-1]
  132. prev_date = last_date - 24 * datetools.bday
  133. trunc_series = self.series[::2].truncate(prev_date, last_date)
  134. trunc_frame = self.frame[::2].truncate(prev_date, last_date)
  135. assert_almost_equal(series_result[-1], static_comp(trunc_series))
  136. assert_almost_equal(frame_result.xs(last_date),
  137. trunc_frame.apply(static_comp))
  138. def test_ewma(self):
  139. self._check_ew(moments.ewma)
  140. def test_ewmvar(self):
  141. self._check_ew(moments.ewmvar)
  142. def test_ewmvol(self):
  143. self._check_ew(moments.ewmvol)
  144. def test_ewma_span_com_args(self):
  145. A = moments.ewma(self.arr, com=9.5)
  146. B = moments.ewma(self.arr, span=20)
  147. assert_almost_equal(A, B)
  148. self.assertRaises(Exception, moments.ewma, self.arr, com=9.5, span=20)
  149. self.assertRaises(Exception, moments.ewma, self.arr)
  150. def _check_ew(self, func):
  151. self._check_ew_ndarray(func)
  152. self._check_ew_structures(func)
  153. def _check_ew_ndarray(self, func, preserve_nan=False):
  154. result = func(self.arr, com=10)
  155. if preserve_nan:
  156. assert(np.isnan(result[self._nan_locs]).all())
  157. # excluding NaNs correctly
  158. arr = randn(50)
  159. arr[:10] = np.NaN
  160. arr[-10:] = np.NaN
  161. # ??? check something
  162. # pass in ints
  163. result2 = func(np.arange(50), span=10)
  164. self.assert_(result.dtype == np.float_)
  165. def _check_ew_structures(self, func):
  166. series_result = func(self.series, com=10)
  167. self.assert_(isinstance(series_result, Series))
  168. frame_result = func(self.frame, com=10)
  169. self.assertEquals(type(frame_result), DataFrame)
  170. # binary moments
  171. def test_rolling_cov(self):
  172. A = self.series
  173. B = A + randn(len(A))
  174. result = moments.rolling_cov(A, B, 50, min_periods=25)
  175. assert_almost_equal(result[-1], np.cov(A[-50:], B[-50:])[0, 1])
  176. def test_rolling_corr(self):
  177. A = self.series
  178. B = A + randn(len(A))
  179. result = moments.rolling_corr(A, B, 50, min_periods=25)
  180. assert_almost_equal(result[-1], np.corrcoef(A[-50:], B[-50:])[0, 1])
  181. # test for correct bias correction
  182. a = tm.makeTimeSeries()
  183. b = tm.makeTimeSeries()
  184. a[:5] = np.nan
  185. b[:10] = np.nan
  186. result = moments.rolling_corr(a, b, len(a), min_periods=1)
  187. assert_almost_equal(result[-1], a.corr(b))
  188. def test_ewmcov(self):
  189. self._check_binary_ew(moments.ewmcov)
  190. def test_ewmcorr(self):
  191. self._check_binary_ew(moments.ewmcorr)
  192. def _check_binary_ew(self, func):
  193. A = Series(randn(50), index=np.arange(50))
  194. B = A[2:] + randn(48)
  195. A[:10] = np.NaN
  196. B[-10:] = np.NaN
  197. result = func(A, B, 20, min_periods=5)
  198. self.assert_(np.isnan(result[:15]).all())
  199. self.assert_(not np.isnan(result[15:]).any())
  200. self.assertRaises(Exception, func, A, randn(50), 20, min_periods=5)
  201. if __name__ == '__main__':
  202. import nose
  203. nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],
  204. exit=False)