PageRenderTime 51ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/ramp/tests/test_metrics.py

https://github.com/psattige/ramp
Python | 91 lines | 74 code | 16 blank | 1 comment | 3 complexity | b12b32c707264a335a5ff4d7ab201fe5 MD5 | raw file
  1. import sys
  2. sys.path.append('../..')
  3. import unittest
  4. import numpy as np
  5. import pandas as pd
  6. from pandas import DataFrame, Series, Index
  7. from pandas.util.testing import assert_almost_equal
  8. from ramp.builders import *
  9. from ramp.features.base import F, Map
  10. from ramp.metrics import *
  11. from ramp.model_definition import ModelDefinition
  12. from ramp import modeling
  13. from ramp.reporters import *
  14. from ramp.result import Result
  15. from ramp.tests.test_features import make_data
  16. class TestMetrics(unittest.TestCase):
  17. def setUp(self):
  18. self.data = make_data(100)
  19. self.result = Result(self.data, self.data,
  20. self.data.y, self.data.y,
  21. self.data.y, model_def=None,
  22. fitted_model=None, original_data=self.data)
  23. def test_recall(self):
  24. self.data['y'] = [0]*20 + [1]*80
  25. self.data['preds'] = [0]*10 + [.5]*60 + [1]*30
  26. self.result.y_test = self.data.y
  27. self.result.y_preds = self.data.preds
  28. m = Recall()
  29. thresholds = np.arange(0,1,.1)
  30. expected = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.375, 0.375, 0.375, 0.375]
  31. assert_almost_equal(expected, [m.score(self.result, t) for t in thresholds])
  32. def test_weighted_recall(self):
  33. self.data['y'] = [0]*20 + [1]*80
  34. self.data['weights'] = [0]*50 + [10]*50
  35. self.data['preds'] = [0]*10 + [.5]*60 + [.8]*30
  36. self.result.y_test = self.data.y
  37. self.result.y_preds = self.data.preds
  38. self.result.original_data = self.data
  39. m = WeightedRecall(weight_column='weights')
  40. thresholds = np.arange(0,1,.1)
  41. # [ 0. 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
  42. expected = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, .6, .6, .6, 0]
  43. actuals = [m.score(self.result, t) for t in thresholds]
  44. assert_almost_equal(expected, actuals)
  45. class TestMetricReporter(unittest.TestCase):
  46. def setUp(self):
  47. self.data = make_data(100)
  48. self.result = Result(self.data, self.data,
  49. self.data.y, self.data.y,
  50. self.data.y, model_def=None,
  51. fitted_model=None, original_data=self.data)
  52. def test_metric_reporter(self):
  53. self.data['y'] = [0]*20 + [1]*80
  54. self.data['preds'] = [0]*10 + [.5]*60 + [.8]*30
  55. self.result.y_test = self.data.y
  56. self.result.y_preds = self.data.preds
  57. r = MetricReporter(Recall(.7))
  58. r.update(self.result)
  59. summary = r.summary_df()
  60. n_thresh = 1
  61. self.assertEqual(len(summary), n_thresh)
  62. r.plot()
  63. def test_dual_threshold_reporter(self):
  64. self.data['y'] = [0]*20 + [1]*80
  65. self.data['preds'] = [0]*10 + [.5]*60 + [.8]*30
  66. self.result.y_test = self.data.y
  67. self.result.y_preds = self.data.preds
  68. r = DualThresholdMetricReporter(Recall(), PositiveRate())
  69. r.update(self.result)
  70. summary = r.summary_df()
  71. n_thresh = 3
  72. self.assertEqual(len(summary), n_thresh)
  73. r.plot()
  74. if __name__ == '__main__':
  75. unittest.main()