PageRenderTime 50ms CodeModel.GetById 22ms RepoModel.GetById 1ms app.codeStats 0ms

/sklearn/utils/tests/test_class_weight.py

https://gitlab.com/0072016/0072016
Python | 281 lines | 190 code | 44 blank | 47 comment | 0 complexity | e59718ba9a1deef28278e0b377fb5fc3 MD5 | raw file
  1. import numpy as np
  2. from sklearn.linear_model import LogisticRegression
  3. from sklearn.datasets import make_blobs
  4. from sklearn.utils.class_weight import compute_class_weight
  5. from sklearn.utils.class_weight import compute_sample_weight
  6. from sklearn.utils.testing import assert_array_almost_equal
  7. from sklearn.utils.testing import assert_almost_equal
  8. from sklearn.utils.testing import assert_raises
  9. from sklearn.utils.testing import assert_raise_message
  10. from sklearn.utils.testing import assert_true
  11. from sklearn.utils.testing import assert_equal
  12. from sklearn.utils.testing import assert_warns
  13. def test_compute_class_weight():
  14. # Test (and demo) compute_class_weight.
  15. y = np.asarray([2, 2, 2, 3, 3, 4])
  16. classes = np.unique(y)
  17. cw = assert_warns(DeprecationWarning,
  18. compute_class_weight, "auto", classes, y)
  19. assert_almost_equal(cw.sum(), classes.shape)
  20. assert_true(cw[0] < cw[1] < cw[2])
  21. cw = compute_class_weight("balanced", classes, y)
  22. # total effect of samples is preserved
  23. class_counts = np.bincount(y)[2:]
  24. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  25. assert_true(cw[0] < cw[1] < cw[2])
  26. def test_compute_class_weight_not_present():
  27. # Raise error when y does not contain all class labels
  28. classes = np.arange(4)
  29. y = np.asarray([0, 0, 0, 1, 1, 2])
  30. assert_raises(ValueError, compute_class_weight, "auto", classes, y)
  31. assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
  32. # Raise error when y has items not in classes
  33. classes = np.arange(2)
  34. assert_raises(ValueError, compute_class_weight, "auto", classes, y)
  35. assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
  36. assert_raises(ValueError, compute_class_weight, {0: 1., 1: 2.}, classes, y)
  37. def test_compute_class_weight_dict():
  38. classes = np.arange(3)
  39. class_weights = {0: 1.0, 1: 2.0, 2: 3.0}
  40. y = np.asarray([0, 0, 1, 2])
  41. cw = compute_class_weight(class_weights, classes, y)
  42. # When the user specifies class weights, compute_class_weights should just
  43. # return them.
  44. assert_array_almost_equal(np.asarray([1.0, 2.0, 3.0]), cw)
  45. # When a class weight is specified that isn't in classes, a ValueError
  46. # should get raised
  47. msg = 'Class label 4 not present.'
  48. class_weights = {0: 1.0, 1: 2.0, 2: 3.0, 4: 1.5}
  49. assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
  50. classes, y)
  51. msg = 'Class label -1 not present.'
  52. class_weights = {-1: 5.0, 0: 1.0, 1: 2.0, 2: 3.0}
  53. assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
  54. classes, y)
  55. def test_compute_class_weight_invariance():
  56. # Test that results with class_weight="balanced" is invariant wrt
  57. # class imbalance if the number of samples is identical.
  58. # The test uses a balanced two class dataset with 100 datapoints.
  59. # It creates three versions, one where class 1 is duplicated
  60. # resulting in 150 points of class 1 and 50 of class 0,
  61. # one where there are 50 points in class 1 and 150 in class 0,
  62. # and one where there are 100 points of each class (this one is balanced
  63. # again).
  64. # With balancing class weights, all three should give the same model.
  65. X, y = make_blobs(centers=2, random_state=0)
  66. # create dataset where class 1 is duplicated twice
  67. X_1 = np.vstack([X] + [X[y == 1]] * 2)
  68. y_1 = np.hstack([y] + [y[y == 1]] * 2)
  69. # create dataset where class 0 is duplicated twice
  70. X_0 = np.vstack([X] + [X[y == 0]] * 2)
  71. y_0 = np.hstack([y] + [y[y == 0]] * 2)
  72. # duplicate everything
  73. X_ = np.vstack([X] * 2)
  74. y_ = np.hstack([y] * 2)
  75. # results should be identical
  76. logreg1 = LogisticRegression(class_weight="balanced").fit(X_1, y_1)
  77. logreg0 = LogisticRegression(class_weight="balanced").fit(X_0, y_0)
  78. logreg = LogisticRegression(class_weight="balanced").fit(X_, y_)
  79. assert_array_almost_equal(logreg1.coef_, logreg0.coef_)
  80. assert_array_almost_equal(logreg.coef_, logreg0.coef_)
  81. def test_compute_class_weight_auto_negative():
  82. # Test compute_class_weight when labels are negative
  83. # Test with balanced class labels.
  84. classes = np.array([-2, -1, 0])
  85. y = np.asarray([-1, -1, 0, 0, -2, -2])
  86. cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
  87. classes, y)
  88. assert_almost_equal(cw.sum(), classes.shape)
  89. assert_equal(len(cw), len(classes))
  90. assert_array_almost_equal(cw, np.array([1., 1., 1.]))
  91. cw = compute_class_weight("balanced", classes, y)
  92. assert_equal(len(cw), len(classes))
  93. assert_array_almost_equal(cw, np.array([1., 1., 1.]))
  94. # Test with unbalanced class labels.
  95. y = np.asarray([-1, 0, 0, -2, -2, -2])
  96. cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
  97. classes, y)
  98. assert_almost_equal(cw.sum(), classes.shape)
  99. assert_equal(len(cw), len(classes))
  100. assert_array_almost_equal(cw, np.array([0.545, 1.636, 0.818]), decimal=3)
  101. cw = compute_class_weight("balanced", classes, y)
  102. assert_equal(len(cw), len(classes))
  103. class_counts = np.bincount(y + 2)
  104. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  105. assert_array_almost_equal(cw, [2. / 3, 2., 1.])
  106. def test_compute_class_weight_auto_unordered():
  107. # Test compute_class_weight when classes are unordered
  108. classes = np.array([1, 0, 3])
  109. y = np.asarray([1, 0, 0, 3, 3, 3])
  110. cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
  111. classes, y)
  112. assert_almost_equal(cw.sum(), classes.shape)
  113. assert_equal(len(cw), len(classes))
  114. assert_array_almost_equal(cw, np.array([1.636, 0.818, 0.545]), decimal=3)
  115. cw = compute_class_weight("balanced", classes, y)
  116. class_counts = np.bincount(y)[classes]
  117. assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
  118. assert_array_almost_equal(cw, [2., 1., 2. / 3])
  119. def test_compute_sample_weight():
  120. # Test (and demo) compute_sample_weight.
  121. # Test with balanced classes
  122. y = np.asarray([1, 1, 1, 2, 2, 2])
  123. sample_weight = assert_warns(DeprecationWarning,
  124. compute_sample_weight, "auto", y)
  125. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  126. sample_weight = compute_sample_weight("balanced", y)
  127. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  128. # Test with user-defined weights
  129. sample_weight = compute_sample_weight({1: 2, 2: 1}, y)
  130. assert_array_almost_equal(sample_weight, [2., 2., 2., 1., 1., 1.])
  131. # Test with column vector of balanced classes
  132. y = np.asarray([[1], [1], [1], [2], [2], [2]])
  133. sample_weight = assert_warns(DeprecationWarning,
  134. compute_sample_weight, "auto", y)
  135. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  136. sample_weight = compute_sample_weight("balanced", y)
  137. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  138. # Test with unbalanced classes
  139. y = np.asarray([1, 1, 1, 2, 2, 2, 3])
  140. sample_weight = assert_warns(DeprecationWarning,
  141. compute_sample_weight, "auto", y)
  142. expected_auto = np.asarray([.6, .6, .6, .6, .6, .6, 1.8])
  143. assert_array_almost_equal(sample_weight, expected_auto)
  144. sample_weight = compute_sample_weight("balanced", y)
  145. expected_balanced = np.array([0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 2.3333])
  146. assert_array_almost_equal(sample_weight, expected_balanced, decimal=4)
  147. # Test with `None` weights
  148. sample_weight = compute_sample_weight(None, y)
  149. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 1.])
  150. # Test with multi-output of balanced classes
  151. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  152. sample_weight = assert_warns(DeprecationWarning,
  153. compute_sample_weight, "auto", y)
  154. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  155. sample_weight = compute_sample_weight("balanced", y)
  156. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  157. # Test with multi-output with user-defined weights
  158. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  159. sample_weight = compute_sample_weight([{1: 2, 2: 1}, {0: 1, 1: 2}], y)
  160. assert_array_almost_equal(sample_weight, [2., 2., 2., 2., 2., 2.])
  161. # Test with multi-output of unbalanced classes
  162. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [3, -1]])
  163. sample_weight = assert_warns(DeprecationWarning,
  164. compute_sample_weight, "auto", y)
  165. assert_array_almost_equal(sample_weight, expected_auto ** 2)
  166. sample_weight = compute_sample_weight("balanced", y)
  167. assert_array_almost_equal(sample_weight, expected_balanced ** 2, decimal=3)
  168. def test_compute_sample_weight_with_subsample():
  169. # Test compute_sample_weight with subsamples specified.
  170. # Test with balanced classes and all samples present
  171. y = np.asarray([1, 1, 1, 2, 2, 2])
  172. sample_weight = assert_warns(DeprecationWarning,
  173. compute_sample_weight, "auto", y)
  174. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  175. sample_weight = compute_sample_weight("balanced", y, range(6))
  176. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  177. # Test with column vector of balanced classes and all samples present
  178. y = np.asarray([[1], [1], [1], [2], [2], [2]])
  179. sample_weight = assert_warns(DeprecationWarning,
  180. compute_sample_weight, "auto", y)
  181. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  182. sample_weight = compute_sample_weight("balanced", y, range(6))
  183. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
  184. # Test with a subsample
  185. y = np.asarray([1, 1, 1, 2, 2, 2])
  186. sample_weight = assert_warns(DeprecationWarning,
  187. compute_sample_weight, "auto", y, range(4))
  188. assert_array_almost_equal(sample_weight, [.5, .5, .5, 1.5, 1.5, 1.5])
  189. sample_weight = compute_sample_weight("balanced", y, range(4))
  190. assert_array_almost_equal(sample_weight, [2. / 3, 2. / 3,
  191. 2. / 3, 2., 2., 2.])
  192. # Test with a bootstrap subsample
  193. y = np.asarray([1, 1, 1, 2, 2, 2])
  194. sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
  195. "auto", y, [0, 1, 1, 2, 2, 3])
  196. expected_auto = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
  197. assert_array_almost_equal(sample_weight, expected_auto)
  198. sample_weight = compute_sample_weight("balanced", y, [0, 1, 1, 2, 2, 3])
  199. expected_balanced = np.asarray([0.6, 0.6, 0.6, 3., 3., 3.])
  200. assert_array_almost_equal(sample_weight, expected_balanced)
  201. # Test with a bootstrap subsample for multi-output
  202. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  203. sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
  204. "auto", y, [0, 1, 1, 2, 2, 3])
  205. assert_array_almost_equal(sample_weight, expected_auto ** 2)
  206. sample_weight = compute_sample_weight("balanced", y, [0, 1, 1, 2, 2, 3])
  207. assert_array_almost_equal(sample_weight, expected_balanced ** 2)
  208. # Test with a missing class
  209. y = np.asarray([1, 1, 1, 2, 2, 2, 3])
  210. sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
  211. "auto", y, range(6))
  212. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
  213. sample_weight = compute_sample_weight("balanced", y, range(6))
  214. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
  215. # Test with a missing class for multi-output
  216. y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [2, 2]])
  217. sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
  218. "auto", y, range(6))
  219. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
  220. sample_weight = compute_sample_weight("balanced", y, range(6))
  221. assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
  222. def test_compute_sample_weight_errors():
  223. # Test compute_sample_weight raises errors expected.
  224. # Invalid preset string
  225. y = np.asarray([1, 1, 1, 2, 2, 2])
  226. y_ = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
  227. assert_raises(ValueError, compute_sample_weight, "ni", y)
  228. assert_raises(ValueError, compute_sample_weight, "ni", y, range(4))
  229. assert_raises(ValueError, compute_sample_weight, "ni", y_)
  230. assert_raises(ValueError, compute_sample_weight, "ni", y_, range(4))
  231. # Not "auto" for subsample
  232. assert_raises(ValueError,
  233. compute_sample_weight, {1: 2, 2: 1}, y, range(4))
  234. # Not a list or preset for multi-output
  235. assert_raises(ValueError, compute_sample_weight, {1: 2, 2: 1}, y_)
  236. # Incorrect length list for multi-output
  237. assert_raises(ValueError, compute_sample_weight, [{1: 2, 2: 1}], y_)