/sklearn/utils/tests/test_class_weight.py
Python | 281 lines | 190 code | 44 blank | 47 comment | 0 complexity | e59718ba9a1deef28278e0b377fb5fc3 MD5 | raw file
- import numpy as np
- from sklearn.linear_model import LogisticRegression
- from sklearn.datasets import make_blobs
- from sklearn.utils.class_weight import compute_class_weight
- from sklearn.utils.class_weight import compute_sample_weight
- from sklearn.utils.testing import assert_array_almost_equal
- from sklearn.utils.testing import assert_almost_equal
- from sklearn.utils.testing import assert_raises
- from sklearn.utils.testing import assert_raise_message
- from sklearn.utils.testing import assert_true
- from sklearn.utils.testing import assert_equal
- from sklearn.utils.testing import assert_warns
- def test_compute_class_weight():
- # Test (and demo) compute_class_weight.
- y = np.asarray([2, 2, 2, 3, 3, 4])
- classes = np.unique(y)
- cw = assert_warns(DeprecationWarning,
- compute_class_weight, "auto", classes, y)
- assert_almost_equal(cw.sum(), classes.shape)
- assert_true(cw[0] < cw[1] < cw[2])
- cw = compute_class_weight("balanced", classes, y)
- # total effect of samples is preserved
- class_counts = np.bincount(y)[2:]
- assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
- assert_true(cw[0] < cw[1] < cw[2])
- def test_compute_class_weight_not_present():
- # Raise error when y does not contain all class labels
- classes = np.arange(4)
- y = np.asarray([0, 0, 0, 1, 1, 2])
- assert_raises(ValueError, compute_class_weight, "auto", classes, y)
- assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
- # Raise error when y has items not in classes
- classes = np.arange(2)
- assert_raises(ValueError, compute_class_weight, "auto", classes, y)
- assert_raises(ValueError, compute_class_weight, "balanced", classes, y)
- assert_raises(ValueError, compute_class_weight, {0: 1., 1: 2.}, classes, y)
- def test_compute_class_weight_dict():
- classes = np.arange(3)
- class_weights = {0: 1.0, 1: 2.0, 2: 3.0}
- y = np.asarray([0, 0, 1, 2])
- cw = compute_class_weight(class_weights, classes, y)
- # When the user specifies class weights, compute_class_weights should just
- # return them.
- assert_array_almost_equal(np.asarray([1.0, 2.0, 3.0]), cw)
- # When a class weight is specified that isn't in classes, a ValueError
- # should get raised
- msg = 'Class label 4 not present.'
- class_weights = {0: 1.0, 1: 2.0, 2: 3.0, 4: 1.5}
- assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
- classes, y)
- msg = 'Class label -1 not present.'
- class_weights = {-1: 5.0, 0: 1.0, 1: 2.0, 2: 3.0}
- assert_raise_message(ValueError, msg, compute_class_weight, class_weights,
- classes, y)
- def test_compute_class_weight_invariance():
- # Test that results with class_weight="balanced" is invariant wrt
- # class imbalance if the number of samples is identical.
- # The test uses a balanced two class dataset with 100 datapoints.
- # It creates three versions, one where class 1 is duplicated
- # resulting in 150 points of class 1 and 50 of class 0,
- # one where there are 50 points in class 1 and 150 in class 0,
- # and one where there are 100 points of each class (this one is balanced
- # again).
- # With balancing class weights, all three should give the same model.
- X, y = make_blobs(centers=2, random_state=0)
- # create dataset where class 1 is duplicated twice
- X_1 = np.vstack([X] + [X[y == 1]] * 2)
- y_1 = np.hstack([y] + [y[y == 1]] * 2)
- # create dataset where class 0 is duplicated twice
- X_0 = np.vstack([X] + [X[y == 0]] * 2)
- y_0 = np.hstack([y] + [y[y == 0]] * 2)
- # duplicate everything
- X_ = np.vstack([X] * 2)
- y_ = np.hstack([y] * 2)
- # results should be identical
- logreg1 = LogisticRegression(class_weight="balanced").fit(X_1, y_1)
- logreg0 = LogisticRegression(class_weight="balanced").fit(X_0, y_0)
- logreg = LogisticRegression(class_weight="balanced").fit(X_, y_)
- assert_array_almost_equal(logreg1.coef_, logreg0.coef_)
- assert_array_almost_equal(logreg.coef_, logreg0.coef_)
- def test_compute_class_weight_auto_negative():
- # Test compute_class_weight when labels are negative
- # Test with balanced class labels.
- classes = np.array([-2, -1, 0])
- y = np.asarray([-1, -1, 0, 0, -2, -2])
- cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
- classes, y)
- assert_almost_equal(cw.sum(), classes.shape)
- assert_equal(len(cw), len(classes))
- assert_array_almost_equal(cw, np.array([1., 1., 1.]))
- cw = compute_class_weight("balanced", classes, y)
- assert_equal(len(cw), len(classes))
- assert_array_almost_equal(cw, np.array([1., 1., 1.]))
- # Test with unbalanced class labels.
- y = np.asarray([-1, 0, 0, -2, -2, -2])
- cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
- classes, y)
- assert_almost_equal(cw.sum(), classes.shape)
- assert_equal(len(cw), len(classes))
- assert_array_almost_equal(cw, np.array([0.545, 1.636, 0.818]), decimal=3)
- cw = compute_class_weight("balanced", classes, y)
- assert_equal(len(cw), len(classes))
- class_counts = np.bincount(y + 2)
- assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
- assert_array_almost_equal(cw, [2. / 3, 2., 1.])
- def test_compute_class_weight_auto_unordered():
- # Test compute_class_weight when classes are unordered
- classes = np.array([1, 0, 3])
- y = np.asarray([1, 0, 0, 3, 3, 3])
- cw = assert_warns(DeprecationWarning, compute_class_weight, "auto",
- classes, y)
- assert_almost_equal(cw.sum(), classes.shape)
- assert_equal(len(cw), len(classes))
- assert_array_almost_equal(cw, np.array([1.636, 0.818, 0.545]), decimal=3)
- cw = compute_class_weight("balanced", classes, y)
- class_counts = np.bincount(y)[classes]
- assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
- assert_array_almost_equal(cw, [2., 1., 2. / 3])
- def test_compute_sample_weight():
- # Test (and demo) compute_sample_weight.
- # Test with balanced classes
- y = np.asarray([1, 1, 1, 2, 2, 2])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- sample_weight = compute_sample_weight("balanced", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- # Test with user-defined weights
- sample_weight = compute_sample_weight({1: 2, 2: 1}, y)
- assert_array_almost_equal(sample_weight, [2., 2., 2., 1., 1., 1.])
- # Test with column vector of balanced classes
- y = np.asarray([[1], [1], [1], [2], [2], [2]])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- sample_weight = compute_sample_weight("balanced", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- # Test with unbalanced classes
- y = np.asarray([1, 1, 1, 2, 2, 2, 3])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- expected_auto = np.asarray([.6, .6, .6, .6, .6, .6, 1.8])
- assert_array_almost_equal(sample_weight, expected_auto)
- sample_weight = compute_sample_weight("balanced", y)
- expected_balanced = np.array([0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 2.3333])
- assert_array_almost_equal(sample_weight, expected_balanced, decimal=4)
- # Test with `None` weights
- sample_weight = compute_sample_weight(None, y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 1.])
- # Test with multi-output of balanced classes
- y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- sample_weight = compute_sample_weight("balanced", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- # Test with multi-output with user-defined weights
- y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
- sample_weight = compute_sample_weight([{1: 2, 2: 1}, {0: 1, 1: 2}], y)
- assert_array_almost_equal(sample_weight, [2., 2., 2., 2., 2., 2.])
- # Test with multi-output of unbalanced classes
- y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [3, -1]])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, expected_auto ** 2)
- sample_weight = compute_sample_weight("balanced", y)
- assert_array_almost_equal(sample_weight, expected_balanced ** 2, decimal=3)
- def test_compute_sample_weight_with_subsample():
- # Test compute_sample_weight with subsamples specified.
- # Test with balanced classes and all samples present
- y = np.asarray([1, 1, 1, 2, 2, 2])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- sample_weight = compute_sample_weight("balanced", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- # Test with column vector of balanced classes and all samples present
- y = np.asarray([[1], [1], [1], [2], [2], [2]])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y)
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- sample_weight = compute_sample_weight("balanced", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1.])
- # Test with a subsample
- y = np.asarray([1, 1, 1, 2, 2, 2])
- sample_weight = assert_warns(DeprecationWarning,
- compute_sample_weight, "auto", y, range(4))
- assert_array_almost_equal(sample_weight, [.5, .5, .5, 1.5, 1.5, 1.5])
- sample_weight = compute_sample_weight("balanced", y, range(4))
- assert_array_almost_equal(sample_weight, [2. / 3, 2. / 3,
- 2. / 3, 2., 2., 2.])
- # Test with a bootstrap subsample
- y = np.asarray([1, 1, 1, 2, 2, 2])
- sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
- "auto", y, [0, 1, 1, 2, 2, 3])
- expected_auto = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
- assert_array_almost_equal(sample_weight, expected_auto)
- sample_weight = compute_sample_weight("balanced", y, [0, 1, 1, 2, 2, 3])
- expected_balanced = np.asarray([0.6, 0.6, 0.6, 3., 3., 3.])
- assert_array_almost_equal(sample_weight, expected_balanced)
- # Test with a bootstrap subsample for multi-output
- y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
- sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
- "auto", y, [0, 1, 1, 2, 2, 3])
- assert_array_almost_equal(sample_weight, expected_auto ** 2)
- sample_weight = compute_sample_weight("balanced", y, [0, 1, 1, 2, 2, 3])
- assert_array_almost_equal(sample_weight, expected_balanced ** 2)
- # Test with a missing class
- y = np.asarray([1, 1, 1, 2, 2, 2, 3])
- sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
- "auto", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
- sample_weight = compute_sample_weight("balanced", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
- # Test with a missing class for multi-output
- y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [2, 2]])
- sample_weight = assert_warns(DeprecationWarning, compute_sample_weight,
- "auto", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
- sample_weight = compute_sample_weight("balanced", y, range(6))
- assert_array_almost_equal(sample_weight, [1., 1., 1., 1., 1., 1., 0.])
- def test_compute_sample_weight_errors():
- # Test compute_sample_weight raises errors expected.
- # Invalid preset string
- y = np.asarray([1, 1, 1, 2, 2, 2])
- y_ = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1]])
- assert_raises(ValueError, compute_sample_weight, "ni", y)
- assert_raises(ValueError, compute_sample_weight, "ni", y, range(4))
- assert_raises(ValueError, compute_sample_weight, "ni", y_)
- assert_raises(ValueError, compute_sample_weight, "ni", y_, range(4))
- # Not "auto" for subsample
- assert_raises(ValueError,
- compute_sample_weight, {1: 2, 2: 1}, y, range(4))
- # Not a list or preset for multi-output
- assert_raises(ValueError, compute_sample_weight, {1: 2, 2: 1}, y_)
- # Incorrect length list for multi-output
- assert_raises(ValueError, compute_sample_weight, [{1: 2, 2: 1}], y_)