PageRenderTime 38ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/tests/test_risk_compare_batch_iterative.py

https://gitlab.com/dandrews/zipline
Python | 131 lines | 92 code | 18 blank | 21 comment | 4 complexity | 54ab47f4153c70d4b2e918a4131f38ae MD5 | raw file
Possible License(s): Apache-2.0
  1. #
  2. # Copyright 2013 Quantopian, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import unittest
  16. import datetime
  17. import pytz
  18. import numpy as np
  19. import zipline.finance.risk as risk
  20. import zipline.finance.trading as trading
  21. from zipline.protocol import DailyReturn
  22. from test_risk import RETURNS
  23. class RiskCompareIterativeToBatch(unittest.TestCase):
  24. """
  25. Assert that RiskMetricsIterative and RiskMetricsBatch
  26. behave in the same way.
  27. """
  28. def setUp(self):
  29. self.start_date = datetime.datetime(
  30. year=2006,
  31. month=1,
  32. day=1,
  33. hour=0,
  34. minute=0,
  35. tzinfo=pytz.utc)
  36. self.end_date = datetime.datetime(
  37. year=2006, month=12, day=31, tzinfo=pytz.utc)
  38. self.oneday = datetime.timedelta(days=1)
  39. def test_risk_metrics_returns(self):
  40. risk_metrics_refactor = risk.RiskMetricsIterative(self.start_date)
  41. todays_date = self.start_date
  42. cur_returns = []
  43. for i, ret in enumerate(RETURNS):
  44. todays_return_obj = DailyReturn(
  45. todays_date,
  46. ret
  47. )
  48. cur_returns.append(todays_return_obj)
  49. # Move forward day counter to next trading day
  50. todays_date += self.oneday
  51. while not trading.environment.is_trading_day(todays_date):
  52. todays_date += self.oneday
  53. try:
  54. risk_metrics_original = risk.RiskMetricsBatch(
  55. start_date=self.start_date,
  56. end_date=todays_date,
  57. returns=cur_returns
  58. )
  59. except Exception as e:
  60. #assert that when original raises exception, same
  61. #exception is raised by risk_metrics_refactor
  62. np.testing.assert_raises(
  63. type(e), risk_metrics_refactor.update, todays_date, ret)
  64. continue
  65. risk_metrics_refactor.update(todays_date, ret)
  66. self.assertEqual(
  67. risk_metrics_original.start_date,
  68. risk_metrics_refactor.start_date)
  69. self.assertEqual(
  70. risk_metrics_original.end_date,
  71. risk_metrics_refactor.end_date)
  72. self.assertEqual(
  73. risk_metrics_original.treasury_duration,
  74. risk_metrics_refactor.treasury_duration)
  75. self.assertEqual(
  76. risk_metrics_original.treasury_curve,
  77. risk_metrics_refactor.treasury_curve)
  78. self.assertEqual(
  79. risk_metrics_original.treasury_period_return,
  80. risk_metrics_refactor.treasury_period_return)
  81. self.assertEqual(
  82. risk_metrics_original.benchmark_returns,
  83. risk_metrics_refactor.benchmark_returns)
  84. self.assertEqual(
  85. risk_metrics_original.algorithm_returns,
  86. risk_metrics_refactor.algorithm_returns)
  87. risk_original_dict = risk_metrics_original.to_dict()
  88. risk_refactor_dict = risk_metrics_refactor.to_dict()
  89. self.assertEqual(set(risk_original_dict.keys()),
  90. set(risk_refactor_dict.keys()))
  91. err_msg_format = """\
  92. "In update step {iter}: {measure} should be {truth} but is {returned}!"""
  93. for measure in risk_original_dict.iterkeys():
  94. if measure == 'max_drawdown':
  95. np.testing.assert_almost_equal(
  96. risk_refactor_dict[measure],
  97. risk_original_dict[measure],
  98. err_msg=err_msg_format.format(
  99. iter=i,
  100. measure=measure,
  101. truth=risk_original_dict[measure],
  102. returned=risk_refactor_dict[measure]))
  103. else:
  104. np.testing.assert_equal(
  105. risk_original_dict[measure],
  106. risk_refactor_dict[measure],
  107. err_msg_format.format(
  108. iter=i,
  109. measure=measure,
  110. truth=risk_original_dict[measure],
  111. returned=risk_refactor_dict[measure])
  112. )