PageRenderTime 222ms CodeModel.GetById 67ms app.highlight 132ms RepoModel.GetById 1ms app.codeStats 1ms

/statsmodels/tsa/tests/test_arima.py

http://github.com/statsmodels/statsmodels
Python | 2746 lines | 2653 code | 79 blank | 14 comment | 5 complexity | 4e2322760d7505b95873cdee2fe1766c MD5 | raw file

Large files files are truncated, but you can click here to view the full file

   1from statsmodels.compat.platform import (PLATFORM_OSX, PLATFORM_WIN,
   2                                         PLATFORM_WIN32)
   3from statsmodels.compat.python import lrange
   4
   5import os
   6import pickle
   7import warnings
   8from io import BytesIO
   9
  10import numpy as np
  11import pandas as pd
  12import pytest
  13from numpy.testing import assert_almost_equal, assert_allclose, assert_raises
  14from pandas import DatetimeIndex, date_range, period_range
  15
  16import statsmodels.sandbox.tsa.fftarma as fa
  17from statsmodels.datasets.macrodata import load_pandas as load_macrodata_pandas
  18from statsmodels.regression.linear_model import OLS
  19from statsmodels.tools.sm_exceptions import (
  20    ValueWarning, HessianInversionWarning, SpecificationWarning,
  21    MissingDataError)
  22from statsmodels.tools.testing import assert_equal
  23from statsmodels.tsa.ar_model import AutoReg
  24from statsmodels.tsa.arima_model import ARMA, ARIMA
  25from statsmodels.tsa.arima_process import arma_generate_sample
  26from statsmodels.tsa.arma_mle import Arma
  27from statsmodels.tsa.tests.results import results_arma, results_arima
  28
  29DECIMAL_4 = 4
  30DECIMAL_3 = 3
  31DECIMAL_2 = 2
  32DECIMAL_1 = 1
  33
  34current_path = os.path.dirname(os.path.abspath(__file__))
  35ydata_path = os.path.join(current_path, 'results', 'y_arma_data.csv')
  36with open(ydata_path, "rb") as fd:
  37    y_arma = np.genfromtxt(fd, delimiter=",", skip_header=1, dtype=float)
  38
  39cpi_dates = period_range(start='1959q1', end='2009q3', freq='Q')
  40sun_dates = period_range(start='1700', end='2008', freq='A')
  41cpi_predict_dates = period_range(start='2009q3', end='2015q4', freq='Q')
  42sun_predict_dates = period_range(start='2008', end='2033', freq='A')
  43
  44
  45def test_compare_arma():
  46    # this is a preliminary test to compare arma_kf, arma_cond_ls
  47    # and arma_cond_mle
  48    # the results returned by the fit methods are incomplete
  49    # for now without random.seed
  50
  51    np.random.seed(9876565)
  52    x = fa.ArmaFft([1, -0.5], [1., 0.4], 40).generate_sample(nsample=200,
  53                                                             burnin=1000)
  54
  55    modkf = ARMA(x, (1, 1))
  56    reskf = modkf.fit(trend='nc', disp=-1)
  57    dres = reskf
  58
  59    modc = Arma(x)
  60    resls = modc.fit(order=(1, 1))
  61    rescm = modc.fit_mle(order=(1, 1), start_params=[0.4, 0.4, 1.], disp=0)
  62
  63    # decimal 1 corresponds to threshold of 5% difference
  64    # still different sign  corrcted
  65    assert_almost_equal(resls[0] / dres.params, np.ones(dres.params.shape),
  66                        decimal=1)
  67
  68    # rescm also contains variance estimate as last element of params
  69    assert_almost_equal(rescm.params[:-1] / dres.params,
  70                        np.ones(dres.params.shape), decimal=1)
  71
  72
  73class CheckArmaResultsMixin(object):
  74    """
  75    res2 are the results from gretl.  They are in results/results_arma.
  76    res1 are from statsmodels
  77    """
  78    decimal_params = DECIMAL_4
  79
  80    def test_params(self):
  81        assert_almost_equal(self.res1.params, self.res2.params,
  82                            self.decimal_params)
  83
  84    decimal_aic = DECIMAL_4
  85
  86    def test_aic(self):
  87        assert_almost_equal(self.res1.aic, self.res2.aic, self.decimal_aic)
  88
  89    decimal_bic = DECIMAL_4
  90
  91    def test_bic(self):
  92        assert_almost_equal(self.res1.bic, self.res2.bic, self.decimal_bic)
  93
  94    decimal_arroots = DECIMAL_4
  95
  96    def test_arroots(self):
  97        assert_almost_equal(self.res1.arroots, self.res2.arroots,
  98                            self.decimal_arroots)
  99
 100    decimal_maroots = DECIMAL_4
 101
 102    def test_maroots(self):
 103        assert_almost_equal(self.res1.maroots, self.res2.maroots,
 104                            self.decimal_maroots)
 105
 106    decimal_bse = DECIMAL_2
 107
 108    def test_bse(self):
 109        assert_almost_equal(self.res1.bse, self.res2.bse, self.decimal_bse)
 110
 111    decimal_cov_params = DECIMAL_4
 112
 113    def test_covparams(self):
 114        assert_almost_equal(self.res1.cov_params(), self.res2.cov_params,
 115                            self.decimal_cov_params)
 116
 117    decimal_hqic = DECIMAL_4
 118
 119    def test_hqic(self):
 120        assert_almost_equal(self.res1.hqic, self.res2.hqic, self.decimal_hqic)
 121
 122    decimal_llf = DECIMAL_4
 123
 124    def test_llf(self):
 125        assert_almost_equal(self.res1.llf, self.res2.llf, self.decimal_llf)
 126
 127    decimal_resid = DECIMAL_4
 128
 129    def test_resid(self):
 130        assert_almost_equal(self.res1.resid, self.res2.resid,
 131                            self.decimal_resid)
 132
 133    decimal_fittedvalues = DECIMAL_4
 134
 135    def test_fittedvalues(self):
 136        assert_almost_equal(self.res1.fittedvalues, self.res2.fittedvalues,
 137                            self.decimal_fittedvalues)
 138
 139    decimal_pvalues = DECIMAL_2
 140
 141    def test_pvalues(self):
 142        assert_almost_equal(self.res1.pvalues, self.res2.pvalues,
 143                            self.decimal_pvalues)
 144
 145    decimal_t = DECIMAL_2  # only 2 decimal places in gretl output
 146
 147    def test_tvalues(self):
 148        assert_almost_equal(self.res1.tvalues, self.res2.tvalues,
 149                            self.decimal_t)
 150
 151    decimal_sigma2 = DECIMAL_4
 152
 153    def test_sigma2(self):
 154        assert_almost_equal(self.res1.sigma2, self.res2.sigma2,
 155                            self.decimal_sigma2)
 156
 157    @pytest.mark.smoke
 158    def test_summary(self):
 159        self.res1.summary()
 160
 161    @pytest.mark.smoke
 162    def test_summary2(self):
 163        self.res1.summary2()
 164
 165
 166class CheckForecastMixin(object):
 167    decimal_forecast = DECIMAL_4
 168
 169    def test_forecast(self):
 170        assert_almost_equal(self.res1.forecast_res, self.res2.forecast,
 171                            self.decimal_forecast)
 172
 173    decimal_forecasterr = DECIMAL_4
 174
 175    def test_forecasterr(self):
 176        assert_almost_equal(self.res1.forecast_err, self.res2.forecasterr,
 177                            self.decimal_forecasterr)
 178
 179
 180class CheckDynamicForecastMixin(object):
 181    decimal_forecast_dyn = 4
 182
 183    def test_dynamic_forecast(self):
 184        assert_almost_equal(self.res1.forecast_res_dyn, self.res2.forecast_dyn,
 185                            self.decimal_forecast_dyn)
 186
 187    def test_forecasterr(self):
 188        assert_almost_equal(self.res1.forecast_err_dyn,
 189                            self.res2.forecasterr_dyn,
 190                            DECIMAL_4)
 191
 192
 193class CheckArimaResultsMixin(CheckArmaResultsMixin):
 194    def test_order(self):
 195        assert self.res1.k_diff == self.res2.k_diff
 196        assert self.res1.k_ar == self.res2.k_ar
 197        assert self.res1.k_ma == self.res2.k_ma
 198
 199    decimal_predict_levels = DECIMAL_4
 200
 201    def test_predict_levels(self):
 202        assert_almost_equal(self.res1.predict(typ='levels'), self.res2.linear,
 203                            self.decimal_predict_levels)
 204
 205
 206class Test_Y_ARMA11_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
 207    @classmethod
 208    def setup_class(cls):
 209        endog = y_arma[:, 0]
 210        cls.res1 = ARMA(endog, order=(1, 1)).fit(trend='nc', disp=-1)
 211        (cls.res1.forecast_res, cls.res1.forecast_err,
 212         confint) = cls.res1.forecast(10)
 213        cls.res2 = results_arma.Y_arma11()
 214
 215    def test_pickle(self):
 216        fh = BytesIO()
 217        # test wrapped results load save pickle
 218        self.res1.save(fh)
 219        fh.seek(0, 0)
 220        res_unpickled = self.res1.__class__.load(fh)
 221        assert type(res_unpickled) is type(self.res1)  # noqa: E721
 222
 223
 224class Test_Y_ARMA14_NoConst(CheckArmaResultsMixin):
 225    @classmethod
 226    def setup_class(cls):
 227        endog = y_arma[:, 1]
 228        cls.res1 = ARMA(endog, order=(1, 4)).fit(trend='nc', disp=-1)
 229        cls.res2 = results_arma.Y_arma14()
 230
 231
 232@pytest.mark.slow
 233class Test_Y_ARMA41_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
 234    @classmethod
 235    def setup_class(cls):
 236        endog = y_arma[:, 2]
 237        cls.res1 = ARMA(endog, order=(4, 1)).fit(trend='nc', disp=-1)
 238        (cls.res1.forecast_res, cls.res1.forecast_err,
 239         confint) = cls.res1.forecast(10)
 240        cls.res2 = results_arma.Y_arma41()
 241        cls.decimal_maroots = DECIMAL_3
 242
 243
 244class Test_Y_ARMA22_NoConst(CheckArmaResultsMixin):
 245    @classmethod
 246    def setup_class(cls):
 247        endog = y_arma[:, 3]
 248        cls.res1 = ARMA(endog, order=(2, 2)).fit(trend='nc', disp=-1)
 249        cls.res2 = results_arma.Y_arma22()
 250
 251
 252class Test_Y_ARMA50_NoConst(CheckArmaResultsMixin, CheckForecastMixin):
 253    @classmethod
 254    def setup_class(cls):
 255        endog = y_arma[:, 4]
 256        cls.res1 = ARMA(endog, order=(5, 0)).fit(trend='nc', disp=-1)
 257        (cls.res1.forecast_res, cls.res1.forecast_err,
 258         confint) = cls.res1.forecast(10)
 259        cls.res2 = results_arma.Y_arma50()
 260
 261
 262class Test_Y_ARMA02_NoConst(CheckArmaResultsMixin):
 263    @classmethod
 264    def setup_class(cls):
 265        endog = y_arma[:, 5]
 266        cls.res1 = ARMA(endog, order=(0, 2)).fit(trend='nc', disp=-1)
 267        cls.res2 = results_arma.Y_arma02()
 268
 269
 270class Test_Y_ARMA11_Const(CheckArmaResultsMixin, CheckForecastMixin):
 271    @classmethod
 272    def setup_class(cls):
 273        endog = y_arma[:, 6]
 274        cls.res1 = ARMA(endog, order=(1, 1)).fit(trend="c", disp=-1)
 275        (cls.res1.forecast_res, cls.res1.forecast_err,
 276         confint) = cls.res1.forecast(10)
 277        cls.res2 = results_arma.Y_arma11c()
 278
 279
 280class Test_Y_ARMA14_Const(CheckArmaResultsMixin):
 281    @classmethod
 282    def setup_class(cls):
 283        endog = y_arma[:, 7]
 284        cls.res1 = ARMA(endog, order=(1, 4)).fit(trend="c", disp=-1)
 285        cls.res2 = results_arma.Y_arma14c()
 286
 287
 288class Test_Y_ARMA41_Const(CheckArmaResultsMixin, CheckForecastMixin):
 289    @classmethod
 290    def setup_class(cls):
 291        endog = y_arma[:, 8]
 292        cls.res2 = results_arma.Y_arma41c()
 293        cls.res1 = ARMA(endog, order=(4, 1)).fit(trend="c", disp=-1,
 294                                                 start_params=cls.res2.params)
 295        (cls.res1.forecast_res, cls.res1.forecast_err,
 296         confint) = cls.res1.forecast(10)
 297        cls.decimal_cov_params = DECIMAL_3
 298        cls.decimal_fittedvalues = DECIMAL_3
 299        cls.decimal_resid = DECIMAL_3
 300        cls.decimal_params = DECIMAL_3
 301
 302
 303class Test_Y_ARMA22_Const(CheckArmaResultsMixin):
 304    @classmethod
 305    def setup_class(cls):
 306        endog = y_arma[:, 9]
 307        cls.res1 = ARMA(endog, order=(2, 2)).fit(trend="c", disp=-1)
 308        cls.res2 = results_arma.Y_arma22c()
 309
 310    def test_summary(self):
 311        # regression test for html of roots table #4434
 312        # we ignore whitespace in the assert
 313        summ = self.res1.summary()
 314        summ_roots = """\
 315        <tableclass="simpletable">
 316        <caption>Roots</caption>
 317        <tr>
 318        <td></td><th>Real</th><th>Imaginary</th><th>Modulus</th><th>Frequency</th>
 319        </tr>
 320        <tr>
 321        <th>AR.1</th><td>1.0991</td><td>-1.2571j</td><td>1.6698</td><td>-0.1357</td>
 322        </tr>
 323        <tr>
 324        <th>AR.2</th><td>1.0991</td><td>+1.2571j</td><td>1.6698</td><td>0.1357</td>
 325        </tr>
 326        <tr>
 327        <th>MA.1</th><td>-1.1702</td><td>+0.0000j</td><td>1.1702</td><td>0.5000</td>
 328        </tr>
 329        <tr>
 330        <th>MA.2</th><td>1.2215</td><td>+0.0000j</td><td>1.2215</td><td>0.0000</td>
 331        </tr>
 332        </table>"""
 333        assert_equal(summ.tables[2]._repr_html_().replace(' ', ''),
 334                     summ_roots.replace(' ', ''))
 335
 336
 337class Test_Y_ARMA50_Const(CheckArmaResultsMixin, CheckForecastMixin):
 338    @classmethod
 339    def setup_class(cls):
 340        endog = y_arma[:, 10]
 341        cls.res1 = ARMA(endog, order=(5, 0)).fit(trend="c", disp=-1)
 342        (cls.res1.forecast_res, cls.res1.forecast_err,
 343         confint) = cls.res1.forecast(10)
 344        cls.res2 = results_arma.Y_arma50c()
 345
 346
 347class Test_Y_ARMA02_Const(CheckArmaResultsMixin):
 348    @classmethod
 349    def setup_class(cls):
 350        endog = y_arma[:, 11]
 351        cls.res1 = ARMA(endog, order=(0, 2)).fit(trend="c", disp=-1)
 352        cls.res2 = results_arma.Y_arma02c()
 353
 354
 355# cov_params and tvalues are off still but not as much vs. R
 356class Test_Y_ARMA11_NoConst_CSS(CheckArmaResultsMixin):
 357    @classmethod
 358    def setup_class(cls):
 359        endog = y_arma[:, 0]
 360        cls.res1 = ARMA(endog, order=(1, 1)).fit(method="css", trend='nc',
 361                                                 disp=-1)
 362        cls.res2 = results_arma.Y_arma11("css")
 363        cls.decimal_t = DECIMAL_1
 364
 365
 366# better vs. R
 367class Test_Y_ARMA14_NoConst_CSS(CheckArmaResultsMixin):
 368    @classmethod
 369    def setup_class(cls):
 370        endog = y_arma[:, 1]
 371        cls.res1 = ARMA(endog, order=(1, 4)).fit(method="css", trend='nc',
 372                                                 disp=-1)
 373        cls.res2 = results_arma.Y_arma14("css")
 374        cls.decimal_fittedvalues = DECIMAL_3
 375        cls.decimal_resid = DECIMAL_3
 376        cls.decimal_t = DECIMAL_1
 377
 378
 379# bse, etc. better vs. R
 380# maroot is off because maparams is off a bit (adjust tolerance?)
 381class Test_Y_ARMA41_NoConst_CSS(CheckArmaResultsMixin):
 382    @classmethod
 383    def setup_class(cls):
 384        endog = y_arma[:, 2]
 385        cls.res1 = ARMA(endog, order=(4, 1)).fit(method="css", trend='nc',
 386                                                 disp=-1)
 387        cls.res2 = results_arma.Y_arma41("css")
 388        cls.decimal_t = DECIMAL_1
 389        cls.decimal_pvalues = 0
 390        cls.decimal_cov_params = DECIMAL_3
 391        cls.decimal_maroots = DECIMAL_1
 392
 393
 394# same notes as above
 395class Test_Y_ARMA22_NoConst_CSS(CheckArmaResultsMixin):
 396    @classmethod
 397    def setup_class(cls):
 398        endog = y_arma[:, 3]
 399        cls.res1 = ARMA(endog, order=(2, 2)).fit(method="css", trend='nc',
 400                                                 disp=-1)
 401        cls.res2 = results_arma.Y_arma22("css")
 402        cls.decimal_t = DECIMAL_1
 403        cls.decimal_resid = DECIMAL_3
 404        cls.decimal_pvalues = DECIMAL_1
 405        cls.decimal_fittedvalues = DECIMAL_3
 406
 407
 408# NOTE: gretl just uses least squares for AR CSS
 409# so BIC, etc. is
 410# -2*res1.llf + np.log(nobs)*(res1.q+res1.p+res1.k)
 411# with no adjustment for p and no extra sigma estimate
 412# NOTE: so our tests use x-12 arima results which agree with us and are
 413# consistent with the rest of the models
 414class Test_Y_ARMA50_NoConst_CSS(CheckArmaResultsMixin):
 415    @classmethod
 416    def setup_class(cls):
 417        endog = y_arma[:, 4]
 418        cls.res1 = ARMA(endog, order=(5, 0)).fit(method="css", trend='nc',
 419                                                 disp=-1)
 420        cls.res2 = results_arma.Y_arma50("css")
 421        cls.decimal_t = 0
 422        cls.decimal_llf = DECIMAL_1  # looks like rounding error?
 423
 424
 425class Test_Y_ARMA02_NoConst_CSS(CheckArmaResultsMixin):
 426    @classmethod
 427    def setup_class(cls):
 428        endog = y_arma[:, 5]
 429        cls.res1 = ARMA(endog, order=(0, 2)).fit(method="css", trend='nc',
 430                                                 disp=-1)
 431        cls.res2 = results_arma.Y_arma02("css")
 432
 433
 434# NOTE: our results are close to --x-12-arima option and R
 435class Test_Y_ARMA11_Const_CSS(CheckArmaResultsMixin):
 436    @classmethod
 437    def setup_class(cls):
 438        endog = y_arma[:, 6]
 439        cls.res1 = ARMA(endog, order=(1, 1)).fit(trend="c", method="css",
 440                                                 disp=-1)
 441        cls.res2 = results_arma.Y_arma11c("css")
 442        cls.decimal_params = DECIMAL_3
 443        cls.decimal_cov_params = DECIMAL_3
 444        cls.decimal_t = DECIMAL_1
 445
 446
 447class Test_Y_ARMA14_Const_CSS(CheckArmaResultsMixin):
 448    @classmethod
 449    def setup_class(cls):
 450        endog = y_arma[:, 7]
 451        cls.res1 = ARMA(endog, order=(1, 4)).fit(trend="c", method="css",
 452                                                 disp=-1)
 453        cls.res2 = results_arma.Y_arma14c("css")
 454        cls.decimal_t = DECIMAL_1
 455        cls.decimal_pvalues = DECIMAL_1
 456
 457
 458class Test_Y_ARMA41_Const_CSS(CheckArmaResultsMixin):
 459    @classmethod
 460    def setup_class(cls):
 461        endog = y_arma[:, 8]
 462        cls.res1 = ARMA(endog, order=(4, 1)).fit(trend="c", method="css",
 463                                                 disp=-1)
 464        cls.res2 = results_arma.Y_arma41c("css")
 465        cls.decimal_t = DECIMAL_1
 466        cls.decimal_cov_params = DECIMAL_1
 467        cls.decimal_maroots = DECIMAL_3
 468        cls.decimal_bse = DECIMAL_1
 469
 470
 471class Test_Y_ARMA22_Const_CSS(CheckArmaResultsMixin):
 472    @classmethod
 473    def setup_class(cls):
 474        endog = y_arma[:, 9]
 475        cls.res1 = ARMA(endog, order=(2, 2)).fit(trend="c", method="css",
 476                                                 disp=-1)
 477        cls.res2 = results_arma.Y_arma22c("css")
 478        cls.decimal_t = 0
 479        cls.decimal_pvalues = DECIMAL_1
 480
 481
 482class Test_Y_ARMA50_Const_CSS(CheckArmaResultsMixin):
 483    @classmethod
 484    def setup_class(cls):
 485        endog = y_arma[:, 10]
 486        cls.res1 = ARMA(endog, order=(5, 0)).fit(trend="c", method="css",
 487                                                 disp=-1)
 488        cls.res2 = results_arma.Y_arma50c("css")
 489        cls.decimal_t = DECIMAL_1
 490        cls.decimal_params = DECIMAL_3
 491        cls.decimal_cov_params = DECIMAL_2
 492
 493
 494class Test_Y_ARMA02_Const_CSS(CheckArmaResultsMixin):
 495    @classmethod
 496    def setup_class(cls):
 497        endog = y_arma[:, 11]
 498        cls.res1 = ARMA(endog, order=(0, 2)).fit(trend="c", method="css",
 499                                                 disp=-1)
 500        cls.res2 = results_arma.Y_arma02c("css")
 501
 502
 503def test_reset_trend_error():
 504    endog = y_arma[:, 0]
 505    mod = ARMA(endog, order=(1, 1))
 506    mod.fit(trend="c", disp=-1)
 507    with pytest.raises(RuntimeError):
 508        mod.fit(trend="nc", disp=-1)
 509
 510
 511@pytest.mark.slow
 512def test_start_params_bug():
 513    data = np.array([1368., 1187, 1090, 1439, 2362, 2783, 2869, 2512, 1804,
 514                     1544, 1028, 869, 1737, 2055, 1947, 1618, 1196, 867, 997,
 515                     1862, 2525,
 516                     3250, 4023, 4018, 3585, 3004, 2500, 2441, 2749, 2466,
 517                     2157, 1847, 1463,
 518                     1146, 851, 993, 1448, 1719, 1709, 1455, 1950, 1763, 2075,
 519                     2343, 3570,
 520                     4690, 3700, 2339, 1679, 1466, 998, 853, 835, 922, 851,
 521                     1125, 1299, 1105,
 522                     860, 701, 689, 774, 582, 419, 846, 1132, 902, 1058, 1341,
 523                     1551, 1167,
 524                     975, 786, 759, 751, 649, 876, 720, 498, 553, 459, 543,
 525                     447, 415, 377,
 526                     373, 324, 320, 306, 259, 220, 342, 558, 825, 994, 1267,
 527                     1473, 1601,
 528                     1896, 1890, 2012, 2198, 2393, 2825, 3411, 3406, 2464,
 529                     2891, 3685, 3638,
 530                     3746, 3373, 3190, 2681, 2846, 4129, 5054, 5002, 4801,
 531                     4934, 4903, 4713,
 532                     4745, 4736, 4622, 4642, 4478, 4510, 4758, 4457, 4356,
 533                     4170, 4658, 4546,
 534                     4402, 4183, 3574, 2586, 3326, 3948, 3983, 3997, 4422,
 535                     4496, 4276, 3467,
 536                     2753, 2582, 2921, 2768, 2789, 2824, 2482, 2773, 3005,
 537                     3641, 3699, 3774,
 538                     3698, 3628, 3180, 3306, 2841, 2014, 1910, 2560, 2980,
 539                     3012, 3210, 3457,
 540                     3158, 3344, 3609, 3327, 2913, 2264, 2326, 2596, 2225,
 541                     1767, 1190, 792,
 542                     669, 589, 496, 354, 246, 250, 323, 495, 924, 1536, 2081,
 543                     2660, 2814, 2992,
 544                     3115, 2962, 2272, 2151, 1889, 1481, 955, 631, 288, 103,
 545                     60, 82, 107, 185,
 546                     618, 1526, 2046, 2348, 2584, 2600, 2515, 2345, 2351, 2355,
 547                     2409, 2449,
 548                     2645, 2918, 3187, 2888, 2610, 2740, 2526, 2383, 2936,
 549                     2968, 2635, 2617,
 550                     2790, 3906, 4018, 4797, 4919, 4942, 4656, 4444, 3898,
 551                     3908, 3678, 3605,
 552                     3186, 2139, 2002, 1559, 1235, 1183, 1096, 673, 389, 223,
 553                     352, 308, 365,
 554                     525, 779, 894, 901, 1025, 1047, 981, 902, 759, 569, 519,
 555                     408, 263, 156,
 556                     72, 49, 31, 41, 192, 423, 492, 552, 564, 723, 921, 1525,
 557                     2768, 3531, 3824,
 558                     3835, 4294, 4533, 4173, 4221, 4064, 4641, 4685, 4026,
 559                     4323, 4585, 4836,
 560                     4822, 4631, 4614, 4326, 4790, 4736, 4104, 5099, 5154,
 561                     5121, 5384, 5274,
 562                     5225, 4899, 5382, 5295, 5349, 4977, 4597, 4069, 3733,
 563                     3439, 3052, 2626,
 564                     1939, 1064, 713, 916, 832, 658, 817, 921, 772, 764, 824,
 565                     967, 1127, 1153,
 566                     824, 912, 957, 990, 1218, 1684, 2030, 2119, 2233, 2657,
 567                     2652, 2682, 2498,
 568                     2429, 2346, 2298, 2129, 1829, 1816, 1225, 1010, 748, 627,
 569                     469, 576, 532,
 570                     475, 582, 641, 605, 699, 680, 714, 670, 666, 636, 672,
 571                     679, 446, 248, 134,
 572                     160, 178, 286, 413, 676, 1025, 1159, 952, 1398, 1833,
 573                     2045, 2072, 1798,
 574                     1799, 1358, 727, 353, 347, 844, 1377, 1829, 2118, 2272,
 575                     2745, 4263, 4314,
 576                     4530, 4354, 4645, 4547, 5391, 4855, 4739, 4520, 4573,
 577                     4305, 4196, 3773,
 578                     3368, 2596, 2596, 2305, 2756, 3747, 4078, 3415, 2369,
 579                     2210, 2316, 2263,
 580                     2672, 3571, 4131, 4167, 4077, 3924, 3738, 3712, 3510,
 581                     3182, 3179, 2951,
 582                     2453, 2078, 1999, 2486, 2581, 1891, 1997, 1366, 1294,
 583                     1536, 2794, 3211,
 584                     3242, 3406, 3121, 2425, 2016, 1787, 1508, 1304, 1060,
 585                     1342, 1589, 2361,
 586                     3452, 2659, 2857, 3255, 3322, 2852, 2964, 3132, 3033,
 587                     2931, 2636, 2818,
 588                     3310, 3396, 3179, 3232, 3543, 3759, 3503, 3758, 3658,
 589                     3425, 3053, 2620,
 590                     1837, 923, 712, 1054, 1376, 1556, 1498, 1523, 1088, 728,
 591                     890, 1413, 2524,
 592                     3295, 4097, 3993, 4116, 3874, 4074, 4142, 3975, 3908,
 593                     3907, 3918, 3755,
 594                     3648, 3778, 4293, 4385, 4360, 4352, 4528, 4365, 3846,
 595                     4098, 3860, 3230,
 596                     2820, 2916, 3201, 3721, 3397, 3055, 2141, 1623, 1825,
 597                     1716, 2232, 2939,
 598                     3735, 4838, 4560, 4307, 4975, 5173, 4859, 5268, 4992,
 599                     5100, 5070, 5270,
 600                     4760, 5135, 5059, 4682, 4492, 4933, 4737, 4611, 4634,
 601                     4789, 4811, 4379,
 602                     4689, 4284, 4191, 3313, 2770, 2543, 3105, 2967, 2420,
 603                     1996, 2247, 2564,
 604                     2726, 3021, 3427, 3509, 3759, 3324, 2988, 2849, 2340,
 605                     2443, 2364, 1252,
 606                     623, 742, 867, 684, 488, 348, 241, 187, 279, 355, 423,
 607                     678, 1375, 1497,
 608                     1434, 2116, 2411, 1929, 1628, 1635, 1609, 1757, 2090,
 609                     2085, 1790, 1846,
 610                     2038, 2360, 2342, 2401, 2920, 3030, 3132, 4385, 5483,
 611                     5865, 5595, 5485,
 612                     5727, 5553, 5560, 5233, 5478, 5159, 5155, 5312, 5079,
 613                     4510, 4628, 4535,
 614                     3656, 3698, 3443, 3146, 2562, 2304, 2181, 2293, 1950,
 615                     1930, 2197, 2796,
 616                     3441, 3649, 3815, 2850, 4005, 5305, 5550, 5641, 4717,
 617                     5131, 2831, 3518,
 618                     3354, 3115, 3515, 3552, 3244, 3658, 4407, 4935, 4299,
 619                     3166, 3335, 2728,
 620                     2488, 2573, 2002, 1717, 1645, 1977, 2049, 2125, 2376,
 621                     2551, 2578, 2629,
 622                     2750, 3150, 3699, 4062, 3959, 3264, 2671, 2205, 2128,
 623                     2133, 2095, 1964,
 624                     2006, 2074, 2201, 2506, 2449, 2465, 2064, 1446, 1382, 983,
 625                     898, 489, 319,
 626                     383, 332, 276, 224, 144, 101, 232, 429, 597, 750, 908,
 627                     960, 1076, 951,
 628                     1062, 1183, 1404, 1391, 1419, 1497, 1267, 963, 682, 777,
 629                     906, 1149, 1439,
 630                     1600, 1876, 1885, 1962, 2280, 2711, 2591, 2411])
 631    with warnings.catch_warnings():
 632        warnings.simplefilter("ignore")
 633        ARMA(data, order=(4, 1)).fit(start_ar_lags=5, disp=-1)
 634
 635
 636class Test_ARIMA101(CheckArmaResultsMixin):
 637    @classmethod
 638    def setup_class(cls):
 639        endog = y_arma[:, 6]
 640        cls.res1 = ARIMA(endog, (1, 0, 1)).fit(trend="c", disp=-1)
 641        (cls.res1.forecast_res, cls.res1.forecast_err,
 642         confint) = cls.res1.forecast(10)
 643        cls.res2 = results_arma.Y_arma11c()
 644        cls.res2.k_diff = 0
 645        cls.res2.k_ar = 1
 646        cls.res2.k_ma = 1
 647
 648
 649class Test_ARIMA111(CheckArimaResultsMixin, CheckForecastMixin,
 650                    CheckDynamicForecastMixin):
 651    @classmethod
 652    def setup_class(cls):
 653        cpi = load_macrodata_pandas().data['cpi'].values
 654        cls.res1 = ARIMA(cpi, (1, 1, 1)).fit(disp=-1)
 655        cls.res2 = results_arima.ARIMA111()
 656        # make sure endog names changes to D.cpi
 657        cls.decimal_llf = 3
 658        cls.decimal_aic = 3
 659        cls.decimal_bic = 3
 660        # TODO: why has dec_cov_params changed, used to be better
 661        cls.decimal_cov_params = 2
 662        cls.decimal_t = 0
 663        (cls.res1.forecast_res,
 664         cls.res1.forecast_err,
 665         conf_int) = cls.res1.forecast(25)
 666        # TODO: fix the indexing for the end here, I do not think this is right
 667        # if we're going to treat it like indexing
 668        # the forecast from 2005Q1 through 2009Q4 is indices
 669        # 184 through 227 not 226
 670        # note that the first one counts in the count so 164 + 64 is 65
 671        # predictions
 672        cls.res1.forecast_res_dyn = cls.res1.predict(start=164, end=164 + 63,
 673                                                     typ='levels',
 674                                                     dynamic=True)
 675
 676    def test_freq(self):
 677        assert_almost_equal(self.res1.arfreq, [0.0000], 4)
 678        assert_almost_equal(self.res1.mafreq, [0.0000], 4)
 679
 680
 681class Test_ARIMA111CSS(CheckArimaResultsMixin, CheckForecastMixin,
 682                       CheckDynamicForecastMixin):
 683    @classmethod
 684    def setup_class(cls):
 685        cpi = load_macrodata_pandas().data['cpi'].values
 686        cls.res1 = ARIMA(cpi, (1, 1, 1)).fit(disp=-1, method='css')
 687        cls.res2 = results_arima.ARIMA111(method='css')
 688        cls.res2.fittedvalues = - cpi[1:-1] + cls.res2.linear
 689        # make sure endog names changes to D.cpi
 690        (cls.res1.forecast_res,
 691         cls.res1.forecast_err,
 692         conf_int) = cls.res1.forecast(25)
 693        cls.decimal_forecast = 2
 694        cls.decimal_forecast_dyn = 2
 695        cls.decimal_forecasterr = 3
 696        cls.res1.forecast_res_dyn = cls.res1.predict(start=164, end=164 + 63,
 697                                                     typ='levels',
 698                                                     dynamic=True)
 699
 700        # precisions
 701        cls.decimal_arroots = 3
 702        cls.decimal_cov_params = 3
 703        cls.decimal_hqic = 3
 704        cls.decimal_maroots = 3
 705        cls.decimal_t = 1
 706        cls.decimal_fittedvalues = 2  # because of rounding when copying
 707        cls.decimal_resid = 2
 708        cls.decimal_predict_levels = DECIMAL_2
 709
 710
 711class Test_ARIMA112CSS(CheckArimaResultsMixin):
 712    @classmethod
 713    def setup_class(cls):
 714        cpi = load_macrodata_pandas().data['cpi'].values
 715        cls.res1 = ARIMA(cpi, (1, 1, 2)).fit(disp=-1, method='css',
 716                                             start_params=[.905322, -.692425,
 717                                                           1.07366,
 718                                                           0.172024])
 719        cls.res2 = results_arima.ARIMA112(method='css')
 720        cls.res2.fittedvalues = - cpi[1:-1] + cls.res2.linear
 721        # make sure endog names changes to D.cpi
 722        cls.decimal_llf = 3
 723        cls.decimal_aic = 3
 724        cls.decimal_bic = 3
 725        # TODO: fix the indexing for the end here, I do not think this is right
 726        # if we're going to treat it like indexing
 727        # the forecast from 2005Q1 through 2009Q4 is indices
 728        # 184 through 227 not 226
 729        # note that the first one counts in the count so 164 + 64 is 65
 730        # predictions
 731        # cls.res1.forecast_res_dyn = self.predict(start=164, end=164+63,
 732        #                                         typ='levels', dynamic=True)
 733        # since we got from gretl do not have linear prediction in differences
 734        cls.decimal_arroots = 3
 735        cls.decimal_maroots = 2
 736        cls.decimal_t = 1
 737        cls.decimal_resid = 2
 738        cls.decimal_fittedvalues = 3
 739        cls.decimal_predict_levels = DECIMAL_3
 740
 741    def test_freq(self):
 742        assert_almost_equal(self.res1.arfreq, [0.5000], 4)
 743        assert_almost_equal(self.res1.mafreq, [0.5000, 0.5000], 4)
 744
 745
 746def test_arima_predict_mle_dates():
 747    cpi = load_macrodata_pandas().data['cpi'].values
 748    res1 = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q').fit(disp=-1)
 749    file_path = os.path.join(current_path, 'results',
 750                             'results_arima_forecasts_all_mle.csv')
 751    with open(file_path, "rb") as test_data:
 752        arima_forecasts = np.genfromtxt(test_data, delimiter=",",
 753                                        skip_header=1, dtype=float)
 754
 755    fc = arima_forecasts[:, 0]
 756    fcdyn = arima_forecasts[:, 1]
 757    fcdyn2 = arima_forecasts[:, 2]
 758
 759    start, end = 2, 51
 760    fv = res1.predict('1959Q3', '1971Q4', typ='levels')
 761    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 762    assert_equal(res1.data.predict_dates, cpi_dates[start:end + 1])
 763
 764    start, end = 202, 227
 765    fv = res1.predict('2009Q3', '2015Q4', typ='levels')
 766    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 767    assert_equal(res1.data.predict_dates, cpi_predict_dates)
 768
 769    # make sure dynamic works
 770
 771    start, end = '1960q2', '1971q4'
 772    fv = res1.predict(start, end, dynamic=True, typ='levels')
 773    assert_almost_equal(fv, fcdyn[5:51 + 1], DECIMAL_4)
 774
 775    start, end = '1965q1', '2015q4'
 776    fv = res1.predict(start, end, dynamic=True, typ='levels')
 777    assert_almost_equal(fv, fcdyn2[24:227 + 1], DECIMAL_4)
 778
 779
 780def test_arma_predict_mle_dates():
 781    from statsmodels.datasets.sunspots import load_pandas
 782    sunspots = load_pandas().data['SUNACTIVITY'].values
 783    mod = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
 784    mod.method = 'mle'
 785
 786    assert_raises(ValueError, mod._get_prediction_index, '1701', '1751', True)
 787
 788    start, end = 2, 51
 789    mod._get_prediction_index('1702', '1751', False)
 790    assert_equal(mod.data.predict_dates, sun_dates[start:end + 1])
 791
 792    start, end = 308, 333
 793    mod._get_prediction_index('2008', '2033', False)
 794    assert_equal(mod.data.predict_dates, sun_predict_dates)
 795
 796
 797def test_arima_predict_css_dates():
 798    cpi = load_macrodata_pandas().data['cpi'].values
 799    res1 = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q').fit(disp=-1,
 800                                                                method='css',
 801                                                                trend='nc')
 802
 803    params = np.array([1.231272508473910,
 804                       -0.282516097759915,
 805                       0.170052755782440,
 806                       -0.118203728504945,
 807                       -0.938783134717947])
 808    file_path = os.path.join(current_path, 'results',
 809                             'results_arima_forecasts_all_css.csv')
 810    with open(file_path, "rb") as test_data:
 811        arima_forecasts = np.genfromtxt(test_data, delimiter=",",
 812                                        skip_header=1, dtype=float)
 813
 814    fc = arima_forecasts[:, 0]
 815    fcdyn = arima_forecasts[:, 1]
 816    fcdyn2 = arima_forecasts[:, 2]
 817
 818    start, end = 5, 51
 819    fv = res1.model.predict(params, '1960Q2', '1971Q4', typ='levels')
 820    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 821    assert_equal(res1.data.predict_dates, cpi_dates[start:end + 1])
 822
 823    start, end = 202, 227
 824    fv = res1.model.predict(params, '2009Q3', '2015Q4', typ='levels')
 825    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 826    assert_equal(res1.data.predict_dates, cpi_predict_dates)
 827
 828    # make sure dynamic works
 829    start, end = 5, 51
 830    fv = res1.model.predict(params, '1960Q2', '1971Q4', typ='levels',
 831                            dynamic=True)
 832    assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
 833
 834    start, end = '1965q1', '2015q4'
 835    fv = res1.model.predict(params, start, end, dynamic=True, typ='levels')
 836    assert_almost_equal(fv, fcdyn2[24:227 + 1], DECIMAL_4)
 837
 838
 839def test_arma_predict_css_dates():
 840    from statsmodels.datasets.sunspots import load_pandas
 841    sunspots = load_pandas().data['SUNACTIVITY'].values
 842    mod = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
 843    mod.method = 'css'
 844    assert_raises(ValueError, mod._get_prediction_index, '1701', '1751', False)
 845
 846
 847def test_arima_predict_mle():
 848    cpi = load_macrodata_pandas().data['cpi'].values
 849    res1 = ARIMA(cpi, (4, 1, 1)).fit(disp=-1)
 850    # fit the model so that we get correct endog length but use
 851    file_path = os.path.join(current_path, 'results',
 852                             'results_arima_forecasts_all_mle.csv')
 853    with open(file_path, "rb") as test_data:
 854        arima_forecasts = np.genfromtxt(test_data, delimiter=",",
 855                                        skip_header=1, dtype=float)
 856    fc = arima_forecasts[:, 0]
 857    fcdyn = arima_forecasts[:, 1]
 858    fcdyn2 = arima_forecasts[:, 2]
 859    fcdyn3 = arima_forecasts[:, 3]
 860    fcdyn4 = arima_forecasts[:, 4]
 861
 862    # 0 indicates the first sample-observation below
 863    # ie., the index after the pre-sample, these are also differenced once
 864    # so the indices are moved back once from the cpi in levels
 865    # start < p, end <p 1959q2 - 1959q4
 866    start, end = 1, 3
 867    fv = res1.predict(start, end, typ='levels')
 868    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 869    # start < p, end 0 1959q3 - 1960q1
 870    start, end = 2, 4
 871    fv = res1.predict(start, end, typ='levels')
 872    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 873    # start < p, end >0 1959q3 - 1971q4
 874    start, end = 2, 51
 875    fv = res1.predict(start, end, typ='levels')
 876    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 877    # start < p, end nobs 1959q3 - 2009q3
 878    start, end = 2, 202
 879    fv = res1.predict(start, end, typ='levels')
 880    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 881    # start < p, end >nobs 1959q3 - 2015q4
 882    start, end = 2, 227
 883    fv = res1.predict(start, end, typ='levels')
 884    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 885    # start 0, end >0 1960q1 - 1971q4
 886    start, end = 4, 51
 887    fv = res1.predict(start, end, typ='levels')
 888    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 889    # start 0, end nobs 1960q1 - 2009q3
 890    start, end = 4, 202
 891    fv = res1.predict(start, end, typ='levels')
 892    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 893    # start 0, end >nobs 1960q1 - 2015q4
 894    start, end = 4, 227
 895    fv = res1.predict(start, end, typ='levels')
 896    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 897    # start >p, end >0 1965q1 - 1971q4
 898    start, end = 24, 51
 899    fv = res1.predict(start, end, typ='levels')
 900    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 901    # start >p, end nobs 1965q1 - 2009q3
 902    start, end = 24, 202
 903    fv = res1.predict(start, end, typ='levels')
 904    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 905    # start >p, end >nobs 1965q1 - 2015q4
 906    start, end = 24, 227
 907    fv = res1.predict(start, end, typ='levels')
 908    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 909    # start nobs, end nobs 2009q3 - 2009q3
 910    start, end = 202, 202
 911    fv = res1.predict(start, end, typ='levels')
 912    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_3)
 913    # start nobs, end >nobs 2009q3 - 2015q4
 914    start, end = 202, 227
 915    fv = res1.predict(start, end, typ='levels')
 916    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_3)
 917    # start >nobs, end >nobs 2009q4 - 2015q4
 918    start, end = 203, 227
 919    fv = res1.predict(start, end, typ='levels')
 920    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
 921    # defaults
 922    start, end = None, None
 923    fv = res1.predict(start, end, typ='levels')
 924    assert_almost_equal(fv, fc[1:203], DECIMAL_4)
 925
 926    # Dynamic
 927
 928    # start < p, end <p 1959q2 - 1959q4
 929    start, end = 1, 3
 930    with pytest.raises(ValueError, match='Start must be >= k_ar'):
 931        fv = res1.predict(start, end, dynamic=True, typ='levels')
 932
 933    # start < p, end 0 1959q3 - 1960q1
 934    start, end = 2, 4
 935    with pytest.raises(ValueError, match='Start must be >= k_ar'):
 936        res1.predict(start, end, dynamic=True, typ='levels')
 937
 938    # start < p, end >0 1959q3 - 1971q4
 939    start, end = 2, 51
 940    with pytest.raises(ValueError, match='Start must be >= k_ar'):
 941        res1.predict(start, end, dynamic=True, typ='levels')
 942
 943    # start < p, end nobs 1959q3 - 2009q3
 944    start, end = 2, 202
 945    with pytest.raises(ValueError, match='Start must be >= k_ar'):
 946        res1.predict(start, end, dynamic=True, typ='levels')
 947
 948    # start < p, end >nobs 1959q3 - 2015q4
 949    start, end = 2, 227
 950    with pytest.raises(ValueError, match='Start must be >= k_ar'):
 951        res1.predict(start, end, dynamic=True, typ='levels')
 952
 953    # start 0, end >0 1960q1 - 1971q4
 954    start, end = 5, 51
 955    fv = res1.predict(start, end, dynamic=True, typ='levels')
 956    assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
 957    # start 0, end nobs 1960q1 - 2009q3
 958    start, end = 5, 202
 959    fv = res1.predict(start, end, dynamic=True, typ='levels')
 960    assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
 961    # start 0, end >nobs 1960q1 - 2015q4
 962    start, end = 5, 227
 963    fv = res1.predict(start, end, dynamic=True, typ='levels')
 964    assert_almost_equal(fv, fcdyn[start:end + 1], DECIMAL_4)
 965    # start >p, end >0 1965q1 - 1971q4
 966    start, end = 24, 51
 967    fv = res1.predict(start, end, dynamic=True, typ='levels')
 968    assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
 969    # start >p, end nobs 1965q1 - 2009q3
 970    start, end = 24, 202
 971    fv = res1.predict(start, end, dynamic=True, typ='levels')
 972    assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
 973    # start >p, end >nobs 1965q1 - 2015q4
 974    start, end = 24, 227
 975    fv = res1.predict(start, end, dynamic=True, typ='levels')
 976    assert_almost_equal(fv, fcdyn2[start:end + 1], DECIMAL_4)
 977    # start nobs, end nobs 2009q3 - 2009q3
 978    start, end = 202, 202
 979    fv = res1.predict(start, end, dynamic=True, typ='levels')
 980    assert_almost_equal(fv, fcdyn3[start:end + 1], DECIMAL_4)
 981    # start nobs, end >nobs 2009q3 - 2015q4
 982    start, end = 202, 227
 983    fv = res1.predict(start, end, dynamic=True, typ='levels')
 984    assert_almost_equal(fv, fcdyn3[start:end + 1], DECIMAL_4)
 985    # start >nobs, end >nobs 2009q4 - 2015q4
 986    start, end = 203, 227
 987    fv = res1.predict(start, end, dynamic=True, typ='levels')
 988    assert_almost_equal(fv, fcdyn4[start:end + 1], DECIMAL_4)
 989    # defaults
 990    start, end = None, None
 991    fv = res1.predict(start, end, dynamic=True, typ='levels')
 992    assert_almost_equal(fv, fcdyn[5:203], DECIMAL_4)
 993
 994
 995def _check_start(model, given, expected, dynamic):
 996    start, _, _, _ = model._get_prediction_index(given, None, dynamic)
 997    assert_equal(start, expected)
 998
 999
