PageRenderTime 42ms CodeModel.GetById 16ms RepoModel.GetById 0ms app.codeStats 0ms

/tests/test_risk_compare_batch_iterative.py

https://github.com/aidoom/zipline
Python | 161 lines | 121 code | 18 blank | 22 comment | 7 complexity | 174d3514c4f3f8295b0390f3f5f84969 MD5 | raw file
Possible License(s): Apache-2.0
  1. #
  2. # Copyright 2013 Quantopian, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import numbers
  16. import unittest
  17. import datetime
  18. import pytz
  19. import numpy as np
  20. import pandas as pd
  21. import zipline.finance.risk as risk
  22. import zipline.finance.trading as trading
  23. from zipline.protocol import DailyReturn
  24. from test_risk import RETURNS
  25. class RiskCompareIterativeToBatch(unittest.TestCase):
  26. """
  27. Assert that RiskMetricsIterative and RiskMetricsBatch
  28. behave in the same way.
  29. """
  30. def setUp(self):
  31. self.start_date = datetime.datetime(
  32. year=2006,
  33. month=1,
  34. day=1,
  35. hour=0,
  36. minute=0,
  37. tzinfo=pytz.utc)
  38. self.end_date = datetime.datetime(
  39. year=2006, month=12, day=31, tzinfo=pytz.utc)
  40. def test_risk_metrics_returns(self):
  41. trading.environment = trading.TradingEnvironment()
  42. # Advance start date to first date in the trading calendar
  43. if trading.environment.is_trading_day(self.start_date):
  44. start_date = self.start_date
  45. else:
  46. start_date = trading.environment.next_trading_day(self.start_date)
  47. self.all_benchmark_returns = pd.Series({
  48. x.date: x.returns
  49. for x in trading.environment.benchmark_returns
  50. if x.date >= self.start_date
  51. })
  52. start_index = trading.environment.trading_days.searchsorted(start_date)
  53. end_date = trading.environment.trading_days[
  54. start_index + len(RETURNS)]
  55. risk_metrics_refactor = risk.RiskMetricsIterative(start_date, end_date)
  56. todays_date = start_date
  57. cur_returns = []
  58. for i, ret in enumerate(RETURNS):
  59. todays_return_obj = DailyReturn(
  60. todays_date,
  61. ret
  62. )
  63. cur_returns.append(todays_return_obj)
  64. # Move forward day counter to next trading day
  65. todays_date = trading.environment.next_trading_day(todays_date)
  66. try:
  67. risk_metrics_original = risk.RiskMetricsBatch(
  68. start_date=start_date,
  69. end_date=todays_date,
  70. returns=cur_returns
  71. )
  72. except Exception as e:
  73. #assert that when original raises exception, same
  74. #exception is raised by risk_metrics_refactor
  75. np.testing.assert_raises(
  76. type(e),
  77. risk_metrics_refactor.update,
  78. todays_date,
  79. self.all_benchmark_returns[todays_return_obj.date]
  80. )
  81. continue
  82. risk_metrics_refactor.update(
  83. todays_date,
  84. ret,
  85. self.all_benchmark_returns[todays_return_obj.date])
  86. self.assertEqual(
  87. risk_metrics_original.start_date,
  88. risk_metrics_refactor.start_date)
  89. self.assertEqual(
  90. risk_metrics_original.end_date,
  91. risk_metrics_refactor.algorithm_returns.index[-1])
  92. self.assertEqual(
  93. risk_metrics_original.treasury_period_return,
  94. risk_metrics_refactor.treasury_period_return)
  95. np.testing.assert_allclose(
  96. risk_metrics_original.benchmark_returns,
  97. risk_metrics_refactor.benchmark_returns,
  98. rtol=0.001
  99. )
  100. np.testing.assert_allclose(
  101. risk_metrics_original.algorithm_returns,
  102. risk_metrics_refactor.algorithm_returns,
  103. rtol=0.001
  104. )
  105. risk_original_dict = risk_metrics_original.to_dict()
  106. risk_refactor_dict = risk_metrics_refactor.to_dict()
  107. self.assertEqual(set(risk_original_dict.keys()),
  108. set(risk_refactor_dict.keys()))
  109. err_msg_format = """\
  110. "In update step {iter}: {measure} should be {truth} but is {returned}!"""
  111. for measure in risk_original_dict.iterkeys():
  112. if measure == 'max_drawdown':
  113. np.testing.assert_almost_equal(
  114. risk_refactor_dict[measure],
  115. risk_original_dict[measure],
  116. err_msg=err_msg_format.format(
  117. iter=i,
  118. measure=measure,
  119. truth=risk_original_dict[measure],
  120. returned=risk_refactor_dict[measure]))
  121. else:
  122. if isinstance(risk_original_dict[measure], numbers.Real):
  123. np.testing.assert_allclose(
  124. risk_original_dict[measure],
  125. risk_refactor_dict[measure],
  126. rtol=0.001,
  127. err_msg=err_msg_format.format(
  128. iter=i,
  129. measure=measure,
  130. truth=risk_original_dict[measure],
  131. returned=risk_refactor_dict[measure])
  132. )
  133. else:
  134. np.testing.assert_equal(
  135. risk_original_dict[measure],
  136. risk_refactor_dict[measure],
  137. err_msg=err_msg_format.format(
  138. iter=i,
  139. measure=measure,
  140. truth=risk_original_dict[measure],
  141. returned=risk_refactor_dict[measure])
  142. )