PageRenderTime 59ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/tests/risk/test_risk_compare_batch_iterative.py

https://github.com/quincysmiith/zipline
Python | 164 lines | 123 code | 19 blank | 22 comment | 7 complexity | 610d1698c9e77f4cb8255267fc3b4994 MD5 | raw file
Possible License(s): Apache-2.0
  1. #
  2. # Copyright 2013 Quantopian, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import numbers
  16. import unittest
  17. import datetime
  18. import pytz
  19. import numpy as np
  20. import pandas as pd
  21. import zipline.finance.risk as risk
  22. import zipline.finance.trading as trading
  23. from zipline.finance.trading import SimulationParameters
  24. from zipline.protocol import DailyReturn
  25. from test_risk import RETURNS
  26. class RiskCompareIterativeToBatch(unittest.TestCase):
  27. """
  28. Assert that RiskMetricsIterative and RiskMetricsBatch
  29. behave in the same way.
  30. """
  31. def setUp(self):
  32. self.start_date = datetime.datetime(
  33. year=2006,
  34. month=1,
  35. day=1,
  36. hour=0,
  37. minute=0,
  38. tzinfo=pytz.utc)
  39. self.end_date = datetime.datetime(
  40. year=2006, month=12, day=31, tzinfo=pytz.utc)
  41. def test_risk_metrics_returns(self):
  42. trading.environment = trading.TradingEnvironment()
  43. # Advance start date to first date in the trading calendar
  44. if trading.environment.is_trading_day(self.start_date):
  45. start_date = self.start_date
  46. else:
  47. start_date = trading.environment.next_trading_day(self.start_date)
  48. self.all_benchmark_returns = pd.Series({
  49. x.date: x.returns
  50. for x in trading.environment.benchmark_returns
  51. if x.date >= self.start_date
  52. })
  53. start_index = trading.environment.trading_days.searchsorted(start_date)
  54. end_date = trading.environment.trading_days[
  55. start_index + len(RETURNS)]
  56. sim_params = SimulationParameters(start_date, end_date)
  57. risk_metrics_refactor = risk.RiskMetricsIterative(sim_params)
  58. todays_date = start_date
  59. cur_returns = []
  60. for i, ret in enumerate(RETURNS):
  61. todays_return_obj = DailyReturn(
  62. todays_date,
  63. ret
  64. )
  65. cur_returns.append(todays_return_obj)
  66. try:
  67. risk_metrics_original = risk.RiskMetricsBatch(
  68. start_date=start_date,
  69. end_date=todays_date,
  70. returns=cur_returns
  71. )
  72. except Exception as e:
  73. #assert that when original raises exception, same
  74. #exception is raised by risk_metrics_refactor
  75. np.testing.assert_raises(
  76. type(e),
  77. risk_metrics_refactor.update,
  78. todays_date,
  79. self.all_benchmark_returns[todays_return_obj.date]
  80. )
  81. continue
  82. risk_metrics_refactor.update(
  83. todays_date,
  84. ret,
  85. self.all_benchmark_returns[todays_return_obj.date])
  86. # Move forward day counter to next trading day
  87. todays_date = trading.environment.next_trading_day(todays_date)
  88. self.assertEqual(
  89. risk_metrics_original.start_date,
  90. risk_metrics_refactor.start_date)
  91. self.assertEqual(
  92. risk_metrics_original.end_date,
  93. risk_metrics_refactor.algorithm_returns.index[-1])
  94. self.assertEqual(
  95. risk_metrics_original.treasury_period_return,
  96. risk_metrics_refactor.treasury_period_return)
  97. np.testing.assert_allclose(
  98. risk_metrics_original.benchmark_returns,
  99. risk_metrics_refactor.benchmark_returns,
  100. rtol=0.001
  101. )
  102. np.testing.assert_allclose(
  103. risk_metrics_original.algorithm_returns,
  104. risk_metrics_refactor.algorithm_returns,
  105. rtol=0.001
  106. )
  107. risk_original_dict = risk_metrics_original.to_dict()
  108. risk_refactor_dict = risk_metrics_refactor.to_dict()
  109. self.assertEqual(set(risk_original_dict.keys()),
  110. set(risk_refactor_dict.keys()))
  111. err_msg_format = """\
  112. "In update step {iter}: {measure} should be {truth} but is {returned}!"""
  113. for measure in risk_original_dict.iterkeys():
  114. if measure == 'max_drawdown':
  115. np.testing.assert_almost_equal(
  116. risk_refactor_dict[measure],
  117. risk_original_dict[measure],
  118. err_msg=err_msg_format.format(
  119. iter=i,
  120. measure=measure,
  121. truth=risk_original_dict[measure],
  122. returned=risk_refactor_dict[measure]))
  123. else:
  124. if isinstance(risk_original_dict[measure], numbers.Real):
  125. np.testing.assert_allclose(
  126. risk_original_dict[measure],
  127. risk_refactor_dict[measure],
  128. rtol=0.001,
  129. err_msg=err_msg_format.format(
  130. iter=i,
  131. measure=measure,
  132. truth=risk_original_dict[measure],
  133. returned=risk_refactor_dict[measure])
  134. )
  135. else:
  136. np.testing.assert_equal(
  137. risk_original_dict[measure],
  138. risk_refactor_dict[measure],
  139. err_msg=err_msg_format.format(
  140. iter=i,
  141. measure=measure,
  142. truth=risk_original_dict[measure],
  143. returned=risk_refactor_dict[measure])
  144. )