PageRenderTime 44ms CodeModel.GetById 11ms RepoModel.GetById 1ms app.codeStats 0ms

/statsmodels/tsa/tests/test_arima.py

http://github.com/statsmodels/statsmodels
Python | 2746 lines | 2653 code | 79 blank | 14 comment | 4 complexity | 4e2322760d7505b95873cdee2fe1766c MD5 | raw file
Possible License(s): BSD-3-Clause

Large files files are truncated, but you can click here to view the full file

  1. from statsmodels.compat.platform import (PLATFORM_OSX, PLATFORM_WIN,
  2. PLATFORM_WIN32)
  3. from statsmodels.compat.python import lrange
  4. import os
  5. import pickle
  6. import warnings
  7. from io import BytesIO
  8. import numpy as np
  9. import pandas as pd
  10. import pytest
  11. from numpy.testing import assert_almost_equal, assert_allclose, assert_raises
  12. from pandas import DatetimeIndex, date_range, period_range
  13. import statsmodels.sandbox.tsa.fftarma as fa
  14. from statsmodels.datasets.macrodata import load_pandas as load_macrodata_pandas
  15. from statsmodels.regression.linear_model import OLS
  16. from statsmodels.tools.sm_exceptions import (
  17. ValueWarning, HessianInversionWarning, SpecificationWarning,
  18. MissingDataError)
  19. from statsmodels.tools.testing import assert_equal
  20. from statsmodels.tsa.ar_model import AutoReg
  21. from statsmodels.tsa.arima_model import ARMA, ARIMA
  22. from statsmodels.tsa.arima_process import arma_generate_sample
  23. from statsmodels.tsa.arma_mle import Arma
  24. from statsmodels.tsa.tests.results import results_arma, results_arima
  25. DECIMAL_4 = 4
  26. DECIMAL_3 = 3
  27. DECIMAL_2 = 2
  28. DECIMAL_1 = 1
  29. current_path = os.path.dirname(os.path.abspath(__file__))
  30. ydata_path = os.path.join(current_path, 'results', 'y_arma_data.csv')
  31. with open(ydata_path, "rb") as fd:
  32. y_arma = np.genfromtxt(fd, delimiter=",", skip_header=1, dtype=float)
  33. cpi_dates = period_range(start='1959q1', end='2009q3', freq='Q')
  34. sun_dates = period_range(start='1700', end='2008', freq='A')
  35. cpi_predict_dates = period_range(start='2009q3', end='2015q4', freq='Q')
  36. sun_predict_dates = period_range(start='2008', end='2033', freq='A')
  37. def test_compare_arma():
  38. # this is a preliminary test to compare arma_kf, arma_cond_ls
  39. # and arma_cond_mle
  40. # the results returned by the fit methods are incomplete
  41. # for now without random.seed
  42. np.random.seed(9876565)
  43. x = fa.ArmaFft([1, -0.5], [1., 0.4], 40).generate_sample(nsample=200,
  44. burnin=1000)
  45. modkf = ARMA(x, (1, 1))
  46. reskf = modkf.fit(trend='nc', disp=-1)
  47. dres = reskf
  48. modc = Arma(x)
  49. resls = modc.fit(order=(1, 1))
  50. rescm = modc.fit_mle(order=(1, 1), start_params=[0.4, 0.4, 1.], disp=0)
  51. # decimal 1 corresponds to threshold of 5% difference
  52. # still different sign corrcted
  53. assert_almost_equal(resls[0] / dres.params, np.ones(dres.params.shape),
  54. decimal=1)
  55. # rescm also contains variance estimate as last element of params
  56. assert_almost_equal(rescm.params[:-1] / dres.params,
  57. np.ones(dres.params.shape), decimal=1)
  58. class CheckArmaResultsMixin(object):
  59. """
  60. res2 are the results from gretl. They are in results/results_arma.
  61. res1 are from statsmodels
  62. """
  63. decimal_params = DECIMAL_4
  64. def test_params(self):
  65. assert_almost_equal(self.res1.params, self.res2.params,
  66. self.decimal_params)
  67. decimal_aic = DECIMAL_4
  68. def test_aic(self):
  69. assert_almost_equal(self.res1.aic, self.res2.aic, self.decimal_aic)
  70. decimal_bic = DECIMAL_4
  71. def test_bic(self):
  72. assert_almost_equal(self.res1.bic, self.res2.bic, self.decimal_bic)
  73. decimal_arroots = DECIMAL_4
  74. def test_arroots(self):
  75. assert_almost_equal(self.res1.arroots, self.res2.arroots,
  76. self.decimal_arroots)
  77. decimal_maroots = DECIMAL_4
  78. def test_maroots(self):
  79. assert_almost_equal(self.res1.maroots, self.res2.maroots,
  80. self.decimal_maroots)
  81. decimal_bse = DECIMAL_2
  82. def test_bse(self):
  83. assert_almost_equal(self.res1.bse, self.res2.bse, self.decimal_bse)
  84. decimal_cov_params = DECIMAL_4
  85. def test_covparams(self):
  86. assert_almost_equal(self.res1.cov_params(), self.res2.cov_params,
  87. self.decimal_cov_params)
  88. decimal_hqic = DECIMAL_4
  89. def test_hqic(self):
  90. assert_almost_equal(self.res1.hqic, self.res2.hqic, self.decimal_hqic)
  91. decimal_llf = DECIMAL_4
  92. def test_llf(self):
  93. assert_almost_equal(self.res1.llf, self.res2.llf, self.decimal_llf)
  94. decimal_resid = DECIMAL_4
  95. def test_resid(self):
  96. assert_almost_equal(self.res1.resid, self.res2.resid,
  97. self.decimal_resid)
  98. decimal_fittedvalues = DECIMAL_4
  99. def test_fittedvalues(self):
  100. assert_almost_equal(self.res1.fittedvalues, self.res2.fittedvalues,
  101. self.decimal_fittedvalues)
  102. decimal_pvalues = DECIMAL_2
  103. def test_pvalues(self):
  104. assert_almost_equal(self.res1.pvalues, self.res2.pvalues,
  105. self.decimal_pvalues)
  106. decimal_t = DECIMAL_2 # only 2 decimal places in gretl output
  107. def test_tvalues(self):
  108. assert_almost_equal(self.res1.tvalues, self.res2.tvalues,
  109. self.decimal_t)
  110. decimal_sigma2 = DECIMAL_4
  111. def test_sigma2(self):
  112. assert_almost_equal(self.res1.sigma2, self.res2.sigma2,
  113. self.decimal_sigma2)
  114. @pytest.mark.smoke
  115. def test_summary(self):
  116. self.res1.summary()
  117. @pytest.mark.smoke
  118. def test_summary2(self):
  119. self.res1.summary2()
  120. class CheckForecastMixin(object):
  121. decimal_forecast = DECIMAL_4
  122. def test_forecast(self):
  123. assert_almost_equal(self.res1.forecast_res, self.res2.forecast,
  124. self.decimal_forecast)
  125. decimal_forecasterr = DECIMAL_4
  126. def test_forecasterr(self):
  127. assert_almost_equal(self.res1.forecast_err, self.res2.forecasterr,
  128. self.decimal_forecasterr)
  129. class CheckDynamicForecastMixin(object):
  130. decimal_forecast_dyn = 4
  131. def test_dynamic_forecast(self):
  132. assert_almost_equal(self.res1.forecast_res_dyn, self.res2.forecast_dyn,
  133. self.decimal_forecast_dyn)
  134. def test_forecasterr(self):
  135. assert_almost_equal(self.res1.forecast_err_dyn,
  136. self.res2.forecasterr_dyn,
  137. DECIMAL_4)
  138. class CheckArimaResultsMixin(CheckArmaResultsMixin):
  139. def test_order(self):
  140. assert self.res1.k_diff == self.res2.k_diff
  141. assert self.res1.k_ar == self.res2.k_ar
  142. assert self.res1.k_ma == self.res2.k_ma
  143. decimal_predict_levels = DECIMAL_4
  144. def test_predict_levels(self):
  145. assert_almost_equal(self.res1.predict(typ='levels'), self.res2.linear,
  146. self.decimal_predict_levels)
  147. class Test_Y_ARMA11_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
  148. @classmethod
  149. def setup_class(cls):
  150. endog = y_arma[:, 0]
  151. cls.res1 = ARMA(endog, order=(1, 1)).fit(trend='nc', disp=-1)
  152. (cls.res1.forecast_res, cls.res1.forecast_err,
  153. confint) = cls.res1.forecast(10)
  154. cls.res2 = results_arma.Y_arma11()
  155. def test_pickle(self):
  156. fh = BytesIO()
  157. # test wrapped results load save pickle
  158. self.res1.save(fh)
  159. fh.seek(0, 0)
  160. res_unpickled = self.res1.__class__.load(fh)
  161. assert type(res_unpickled) is type(self.res1) # noqa: E721
  162. class Test_Y_ARMA14_NoConst(CheckArmaResultsMixin):
  163. @classmethod
  164. def setup_class(cls):
  165. endog = y_arma[:, 1]
  166. cls.res1 = ARMA(endog, order=(1, 4)).fit(trend='nc', disp=-1)
  167. cls.res2 = results_arma.Y_arma14()
  168. @pytest.mark.slow
  169. class Test_Y_ARMA41_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
  170. @classmethod
  171. def setup_class(cls):
  172. endog = y_arma[:, 2]
  173. cls.res1 = ARMA(endog, order=(4, 1)).fit(trend='nc', disp=-1)
  174. (cls.res1.forecast_res, cls.res1.forecast_err,
  175. confint) = cls.res1.forecast(10)
  176. cls.res2 = results_arma.Y_arma41()
  177. cls.decimal_maroots = DECIMAL_3
  178. class Test_Y_ARMA22_NoConst(CheckArmaResultsMixin):
  179. @classmethod
  180. def setup_class(cls):
  181. endog = y_arma[:, 3]
  182. cls.res1 = ARMA(endog, order=(2, 2)).fit(trend='nc', disp=-1)
  183. cls.res2 = results_arma.Y_arma22()
  184. class Test_Y_ARMA50_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
  185. @classmethod
  186. def setup_class(cls):
  187. endog = y_arma[:, 4]
  188. cls.res1 = ARMA(endog, order=(5, 0)).fit(trend='nc', disp=-1)
  189. (cls.res1.forecast_res, cls.res1.forecast_err,
  190. confint) = cls.res1.forecast(10)
  191. cls.res2 = results_arma.Y_arma50()
  192. class Test_Y_ARMA02_NoConst(CheckArmaResultsMixin):
  193. @classmethod
  194. def setup_class(cls):
  195. endog = y_arma[:, 5]
  196. cls.res1 = ARMA(endog, order=(0, 2)).fit(trend='nc', disp=-1)
  197. cls.res2 = results_arma.Y_arma02()
  198. class Test_Y_ARMA11_Const(CheckArmaResultsMixin, CheckForecastMixin):
  199. @classmethod
  200. def setup_class(cls):
  201. endog = y_arma[:, 6]
  202. cls.res1 = ARMA(endog, order=(1, 1)).fit(trend="c", disp=-1)
  203. (cls.res1.forecast_res, cls.res1.forecast_err,
  204. confint) = cls.res1.forecast(10)
  205. cls.res2 = results_arma.Y_arma11c()
  206. class Test_Y_ARMA14_Const(CheckArmaResultsMixin):
  207. @classmethod
  208. def setup_class(cls):
  209. endog = y_arma[:, 7]
  210. cls.res1 = ARMA(endog, order=(1, 4)).fit(trend="c", disp=-1)
  211. cls.res2 = results_arma.Y_arma14c()
  212. class Test_Y_ARMA41_Const(CheckArmaResultsMixin, CheckForecastMixin):
  213. @classmethod
  214. def setup_class(cls):
  215. endog = y_arma[:, 8]
  216. cls.res2 = results_arma.Y_arma41c()
  217. cls.res1 = ARMA(endog, order=(4, 1)).fit(trend="c", disp=-1,
  218. start_params=cls.res2.params)
  219. (cls.res1.forecast_res, cls.res1.forecast_err,
  220. confint) = cls.res1.forecast(10)
  221. cls.decimal_cov_params = DECIMAL_3
  222. cls.decimal_fittedvalues = DECIMAL_3
  223. cls.decimal_resid = DECIMAL_3
  224. cls.decimal_params = DECIMAL_3
  225. class Test_Y_ARMA22_Const(CheckArmaResultsMixin):
  226. @classmethod
  227. def setup_class(cls):
  228. endog = y_arma[:, 9]
  229. cls.res1 = ARMA(endog, order=(2, 2)).fit(trend="c", disp=-1)
  230. cls.res2 = results_arma.Y_arma22c()
  231. def test_summary(self):
  232. # regression test for html of roots table #4434
  233. # we ignore whitespace in the assert
  234. summ = self.res1.summary()
  235. summ_roots = """\
  236. <tableclass="simpletable">
  237. <caption>Roots</caption>
  238. <tr>
  239. <td></td><th>Real</th><th>Imaginary</th><th>Modulus</th><th>Frequency</th>
  240. </tr>
  241. <tr>
  242. <th>AR.1</th><td>1.0991</td><td>-1.2571j</td><td>1.6698</td><td>-0.1357</td>
  243. </tr>
  244. <tr>
  245. <th>AR.2</th><td>1.0991</td><td>+1.2571j</td><td>1.6698</td><td>0.1357</td>
  246. </tr>
  247. <tr>
  248. <th>MA.1</th><td>-1.1702</td><td>+0.0000j</td><td>1.1702</td><td>0.5000</td>
  249. </tr>
  250. <tr>
  251. <th>MA.2</th><td>1.2215</td><td>+0.0000j</td><td>1.2215</td><td>0.0000</td>
  252. </tr>
  253. </table>"""
  254. assert_equal(summ.tables[2]._repr_html_().replace(' ', ''),
  255. summ_roots.replace(' ', ''))
  256. class Test_Y_ARMA50_Const(CheckArmaResultsMixin, CheckForecastMixin):
  257. @classmethod
  258. def setup_class(cls):
  259. endog = y_arma[:, 10]
  260. cls.res1 = ARMA(endog, order=(5, 0)).fit(trend="c", disp=-1)
  261. (cls.res1.forecast_res, cls.res1.forecast_err,
  262. confint) = cls.res1.forecast(10)
  263. cls.res2 = results_arma.Y_arma50c()
  264. class Test_Y_ARMA02_Const(CheckArmaResultsMixin):
  265. @classmethod
  266. def setup_class(cls):
  267. endog = y_arma[:, 11]
  268. cls.res1 = ARMA(endog, order=(0, 2)).fit(trend="c", disp=-1)
  269. cls.res2 = results_arma.Y_arma02c()
  270. # cov_params and tvalues are off still but not as much vs. R
  271. class Test_Y_ARMA11_NoConst_CSS(CheckArmaResultsMixin):
  272. @classmethod
  273. def setup_class(cls):
  274. endog = y_arma[:, 0]
  275. cls.res1 = ARMA(endog, order=(1, 1)).fit(method="css", trend='nc',
  276. disp=-1)
  277. cls.res2 = results_arma.Y_arma11("css")
  278. cls.decimal_t = DECIMAL_1
  279. # better vs. R
  280. class Test_Y_ARMA14_NoConst_CSS(CheckArmaResultsMixin):
  281. @classmethod
  282. def setup_class(cls):
  283. endog = y_arma[:, 1]
  284. cls.res1 = ARMA(endog, order=(1, 4)).fit(method="css", trend='nc',
  285. disp=-1)
  286. cls.res2 = results_arma.Y_arma14("css")
  287. cls.decimal_fittedvalues = DECIMAL_3
  288. cls.decimal_resid = DECIMAL_3
  289. cls.decimal_t = DECIMAL_1
  290. # bse, etc. better vs. R
  291. # maroot is off because maparams is off a bit (adjust tolerance?)
  292. class Test_Y_ARMA41_NoConst_CSS(CheckArmaResultsMixin):
  293. @classmethod
  294. def setup_class(cls):
  295. endog = y_arma[:, 2]
  296. cls.res1 = ARMA(endog, order=(4, 1)).fit(method="css", trend='nc',
  297. disp=-1)
  298. cls.res2 = results_arma.Y_arma41("css")
  299. cls.decimal_t = DECIMAL_1
  300. cls.decimal_pvalues = 0
  301. cls.decimal_cov_params = DECIMAL_3
  302. cls.decimal_maroots = DECIMAL_1
  303. # same notes as above
  304. class Test_Y_ARMA22_NoConst_CSS(CheckArmaResultsMixin):
  305. @classmethod
  306. def setup_class(cls):
  307. endog = y_arma[:, 3]
  308. cls.res1 = ARMA(endog, order=(2, 2)).fit(method="css", trend='nc',
  309. disp=-1)
  310. cls.res2 = results_arma.Y_arma22("css")
  311. cls.decimal_t = DECIMAL_1
  312. cls.decimal_resid = DECIMAL_3
  313. cls.decimal_pvalues = DECIMAL_1
  314. cls.decimal_fittedvalues = DECIMAL_3
  315. # NOTE: gretl just uses least squares for AR CSS
  316. # so BIC, etc. is
  317. # -2*res1.llf + np.log(nobs)*(res1.q+res1.p+res1.k)
  318. # with no adjustment for p and no extra sigma estimate
  319. # NOTE: so our tests use x-12 arima results which agree with us and are
  320. # consistent with the rest of the models
  321. class Test_Y_ARMA50_NoConst_CSS(CheckArmaResultsMixin):
  322. @classmethod
  323. def setup_class(cls):
  324. endog = y_arma[:, 4]
  325. cls.res1 = ARMA(endog, order=(5, 0)).fit(method="css", trend='nc',
  326. disp=-1)
  327. cls.res2 = results_arma.Y_arma50("css")
  328. cls.decimal_t = 0
  329. cls.decimal_llf = DECIMAL_1 # looks like rounding error?
  330. class Test_Y_ARMA02_NoConst_CSS(CheckArmaResultsMixin):
  331. @classmethod
  332. def setup_class(cls):
  333. endog = y_arma[:, 5]
  334. cls.res1 = ARMA(endog, order=(0, 2)).fit(method="css", trend='nc',
  335. disp=-1)
  336. cls.res2 = results_arma.Y_arma02("css")
  337. # NOTE: our results are close to --x-12-arima option and R
  338. class Test_Y_ARMA11_Const_CSS(CheckArmaResultsMixin):
  339. @classmethod
  340. def setup_class(cls):
  341. endog = y_arma[:, 6]
  342. cls.res1 = ARMA(endog, order=(1, 1)).fit(trend="c", method="css",
  343. disp=-1)
  344. cls.res2 = results_arma.Y_arma11c("css")
  345. cls.decimal_params = DECIMAL_3
  346. cls.decimal_cov_params = DECIMAL_3
  347. cls.decimal_t = DECIMAL_1
  348. class Test_Y_ARMA14_Const_CSS(CheckArmaResultsMixin):
  349. @classmethod
  350. def setup_class(cls):
  351. endog = y_arma[:, 7]
  352. cls.res1 = ARMA(endog, order=(1, 4)).fit(trend="c", method="css",
  353. disp=-1)
  354. cls.res2 = results_arma.Y_arma14c("css")
  355. cls.decimal_t = DECIMAL_1
  356. cls.decimal_pvalues = DECIMAL_1
  357. class Test_Y_ARMA41_Const_CSS(CheckArmaResultsMixin):
  358. @classmethod
  359. def setup_class(cls):
  360. endog = y_arma[:, 8]
  361. cls.res1 = ARMA(endog, order=(4, 1)).fit(trend="c", method="css",
  362. disp=-1)
  363. cls.res2 = results_arma.Y_arma41c("css")
  364. cls.decimal_t = DECIMAL_1
  365. cls.decimal_cov_params = DECIMAL_1
  366. cls.decimal_maroots = DECIMAL_3
  367. cls.decimal_bse = DECIMAL_1
  368. class Test_Y_ARMA22_Const_CSS(CheckArmaResultsMixin):
  369. @classmethod
  370. def setup_class(cls):
  371. endog = y_arma[:, 9]
  372. cls.res1 = ARMA(endog, order=(2, 2)).fit(trend="c", method="css",
  373. disp=-1)
  374. cls.res2 = results_arma.Y_arma22c("css")
  375. cls.decimal_t = 0
  376. cls.decimal_pvalues = DECIMAL_1
  377. class Test_Y_ARMA50_Const_CSS(CheckArmaResultsMixin):
  378. @classmethod
  379. def setup_class(cls):
  380. endog = y_arma[:, 10]
  381. cls.res1 = ARMA(endog, order=(5, 0)).fit(trend="c", method="css",
  382. disp=-1)
  383. cls.res2 = results_arma.Y_arma50c("css")
  384. cls.decimal_t = DECIMAL_1
  385. cls.decimal_params = DECIMAL_3
  386. cls.decimal_cov_params = DECIMAL_2
  387. class Test_Y_ARMA02_Const_CSS(CheckArmaResultsMixin):
  388. @classmethod
  389. def setup_class(cls):
  390. endog = y_arma[:, 11]
  391. cls.res1 = ARMA(endog, order=(0, 2)).fit(trend="c", method="css",
  392. disp=-1)
  393. cls.res2 = results_arma.Y_arma02c("css")
  394. def test_reset_trend_error():
  395. endog = y_arma[:, 0]
  396. mod = ARMA(endog, order=(1, 1))
  397. mod.fit(trend="c", disp=-1)
  398. with pytest.raises(RuntimeError):
  399. mod.fit(trend="nc", disp=-1)
  400. @pytest.mark.slow
  401. def test_start_params_bug():
  402. data = np.array([1368., 1187, 1090, 1439, 2362, 2783, 2869, 2512, 1804,
  403. 1544, 1028, 869, 1737, 2055, 1947, 1618, 1196, 867, 997,
  404. 1862, 2525,
  405. 3250, 4023, 4018, 3585, 3004, 2500, 2441, 2749, 2466,
  406. 2157, 1847, 1463,
  407. 1146, 851, 993, 1448, 1719, 1709, 1455, 1950, 1763, 2075,
  408. 2343, 3570,
  409. 4690, 3700, 2339, 1679, 1466, 998, 853, 835, 922, 851,
  410. 1125, 1299, 1105,
  411. 860, 701, 689, 774, 582, 419, 846, 1132, 902, 1058, 1341,
  412. 1551, 1167,
  413. 975, 786, 759, 751, 649, 876, 720, 498, 553, 459, 543,
  414. 447, 415, 377,
  415. 373, 324, 320, 306, 259, 220, 342, 558, 825, 994, 1267,
  416. 1473, 1601,
  417. 1896, 1890, 2012, 2198, 2393, 2825, 3411, 3406, 2464,
  418. 2891, 3685, 3638,
  419. 3746, 3373, 3190, 2681, 2846, 4129, 5054, 5002, 4801,
  420. 4934, 4903, 4713,
  421. 4745, 4736, 4622, 4642, 4478, 4510, 4758, 4457, 4356,
  422. 4170, 4658, 4546,
  423. 4402, 4183, 3574, 2586, 3326, 3948, 3983, 3997, 4422,
  424. 4496, 4276, 3467,
  425. 2753, 2582, 2921, 2768, 2789, 2824, 2482, 2773, 3005,
  426. 3641, 3699, 3774,
  427. 3698, 3628, 3180, 3306, 2841, 2014, 1910, 2560, 2980,
  428. 3012, 3210, 3457,
  429. 3158, 3344, 3609, 3327, 2913, 2264, 2326, 2596, 2225,
  430. 1767, 1190, 792,
  431. 669, 589, 496, 354, 246, 250, 323, 495, 924, 1536, 2081,
  432. 2660, 2814, 2992,
  433. 3115, 2962, 2272, 2151, 1889, 1481, 955, 631, 288, 103,
  434. 60, 82, 107, 185,
  435. 618, 1526, 2046, 2348, 2584, 2600, 2515, 2345, 2351, 2355,
  436. 2409, 2449,
  437. 2645, 2918, 3187, 2888, 2610, 2740, 2526, 2383, 2936,
  438. 2968, 2635, 2617,
  439. 2790, 3906, 4018, 4797, 4919, 4942, 4656, 4444, 3898,
  440. 3908, 3678, 3605,
  441. 3186, 2139, 2002, 1559, 1235, 1183, 1096, 673, 389, 223,
  442. 352, 308, 365,
  443. 525, 779, 894, 901, 1025, 1047, 981, 902, 759, 569, 519,
  444. 408, 263, 156,
  445. 72, 49, 31, 41, 192, 423, 492, 552, 564, 723, 921, 1525,
  446. 2768, 3531, 3824,
  447. 3835, 4294, 4533, 4173, 4221, 4064, 4641, 4685, 4026,
  448. 4323, 4585, 4836,
  449. 4822, 4631, 4614, 4326, 4790, 4736, 4104, 5099, 5154,
  450. 5121, 5384, 5274,
  451. 5225, 4899, 5382, 5295, 5349, 4977, 4597, 4069, 3733,
  452. 3439, 3052, 2626,
  453. 1939, 1064, 713, 916, 832, 658, 817, 921, 772, 764, 824,
  454. 967, 1127, 1153,
  455. 824, 912, 957, 990, 1218, 1684, 2030, 2119, 2233, 2657,
  456. 2652, 2682, 2498,
  457. 2429, 2346, 2298, 2129, 1829, 1816, 1225, 1010, 748, 627,
  458. 469, 576, 532,
  459. 475, 582, 641, 605, 699, 680, 714, 670, 666, 636, 672,
  460. 679, 446, 248, 134,
  461. 160, 178, 286, 413, 676, 1025, 1159, 952, 1398, 1833,
  462. 2045, 2072, 1798,
  463. 1799, 1358, 727, 353, 347, 844, 1377, 1829, 2118, 2272,
  464. 2745, 4263, 4314,
  465. 4530, 4354, 4645, 4547, 5391, 4855, 4739, 4520, 4573,
  466. 4305, 4196, 3773,
  467. 3368, 2596, 2596, 2305, 2756, 3747, 4078, 3415, 2369,
  468. 2210, 2316, 2263,
  469. 2672, 3571, 4131, 4167, 4077, 3924, 3738, 3712, 3510,
  470. 3182, 3179, 2951,
  471. 2453, 2078, 1999, 2486, 2581, 1891, 1997, 1366, 1294,
  472. 1536, 2794, 3211,
  473. 3242, 3406, 3121, 2425, 2016, 1787, 1508, 1304, 1060,
  474. 1342, 1589, 2361,
  475. 3452, 2659, 2857, 3255, 3322, 2852, 2964, 3132, 3033,
  476. 2931, 2636, 2818,
  477. 3310, 3396, 3179, 3232, 3543, 3759, 3503, 3758, 3658,
  478. 3425, 3053, 2620,
  479. 1837, 923, 712, 1054, 1376, 1556, 1498, 1523, 1088, 728,
  480. 890, 1413, 2524,
  481. 3295, 4097, 3993, 4116, 3874, 4074, 4142, 3975, 3908,
  482. 3907, 3918, 3755,
  483. 3648, 3778, 4293, 4385, 4360, 4352, 4528, 4365, 3846,
  484. 4098, 3860, 3230,
  485. 2820, 2916, 3201, 3721, 3397, 3055, 2141, 1623, 1825,
  486. 1716, 2232, 2939,
  487. 3735, 4838, 4560, 4307, 4975, 5173, 4859, 5268, 4992,
  488. 5100, 5070, 5270,
  489. 4760, 5135, 5059, 4682, 4492, 4933, 4737, 4611, 4634,
  490. 4789, 4811, 4379,
  491. 4689, 4284, 4191, 3313, 2770, 2543, 3105, 2967, 2420,
  492. 1996, 2247, 2564,
  493. 2726, 3021, 3427, 3509, 3759, 3324, 2988, 2849, 2340,
  494. 2443, 2364, 1252,
  495. 623, 742, 867, 684, 488, 348, 241, 187, 279, 355, 423,
  496. 678, 1375, 1497,
  497. 1434, 2116, 2411, 1929, 1628, 1635, 1609, 1757, 2090,
  498. 2085, 1790, 1846,
  499. 2038, 2360, 2342, 2401, 2920, 3030, 3132, 4385, 5483,
  500. 5865, 5595, 5485,
  501. 5727, 5553, 5560, 5233, 5478, 5159, 5155, 5312, 5079,
  502. 4510, 4628, 4535,
  503. 3656, 3698, 3443, 3146, 2562, 2304, 2181, 2293, 1950,
  504. 1930, 2197, 2796,
  505. 3441, 3649, 3815, 2850, 4005, 5305, 5550, 5641, 4717,
  506. 5131, 2831, 3518,
  507. 3354, 3115, 3515, 3552, 3244, 3658, 4407, 4935, 4299,
  508. 3166, 3335, 2728,
  509. 2488, 2573, 2002, 1717, 1645, 1977, 2049, 2125, 2376,
  510. 2551, 2578, 2629,
  511. 2750, 3150, 3699, 4062, 3959, 3264, 2671, 2205, 2128,
  512. 2133, 2095, 1964,
  513. 2006, 2074, 2201, 2506, 2449, 2465, 2064, 1446, 1382, 983,
  514. 898, 489, 319,
  515. 383, 332, 276, 224, 144, 101, 232, 429, 597, 750, 908,
  516. 960, 1076, 951,
  517. 1062, 1183, 1404, 1391, 1419, 1497, 1267, 963, 682, 777,
  518. 906, 1149, 1439,
  519. 1600, 1876, 1885, 1962, 2280, 2711, 2591, 2411])
  520. with warnings.catch_warnings():
  521. warnings.simplefilter("ignore")
  522. ARMA(data, order=(4, 1)).fit(start_ar_lags=5, disp=-1)
  523. class Test_ARIMA101(CheckArmaResultsMixin):
  524. @classmethod
  525. def setup_class(cls):
  526. endog = y_arma[:, 6]
  527. cls.res1 = ARIMA(endog, (1, 0, 1)).fit(trend="c", disp=-1)
  528. (cls.res1.forecast_res, cls.res1.forecast_err,
  529. confint) = cls.res1.forecast(10)
  530. cls.res2 = results_arma.Y_arma11c()
  531. cls.res2.k_diff = 0
  532. cls.res2.k_ar = 1
  533. cls.res2.k_ma = 1
  534. class Test_ARIMA111(CheckArimaResultsMixin, CheckForecastMixin,
  535. CheckDynamicForecastMixin):
  536. @classmethod
  537. def setup_class(cls):
  538. cpi = load_macrodata_pandas().data['cpi'].values
  539. cls.res1 = ARIMA(cpi, (1, 1, 1)).fit(disp=-1)
  540. cls.res2 = results_arima.ARIMA111()
  541. # make sure endog names changes to D.cpi
  542. cls.decimal_llf = 3
  543. cls.decimal_aic = 3
  544. cls.decimal_bic = 3
  545. # TODO: why has dec_cov_params changed, used to be better
  546. cls.decimal_cov_params = 2
  547. cls.decimal_t = 0
  548. (cls.res1.forecast_res,
  549. cls.res1.forecast_err,
  550. conf_int) = cls.res1.forecast(25)
  551. # TODO: fix the indexing for the end here, I do not think this is right
  552. # if we're going to treat it like indexing
  553. # the forecast from 2005Q1 through 2009Q4 is indices
  554. # 184 through 227 not 226
  555. # note that the first one counts in the count so 164 + 64 is 65
  556. # predictions
  557. cls.res1.forecast_res_dyn = cls.res1.predict(start=164, end=164 + 63,
  558. typ='levels',
  559. dynamic=True)
  560. def test_freq(self):
  561. assert_almost_equal(self.res1.arfreq, [0.0000], 4)
  562. assert_almost_equal(self.res1.mafreq, [0.0000], 4)
  563. class Test_ARIMA111CSS(CheckArimaResultsMixin, CheckForecastMixin,
  564. CheckDynamicForecastMixin):
  565. @classmethod
  566. def setup_class(cls):
  567. cpi = load_macrodata_pandas().data['cpi'].values
  568. cls.res1 = ARIMA(cpi, (1, 1, 1)).fit(disp=-1, method='css')
  569. cls.res2 = results_arima.ARIMA111(method='css')
  570. cls.res2.fittedvalues = - cpi[1:-1] + cls.res2.linear
  571. # make sure endog names changes to D.cpi
  572. (cls.res1.forecast_res,
  573. cls.res1.forecast_err,
  574. conf_int) = cls.res1.forecast(25)
  575. cls.decimal_forecast = 2
  576. cls.decimal_forecast_dyn = 2
  577. cls.decimal_forecasterr = 3
  578. cls.res1.forecast_res_dyn = cls.res1.predict(start=164, end=164 + 63,
  579. typ='levels',
  580. dynamic=True)
  581. # precisions
  582. cls.decimal_arroots = 3
  583. cls.decimal_cov_params = 3
  584. cls.decimal_hqic = 3
  585. cls.decimal_maroots = 3
  586. cls.decimal_t = 1
  587. cls.decimal_fittedvalues = 2 # because of rounding when copying
  588. cls.decimal_resid = 2
  589. cls.decimal_predict_levels = DECIMAL_2
  590. class Test_ARIMA112CSS(CheckArimaResultsMixin):
  591. @classmethod
  592. def setup_class(cls):
  593. cpi = load_macrodata_pandas().data['cpi'].values
  594. cls.res1 = ARIMA(cpi, (1, 1, 2)).fit(disp=-1, method='css',
  595. start_params=[.905322, -.692425,
  596. 1.07366,
  597. 0.172024])
  598. cls.res2 = results_arima.ARIMA112(method='css')
  599. cls.res2.fittedvalues = - cpi[1:-1] + cls.res2.linear
  600. # make sure endog names changes to D.cpi
  601. cls.decimal_llf = 3
  602. cls.decimal_aic = 3
  603. cls.decimal_bic = 3
  604. # TODO: fix the indexing for the end here, I do not think this is right
  605. # if we're going to treat it like indexing
  606. # the forecast from 2005Q1 through 2009Q4 is indices
  607. # 184 through 227 not 226
  608. # note that the first one counts in the count so 164 + 64 is 65
  609. # predictions
  610. # cls.res1.forecast_res_dyn = self.predict(start=164, end=164+63,
  611. # typ='levels', dynamic=True)
  612. # since we got from gretl do not have linear prediction in differences
  613. cls.decimal_arroots = 3
  614. cls.decimal_maroots = 2
  615. cls.decimal_t = 1
  616. cls.decimal_resid = 2
  617. cls.decimal_fittedvalues = 3
  618. cls.decimal_predict_levels = DECIMAL_3
  619. def test_freq(self):
  620. assert_almost_equal(self.res1.arfreq, [0.5000], 4)
  621. assert_almost_equal(self.res1.mafreq, [0.5000, 0.5000], 4)
  622. def test_arima_predict_mle_dates():
  623. cpi = load_macrodata_pandas().data['cpi'].values
  624. res1 = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q').fit(disp=-1)
  625. file_path = os.path.join(current_path, 'results',
  626. 'results_arima_forecasts_all_mle.csv')
  627. with open(file_path, "rb") as test_data:
  628. arima_forecasts = np.genfromtxt(test_data, delimiter=",",
  629. skip_header=1, dtype=float)
  630. fc = arima_forecasts[:, 0]
  631. fcdyn = arima_forecasts[:, 1]
  632. fcdyn2 = arima_forecasts[:, 2]
  633. start, end = 2, 51
  634. fv = res1.predict('1959Q3', '1971Q4', typ='levels')
  635. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  636. assert_equal(res1.data.predict_dates, cpi_dates[start:end + 1])
  637. start, end = 202, 227
  638. fv = res1.predict('2009Q3', '2015Q4', typ='levels')
  639. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  640. assert_equal(res1.data.predict_dates, cpi_predict_dates)
  641. # make sure dynamic works
  642. start, end = '1960q2', '1971q4'
  643. fv = res1.predict(start, end, dynamic=True, typ='levels')
  644. assert_almost_equal(fv, fcdyn[5:51 + 1], DECIMAL_4)
  645. start, end = '1965q1', '2015q4'
  646. fv = res1.predict(start, end, dynamic=True, typ='levels')
  647. assert_almost_equal(fv, fcdyn2[24:227 + 1], DECIMAL_4)
  648. def test_arma_predict_mle_dates():
  649. from statsmodels.datasets.sunspots import load_pandas
  650. sunspots = load_pandas().data['SUNACTIVITY'].values
  651. mod = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
  652. mod.method = 'mle'
  653. assert_raises(ValueError, mod._get_prediction_index, '1701', '1751', True)
  654. start, end = 2, 51
  655. mod._get_prediction_index('1702', '1751', False)
  656. assert_equal(mod.data.predict_dates, sun_dates[start:end + 1])
  657. start, end = 308, 333
  658. mod._get_prediction_index('2008', '2033', False)
  659. assert_equal(mod.data.predict_dates, sun_predict_dates)
  660. def test_arima_predict_css_dates():
  661. cpi = load_macrodata_pandas().data['cpi'].values
  662. res1 = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q').fit(disp=-1,
  663. method='css',
  664. trend='nc')
  665. params = np.array([1.231272508473910,
  666. -0.282516097759915,
  667. 0.170052755782440,
  668. -0.118203728504945,
  669. -0.938783134717947])
  670. file_path = os.path.join(current_path, 'results',
  671. 'results_arima_forecasts_all_css.csv')
  672. with open(file_path, "rb") as test_data:
  673. arima_forecasts = np.genfromtxt(test_data, delimiter=",",
  674. skip_header=1, dtype=float)
  675. fc = arima_forecasts[:, 0]
  676. fcdyn = arima_forecasts[:, 1]
  677. fcdyn2 = arima_forecasts[:, 2]
  678. start, end = 5, 51
  679. fv = res1.model.predict(params, '1960Q2', '1971Q4', typ='levels')
  680. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  681. assert_equal(res1.data.predict_dates, cpi_dates[start:end + 1])
  682. start, end = 202, 227
  683. fv = res1.model.predict(params, '2009Q3', '2015Q4', typ='levels')
  684. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  685. assert_equal(res1.data.predict_dates, cpi_predict_dates)
  686. # make sure dynamic works
  687. start, end = 5, 51
  688. fv = res1.model.predict(params, '1960Q2', '1971Q4', typ='levels',
  689. dynamic=True)
  690. assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
  691. start, end = '1965q1', '2015q4'
  692. fv = res1.model.predict(params, start, end, dynamic=True, typ='levels')
  693. assert_almost_equal(fv, fcdyn2[24:227 + 1], DECIMAL_4)
  694. def test_arma_predict_css_dates():
  695. from statsmodels.datasets.sunspots import load_pandas
  696. sunspots = load_pandas().data['SUNACTIVITY'].values
  697. mod = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
  698. mod.method = 'css'
  699. assert_raises(ValueError, mod._get_prediction_index, '1701', '1751', False)
  700. def test_arima_predict_mle():
  701. cpi = load_macrodata_pandas().data['cpi'].values
  702. res1 = ARIMA(cpi, (4, 1, 1)).fit(disp=-1)
  703. # fit the model so that we get correct endog length but use
  704. file_path = os.path.join(current_path, 'results',
  705. 'results_arima_forecasts_all_mle.csv')
  706. with open(file_path, "rb") as test_data:
  707. arima_forecasts = np.genfromtxt(test_data, delimiter=",",
  708. skip_header=1, dtype=float)
  709. fc = arima_forecasts[:, 0]
  710. fcdyn = arima_forecasts[:, 1]
  711. fcdyn2 = arima_forecasts[:, 2]
  712. fcdyn3 = arima_forecasts[:, 3]
  713. fcdyn4 = arima_forecasts[:, 4]
  714. # 0 indicates the first sample-observation below
  715. # ie., the index after the pre-sample, these are also differenced once
  716. # so the indices are moved back once from the cpi in levels
  717. # start < p, end <p 1959q2 - 1959q4
  718. start, end = 1, 3
  719. fv = res1.predict(start, end, typ='levels')
  720. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  721. # start < p, end 0 1959q3 - 1960q1
  722. start, end = 2, 4
  723. fv = res1.predict(start, end, typ='levels')
  724. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  725. # start < p, end >0 1959q3 - 1971q4
  726. start, end = 2, 51
  727. fv = res1.predict(start, end, typ='levels')
  728. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  729. # start < p, end nobs 1959q3 - 2009q3
  730. start, end = 2, 202
  731. fv = res1.predict(start, end, typ='levels')
  732. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  733. # start < p, end >nobs 1959q3 - 2015q4
  734. start, end = 2, 227
  735. fv = res1.predict(start, end, typ='levels')
  736. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  737. # start 0, end >0 1960q1 - 1971q4
  738. start, end = 4, 51
  739. fv = res1.predict(start, end, typ='levels')
  740. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  741. # start 0, end nobs 1960q1 - 2009q3
  742. start, end = 4, 202
  743. fv = res1.predict(start, end, typ='levels')
  744. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  745. # start 0, end >nobs 1960q1 - 2015q4
  746. start, end = 4, 227
  747. fv = res1.predict(start, end, typ='levels')
  748. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  749. # start >p, end >0 1965q1 - 1971q4
  750. start, end = 24, 51
  751. fv = res1.predict(start, end, typ='levels')
  752. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  753. # start >p, end nobs 1965q1 - 2009q3
  754. start, end = 24, 202
  755. fv = res1.predict(start, end, typ='levels')
  756. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  757. # start >p, end >nobs 1965q1 - 2015q4
  758. start, end = 24, 227
  759. fv = res1.predict(start, end, typ='levels')
  760. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  761. # start nobs, end nobs 2009q3 - 2009q3
  762. start, end = 202, 202
  763. fv = res1.predict(start, end, typ='levels')
  764. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_3)
  765. # start nobs, end >nobs 2009q3 - 2015q4
  766. start, end = 202, 227
  767. fv = res1.predict(start, end, typ='levels')
  768. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_3)
  769. # start >nobs, end >nobs 2009q4 - 2015q4
  770. start, end = 203, 227
  771. fv = res1.predict(start, end, typ='levels')
  772. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  773. # defaults
  774. start, end = None, None
  775. fv = res1.predict(start, end, typ='levels')
  776. assert_almost_equal(fv, fc[1:203], DECIMAL_4)
  777. # Dynamic
  778. # start < p, end <p 1959q2 - 1959q4
  779. start, end = 1, 3
  780. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  781. fv = res1.predict(start, end, dynamic=True, typ='levels')
  782. # start < p, end 0 1959q3 - 1960q1
  783. start, end = 2, 4
  784. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  785. res1.predict(start, end, dynamic=True, typ='levels')
  786. # start < p, end >0 1959q3 - 1971q4
  787. start, end = 2, 51
  788. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  789. res1.predict(start, end, dynamic=True, typ='levels')
  790. # start < p, end nobs 1959q3 - 2009q3
  791. start, end = 2, 202
  792. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  793. res1.predict(start, end, dynamic=True, typ='levels')
  794. # start < p, end >nobs 1959q3 - 2015q4
  795. start, end = 2, 227
  796. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  797. res1.predict(start, end, dynamic=True, typ='levels')
  798. # start 0, end >0 1960q1 - 1971q4
  799. start, end = 5, 51
  800. fv = res1.predict(start, end, dynamic=True, typ='levels')
  801. assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
  802. # start 0, end nobs 1960q1 - 2009q3
  803. start, end = 5, 202
  804. fv = res1.predict(start, end, dynamic=True, typ='levels')
  805. assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
  806. # start 0, end >nobs 1960q1 - 2015q4
  807. start, end = 5, 227
  808. fv = res1.predict(start, end, dynamic=True, typ='levels')
  809. assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
  810. # start >p, end >0 1965q1 - 1971q4
  811. start, end = 24, 51
  812. fv = res1.predict(start, end, dynamic=True, typ='levels')
  813. assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
  814. # start >p, end nobs 1965q1 - 2009q3
  815. start, end = 24, 202
  816. fv = res1.predict(start, end, dynamic=True, typ='levels')
  817. assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
  818. # start >p, end >nobs 1965q1 - 2015q4
  819. start, end = 24, 227
  820. fv = res1.predict(start, end, dynamic=True, typ='levels')
  821. assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
  822. # start nobs, end nobs 2009q3 - 2009q3
  823. start, end = 202, 202
  824. fv = res1.predict(start, end, dynamic=True, typ='levels')
  825. assert_almost_equal(fv, fcdyn3[start:end + 1], DECIMAL_4)
  826. # start nobs, end >nobs 2009q3 - 2015q4
  827. start, end = 202, 227
  828. fv = res1.predict(start, end, dynamic=True, typ='levels')
  829. assert_almost_equal(fv, fcdyn3[start:end + 1], DECIMAL_4)
  830. # start >nobs, end >nobs 2009q4 - 2015q4
  831. start, end = 203, 227
  832. fv = res1.predict(start, end, dynamic=True, typ='levels')
  833. assert_almost_equal(fv, fcdyn4[start:end + 1], DECIMAL_4)
  834. # defaults
  835. start, end = None, None
  836. fv = res1.predict(start, end, dynamic=True, typ='levels')
  837. assert_almost_equal(fv, fcdyn[5:203], DECIMAL_4)
  838. def _check_start(model, given, expected, dynamic):
  839. start, _, _, _ = model._get_prediction_index(given, None, dynamic)
  840. assert_equal(start, expected)
  841. def _check_end(model, given, end_expect, out_of_sample_expect):
  842. _, end, out_of_sample, _ = model._get_prediction_index(None, given, False)
  843. assert_equal((end, out_of_sample), (end_expect, out_of_sample_expect))
  844. def test_arma_predict_indices():
  845. from statsmodels.datasets.sunspots import load_pandas
  846. sunspots = load_pandas().data['SUNACTIVITY'].values
  847. model = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
  848. model.method = 'mle'
  849. # raises - pre-sample + dynamic
  850. with pytest.raises(ValueError):
  851. model._get_prediction_index(0, None, True)
  852. with pytest.raises(ValueError):
  853. model._get_prediction_index(8, None, True)
  854. with pytest.raises(ValueError):
  855. model._get_prediction_index('1700', None, True)
  856. with pytest.raises(ValueError):
  857. model._get_prediction_index('1708', None, True)
  858. # raises - start out of sample
  859. # works - in-sample
  860. # None
  861. start_test_cases = [
  862. # given, expected, dynamic
  863. (None, 9, True),
  864. # all start get moved back by k_diff
  865. (9, 9, True),
  866. (10, 10, True),
  867. # what about end of sample start - last value is first
  868. # forecast
  869. (309, 309, True),
  870. (308, 308, True),
  871. (0, 0, False),
  872. (1, 1, False),
  873. (4, 4, False),
  874. # all start get moved back by k_diff
  875. ('1709', 9, True),
  876. ('1710', 10, True),
  877. # what about end of sample start - last value is first
  878. # forecast
  879. ('2008', 308, True),
  880. ('2009', 309, True),
  881. ('1700', 0, False),
  882. ('1708', 8, False),
  883. ('1709', 9, False),
  884. ]
  885. for case in start_test_cases:
  886. _check_start(*((model,) + case))
  887. # the length of sunspot is 309, so last index is 208
  888. end_test_cases = [(None, 308, 0),
  889. (307, 307, 0),
  890. (308, 308, 0),
  891. (309, 308, 1),
  892. (312, 308, 4),
  893. (51, 51, 0),
  894. (333, 308, 25),
  895. ('2007', 307, 0),
  896. ('2008', 308, 0),
  897. ('2009', 308, 1),
  898. ('2012', 308, 4),
  899. ('1815', 115, 0),
  900. ('2033', 308, 25),
  901. ]
  902. for case in end_test_cases:
  903. _check_end(*((model,) + case))
  904. def test_arima_predict_indices():
  905. cpi = load_macrodata_pandas().data['cpi'].values
  906. model = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q')
  907. model.method = 'mle'
  908. # starting indices
  909. # raises - pre-sample + dynamic
  910. with pytest.raises(ValueError):
  911. model._get_prediction_index(0, None, True)
  912. with pytest.raises(ValueError):
  913. model._get_prediction_index(4, None, True)
  914. with pytest.raises(KeyError):
  915. model._get_prediction_index('1959Q1', None, True)
  916. with pytest.raises(ValueError):
  917. model._get_prediction_index('1960Q1', None, True)
  918. # raises - index differenced away
  919. with pytest.raises(ValueError):
  920. model._get_prediction_index(0, None, False)
  921. with pytest.raises(KeyError):
  922. model._get_prediction_index('1959Q1', None, False)
  923. # raises - start out of sample
  924. # works - in-sample
  925. # None
  926. start_test_cases = [
  927. # given, expected, dynamic
  928. (None, 4, True),
  929. # all start get moved back by k_diff
  930. (5, 4, True),
  931. (6, 5, True),
  932. # what about end of sample start - last value is first
  933. # forecast
  934. (203, 202, True),
  935. (1, 0, False),
  936. (4, 3, False),
  937. (5, 4, False),
  938. # all start get moved back by k_diff
  939. ('1960Q2', 4, True),
  940. ('1960Q3', 5, True),
  941. # what about end of sample start - last value is first
  942. # forecast
  943. ('2009Q4', 202, True),
  944. ('1959Q2', 0, False),
  945. ('1960Q1', 3, False),
  946. ('1960Q2', 4, False),
  947. ]
  948. for case in start_test_cases:
  949. _check_start(*((model,) + case))
  950. # TODO: make sure dates are passing through unmolested
  951. # the length of diff(cpi) is 202, so last index is 201
  952. end_test_cases = [(None, 201, 0),
  953. (201, 200, 0),
  954. (202, 201, 0),
  955. (203, 201, 1),
  956. (204, 201, 2),
  957. (51, 50, 0),
  958. (164 + 63, 201, 25),
  959. ('2009Q2', 200, 0),
  960. ('2009Q3', 201, 0),
  961. ('2009Q4', 201, 1),
  962. ('2010Q1', 201, 2),
  963. ('1971Q4', 50, 0),
  964. ('2015Q4', 201, 25),
  965. ]
  966. for case in end_test_cases:
  967. _check_end(*((model,) + case))
  968. # check higher k_diff
  969. # model.k_diff = 2
  970. model = ARIMA(cpi, (4, 2, 1), dates=cpi_dates, freq='Q')
  971. model.method = 'mle'
  972. # raises - pre-sample + dynamic
  973. assert_raises(ValueError, model._get_prediction_index, 0, None, True)
  974. assert_raises(ValueError, model._get_prediction_index, 5, None, True)
  975. assert_raises(KeyError, model._get_prediction_index,
  976. '1959Q1', None, True)
  977. assert_raises(ValueError, model._get_prediction_index,
  978. '1960Q1', None, True)
  979. # raises - index differenced away
  980. assert_raises(ValueError, model._get_prediction_index, 1, None, False)
  981. assert_raises(KeyError, model._get_prediction_index,
  982. '1959Q2', None, False)
  983. start_test_cases = [(None, 4, True),
  984. # all start get moved back by k_diff
  985. (6, 4, True),
  986. # what about end of sample start - last value is first
  987. # forecast
  988. (203, 201, True),
  989. (2, 0, False),
  990. (4, 2, False),
  991. (5, 3, False),
  992. ('1960Q3', 4, True),
  993. # what about end of sample start - last value is first
  994. # forecast
  995. ('2009Q4', 201, True),
  996. ('2009Q4', 201, True),
  997. ('1959Q3', 0, False),
  998. ('1960Q1', 2, False),
  999. ('1960Q2', 3, False),
  1000. ]
  1001. for case in start_test_cases:
  1002. _check_start(*((model,) + case))
  1003. end_test_cases = [(None, 200, 0),
  1004. (201, 199, 0),
  1005. (202, 200, 0),
  1006. (203, 200, 1),
  1007. (204, 200, 2),
  1008. (51, 49, 0),
  1009. (164 + 63, 200, 25),
  1010. ('2009Q2', 199, 0),
  1011. ('2009Q3', 200, 0),
  1012. ('2009Q4', 200, 1),
  1013. ('2010Q1', 200, 2),
  1014. ('1971Q4', 49, 0),
  1015. ('2015Q4', 200, 25),
  1016. ]
  1017. for case in end_test_cases:
  1018. _check_end(*((model,) + case))
  1019. def test_arima_predict_indices_css():
  1020. cpi = load_macrodata_pandas().data['cpi'].values
  1021. # NOTE: Doing no-constant for now to kick the conditional exogenous
  1022. # issue 274 down the road
  1023. # go ahead and git the model to set up necessary variables
  1024. model = ARIMA(cpi, (4, 1, 1))
  1025. model.method = 'css'
  1026. assert_raises(ValueError, model._get_prediction_index, 0, None, False)
  1027. assert_raises(ValueError, model._get_prediction_index, 0, None, True)
  1028. assert_raises(ValueError, model._get_prediction_index, 2, None, False)
  1029. assert_raises(ValueError, model._get_prediction_index, 2, None, True)
  1030. def test_arima_predict_css():
  1031. cpi = load_macrodata_pandas().data['cpi'].values
  1032. # NOTE: Doing no-constant for now to kick the conditional exogenous
  1033. # issue 274 down the road
  1034. # go ahead and git the model to set up necessary variables
  1035. res1 = ARIMA(cpi, (4, 1, 1)).fit(disp=-1, method="css",
  1036. trend="nc")
  1037. # but use gretl parameters to predict to avoid precision problems
  1038. params = np.array([1.231272508473910,
  1039. -0.282516097759915,
  1040. 0.170052755782440,
  1041. -0.118203728504945,
  1042. -0.938783134717947])
  1043. file_path = os.path.join(current_path, 'results',
  1044. 'results_arima_forecasts_all_css.csv')
  1045. with open(file_path, "rb") as test_data:
  1046. arima_forecasts = np.genfromtxt(test_data, delimiter=",",
  1047. skip_header=1, dtype=float)
  1048. fc = arima_forecasts[:, 0]
  1049. fcdyn = arima_forecasts[:, 1]
  1050. fcdyn2 = arima_forecasts[:, 2]
  1051. fcdyn3 = arima_forecasts[:, 3]
  1052. fcdyn4 = arima_forecasts[:, 4]
  1053. start, end = 1, 3
  1054. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  1055. res1.model.predict(params, start, end)
  1056. # start < p, end 0 1959q3 - 1960q1
  1057. start, end = 2, 4
  1058. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  1059. res1.model.predict(params, start, end)
  1060. # start < p, end >0 1959q3 - 1971q4
  1061. start, end = 2, 51
  1062. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  1063. res1.model.predict(params, start, end)
  1064. # start < p, end nobs 1959q3 - 2009q3
  1065. start, end = 2, 202
  1066. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  1067. res1.model.predict(params, start, end)
  1068. # start < p, end >nobs 1959q3 - 2015q4
  1069. start, end = 2, 227
  1070. with pytest.raises(ValueError, match='Start must be >= k_ar'):
  1071. res1.model.predict(params, start, end)
  1072. # start 0, end >0 1960q1 - 1971q4
  1073. start, end = 5, 51
  1074. fv = res1.model.predict(params, start, end, typ='levels')
  1075. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1076. # start 0, end nobs 1960q1 - 2009q3
  1077. start, end = 5, 202
  1078. fv = res1.model.predict(params, start, end, typ='levels')
  1079. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1080. # start 0, end >nobs 1960q1 - 2015q4
  1081. # TODO: why detoriating precision?
  1082. fv = res1.model.predict(params, start, end, typ='levels')
  1083. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1084. # start >p, end >0 1965q1 - 1971q4
  1085. start, end = 24, 51
  1086. fv = res1.model.predict(params, start, end, typ='levels')
  1087. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1088. # start >p, end nobs 1965q1 - 2009q3
  1089. start, end = 24, 202
  1090. fv = res1.model.predict(params, start, end, typ='levels')
  1091. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1092. # start >p, end >nobs 1965q1 - 2015q4
  1093. start, end = 24, 227
  1094. fv = res1.model.predict(params, start, end, typ='levels')
  1095. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1096. # start nobs, end nobs 2009q3 - 2009q3
  1097. start, end = 202, 202
  1098. fv = res1.model.predict(params, start, end, typ='levels')
  1099. assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
  1100. # start nobs, end >nobs 2009q3 - 2015q4

Large files files are truncated, but you can click here to view the full file