1000def _check_end(model, given, end_expect, out_of_sample_expect):
1001    _, end, out_of_sample, _ = model._get_prediction_index(None, given, False)
1002    assert_equal((end, out_of_sample), (end_expect, out_of_sample_expect))
1003
1004
1005def test_arma_predict_indices():
1006    from statsmodels.datasets.sunspots import load_pandas
1007    sunspots = load_pandas().data['SUNACTIVITY'].values
1008    model = ARMA(sunspots, (9, 0), dates=sun_dates, freq='A')
1009    model.method = 'mle'
1010
1011    # raises - pre-sample + dynamic
1012    with pytest.raises(ValueError):
1013        model._get_prediction_index(0, None, True)
1014    with pytest.raises(ValueError):
1015        model._get_prediction_index(8, None, True)
1016    with pytest.raises(ValueError):
1017        model._get_prediction_index('1700', None, True)
1018    with pytest.raises(ValueError):
1019        model._get_prediction_index('1708', None, True)
1020
1021    # raises - start out of sample
1022
1023    # works - in-sample
1024    # None
1025    start_test_cases = [
1026        # given, expected, dynamic
1027        (None, 9, True),
1028        # all start get moved back by k_diff
1029        (9, 9, True),
1030        (10, 10, True),
1031        # what about end of sample start - last value is first
1032        # forecast
1033        (309, 309, True),
1034        (308, 308, True),
1035        (0, 0, False),
1036        (1, 1, False),
1037        (4, 4, False),
1038
1039        # all start get moved back by k_diff
1040        ('1709', 9, True),
1041        ('1710', 10, True),
1042        # what about end of sample start - last value is first
1043        # forecast
1044        ('2008', 308, True),
1045        ('2009', 309, True),
1046        ('1700', 0, False),
1047        ('1708', 8, False),
1048        ('1709', 9, False),
1049    ]
1050
1051    for case in start_test_cases:
1052        _check_start(*((model,) + case))
1053
1054    # the length of sunspot is 309, so last index is 208
1055    end_test_cases = [(None, 308, 0),
1056                      (307, 307, 0),
1057                      (308, 308, 0),
1058                      (309, 308, 1),
1059                      (312, 308, 4),
1060                      (51, 51, 0),
1061                      (333, 308, 25),
1062
1063                      ('2007', 307, 0),
1064                      ('2008', 308, 0),
1065                      ('2009', 308, 1),
1066                      ('2012', 308, 4),
1067                      ('1815', 115, 0),
1068                      ('2033', 308, 25),
1069                      ]
1070
1071    for case in end_test_cases:
1072        _check_end(*((model,) + case))
1073
1074
1075def test_arima_predict_indices():
1076    cpi = load_macrodata_pandas().data['cpi'].values
1077    model = ARIMA(cpi, (4, 1, 1), dates=cpi_dates, freq='Q')
1078    model.method = 'mle'
1079
1080    # starting indices
1081
1082    # raises - pre-sample + dynamic
1083    with pytest.raises(ValueError):
1084        model._get_prediction_index(0, None, True)
1085    with pytest.raises(ValueError):
1086        model._get_prediction_index(4, None, True)
1087    with pytest.raises(KeyError):
1088        model._get_prediction_index('1959Q1', None, True)
1089    with pytest.raises(ValueError):
1090        model._get_prediction_index('1960Q1', None, True)
1091
1092    # raises - index differenced away
1093    with pytest.raises(ValueError):
1094        model._get_prediction_index(0, None, False)
1095    with pytest.raises(KeyError):
1096        model._get_prediction_index('1959Q1', None, False)
1097
1098    # raises - start out of sample
1099
1100    # works - in-sample
1101    # None
1102    start_test_cases = [
1103        # given, expected, dynamic
1104        (None, 4, True),
1105        # all start get moved back by k_diff
1106        (5, 4, True),
1107        (6, 5, True),
1108        # what about end of sample start - last value is first
1109        # forecast
1110        (203, 202, True),
1111        (1, 0, False),
1112        (4, 3, False),
1113        (5, 4, False),
1114        # all start get moved back by k_diff
1115        ('1960Q2', 4, True),
1116        ('1960Q3', 5, True),
1117        # what about end of sample start - last value is first
1118        # forecast
1119        ('2009Q4', 202, True),
1120        ('1959Q2', 0, False),
1121        ('1960Q1', 3, False),
1122        ('1960Q2', 4, False),
1123    ]
1124
1125    for case in start_test_cases:
1126        _check_start(*((model,) + case))
1127
1128    # TODO: make sure dates are passing through unmolested
1129
1130    # the length of diff(cpi) is 202, so last index is 201
1131    end_test_cases = [(None, 201, 0),
1132                      (201, 200, 0),
1133                      (202, 201, 0),
1134                      (203, 201, 1),
1135                      (204, 201, 2),
1136                      (51, 50, 0),
1137                      (164 + 63, 201, 25),
1138
1139                      ('2009Q2', 200, 0),
1140                      ('2009Q3', 201, 0),
1141                      ('2009Q4', 201, 1),
1142                      ('2010Q1', 201, 2),
1143                      ('1971Q4', 50, 0),
1144                      ('2015Q4', 201, 25),
1145                      ]
1146
1147    for case in end_test_cases:
1148        _check_end(*((model,) + case))
1149
1150    # check higher k_diff
1151    # model.k_diff = 2
1152    model = ARIMA(cpi, (4, 2, 1), dates=cpi_dates, freq='Q')
1153    model.method = 'mle'
1154
1155    # raises - pre-sample + dynamic
1156    assert_raises(ValueError, model._get_prediction_index, 0, None, True)
1157    assert_raises(ValueError, model._get_prediction_index, 5, None, True)
1158    assert_raises(KeyError, model._get_prediction_index,
1159                  '1959Q1', None, True)
1160    assert_raises(ValueError, model._get_prediction_index,
1161                  '1960Q1', None, True)
1162
1163    # raises - index differenced away
1164    assert_raises(ValueError, model._get_prediction_index, 1, None, False)
1165    assert_raises(KeyError, model._get_prediction_index,
1166                  '1959Q2', None, False)
1167
1168    start_test_cases = [(None, 4, True),
1169                        # all start get moved back by k_diff
1170                        (6, 4, True),
1171                        # what about end of sample start - last value is first
1172                        # forecast
1173                        (203, 201, True),
1174                        (2, 0, False),
1175                        (4, 2, False),
1176                        (5, 3, False),
1177                        ('1960Q3', 4, True),
1178                        # what about end of sample start - last value is first
1179                        # forecast
1180                        ('2009Q4', 201, True),
1181                        ('2009Q4', 201, True),
1182                        ('1959Q3', 0, False),
1183                        ('1960Q1', 2, False),
1184                        ('1960Q2', 3, False),
1185                        ]
1186
1187    for case in start_test_cases:
1188        _check_start(*((model,) + case))
1189
1190    end_test_cases = [(None, 200, 0),
1191                      (201, 199, 0),
1192                      (202, 200, 0),
1193                      (203, 200, 1),
1194                      (204, 200, 2),
1195                      (51, 49, 0),
1196                      (164 + 63, 200, 25),
1197
1198                      ('2009Q2', 199, 0),
1199                      ('2009Q3', 200, 0),
1200                      ('2009Q4', 200, 1),
1201                      ('2010Q1', 200, 2),
1202                      ('1971Q4', 49, 0),
1203                      ('2015Q4', 200, 25),
1204                      ]
1205
1206    for case in end_test_cases:
1207        _check_end(*((model,) + case))
1208
1209
1210def test_arima_predict_indices_css():
1211    cpi = load_macrodata_pandas().data['cpi'].values
1212    # NOTE: Doing no-constant for now to kick the conditional exogenous
1213    # issue 274 down the road
1214    # go ahead and git the model to set up necessary variables
1215    model = ARIMA(cpi, (4, 1, 1))
1216    model.method = 'css'
1217
1218    assert_raises(ValueError, model._get_prediction_index, 0, None, False)
1219    assert_raises(ValueError, model._get_prediction_index, 0, None, True)
1220    assert_raises(ValueError, model._get_prediction_index, 2, None, False)
1221    assert_raises(ValueError, model._get_prediction_index, 2, None, True)
1222
1223
1224def test_arima_predict_css():
1225    cpi = load_macrodata_pandas().data['cpi'].values
1226    # NOTE: Doing no-constant for now to kick the conditional exogenous
1227    # issue 274 down the road
1228    # go ahead and git the model to set up necessary variables
1229    res1 = ARIMA(cpi, (4, 1, 1)).fit(disp=-1, method="css",
1230                                     trend="nc")
1231    # but use gretl parameters to predict to avoid precision problems
1232    params = np.array([1.231272508473910,
1233                       -0.282516097759915,
1234                       0.170052755782440,
1235                       -0.118203728504945,
1236                       -0.938783134717947])
1237    file_path = os.path.join(current_path, 'results',
1238                             'results_arima_forecasts_all_css.csv')
1239    with open(file_path, "rb") as test_data:
1240        arima_forecasts = np.genfromtxt(test_data, delimiter=",",
1241                                        skip_header=1, dtype=float)
1242    fc = arima_forecasts[:, 0]
1243    fcdyn = arima_forecasts[:, 1]
1244    fcdyn2 = arima_forecasts[:, 2]
1245    fcdyn3 = arima_forecasts[:, 3]
1246    fcdyn4 = arima_forecasts[:, 4]
1247
1248    start, end = 1, 3
1249    with pytest.raises(ValueError, match='Start must be >= k_ar'):
1250        res1.model.predict(params, start, end)
1251    #  start < p, end 0 1959q3 - 1960q1
1252    start, end = 2, 4
1253    with pytest.raises(ValueError, match='Start must be >= k_ar'):
1254        res1.model.predict(params, start, end)
1255    #  start < p, end >0 1959q3 - 1971q4
1256    start, end = 2, 51
1257    with pytest.raises(ValueError, match='Start must be >= k_ar'):
1258        res1.model.predict(params, start, end)
1259    #  start < p, end nobs 1959q3 - 2009q3
1260    start, end = 2, 202
1261    with pytest.raises(ValueError, match='Start must be >= k_ar'):
1262        res1.model.predict(params, start, end)
1263    #  start < p, end >nobs 1959q3 - 2015q4
1264    start, end = 2, 227
1265    with pytest.raises(ValueError, match='Start must be >= k_ar'):
1266        res1.model.predict(params, start, end)
1267    # start 0, end >0 1960q1 - 1971q4
1268    start, end = 5, 51
1269    fv = res1.model.predict(params, start, end, typ='levels')
1270    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1271    # start 0, end nobs 1960q1 - 2009q3
1272    start, end = 5, 202
1273    fv = res1.model.predict(params, start, end, typ='levels')
1274    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1275    # start 0, end >nobs 1960q1 - 2015q4
1276    # TODO: why detoriating precision?
1277    fv = res1.model.predict(params, start, end, typ='levels')
1278    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1279    # start >p, end >0 1965q1 - 1971q4
1280    start, end = 24, 51
1281    fv = res1.model.predict(params, start, end, typ='levels')
1282    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1283    # start >p, end nobs 1965q1 - 2009q3
1284    start, end = 24, 202
1285    fv = res1.model.predict(params, start, end, typ='levels')
1286    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1287    # start >p, end >nobs 1965q1 - 2015q4
1288    start, end = 24, 227
1289    fv = res1.model.predict(params, start, end, typ='levels')
1290    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1291    # start nobs, end nobs 2009q3 - 2009q3
1292    start, end = 202, 202
1293    fv = res1.model.predict(params, start, end, typ='levels')
1294    assert_almost_equal(fv, fc[start:end + 1], DECIMAL_4)
1295    # start nobs, end >nobs 2009q3 - 2015q4
1296    

Large files files are truncated, but you can click here to view the full file