PageRenderTime 22ms CodeModel.GetById 26ms RepoModel.GetById 0ms app.codeStats 0ms

/sklearn/utils/tests/test_estimator_checks.py

http://github.com/scikit-learn/scikit-learn
Python | 633 lines | 443 code | 124 blank | 66 comment | 35 complexity | eb4952e0c17868ddf6704c3990ffe11a MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import unittest
  2. import sys
  3. import numpy as np
  4. import scipy.sparse as sp
  5. import joblib
  6. from io import StringIO
  7. from sklearn.base import BaseEstimator, ClassifierMixin
  8. from sklearn.utils import deprecated
  9. from sklearn.utils._testing import (assert_raises_regex,
  10. ignore_warnings,
  11. assert_warns, assert_raises,
  12. SkipTest)
  13. from sklearn.utils.estimator_checks import check_estimator, _NotAnArray
  14. from sklearn.utils.estimator_checks \
  15. import check_class_weight_balanced_linear_classifier
  16. from sklearn.utils.estimator_checks import set_random_state
  17. from sklearn.utils.estimator_checks import _set_checking_parameters
  18. from sklearn.utils.estimator_checks import check_estimators_unfitted
  19. from sklearn.utils.estimator_checks import check_fit_score_takes_y
  20. from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
  21. from sklearn.utils.estimator_checks import check_classifier_data_not_an_array
  22. from sklearn.utils.estimator_checks import check_regressor_data_not_an_array
  23. from sklearn.utils.validation import check_is_fitted
  24. from sklearn.utils.estimator_checks import check_outlier_corruption
  25. from sklearn.utils.fixes import _parse_version
  26. from sklearn.ensemble import RandomForestClassifier
  27. from sklearn.linear_model import LinearRegression, SGDClassifier
  28. from sklearn.mixture import GaussianMixture
  29. from sklearn.cluster import MiniBatchKMeans
  30. from sklearn.decomposition import NMF
  31. from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
  32. from sklearn.svm import SVC
  33. from sklearn.neighbors import KNeighborsRegressor
  34. from sklearn.tree import DecisionTreeClassifier
  35. from sklearn.utils.validation import check_array
  36. from sklearn.utils import all_estimators
  37. class CorrectNotFittedError(ValueError):
  38. """Exception class to raise if estimator is used before fitting.
  39. Like NotFittedError, it inherits from ValueError, but not from
  40. AttributeError. Used for testing only.
  41. """
  42. class BaseBadClassifier(ClassifierMixin, BaseEstimator):
  43. def fit(self, X, y):
  44. return self
  45. def predict(self, X):
  46. return np.ones(X.shape[0])
  47. class ChangesDict(BaseEstimator):
  48. def __init__(self, key=0):
  49. self.key = key
  50. def fit(self, X, y=None):
  51. X, y = self._validate_data(X, y)
  52. return self
  53. def predict(self, X):
  54. X = check_array(X)
  55. self.key = 1000
  56. return np.ones(X.shape[0])
  57. class SetsWrongAttribute(BaseEstimator):
  58. def __init__(self, acceptable_key=0):
  59. self.acceptable_key = acceptable_key
  60. def fit(self, X, y=None):
  61. self.wrong_attribute = 0
  62. X, y = self._validate_data(X, y)
  63. return self
  64. class ChangesWrongAttribute(BaseEstimator):
  65. def __init__(self, wrong_attribute=0):
  66. self.wrong_attribute = wrong_attribute
  67. def fit(self, X, y=None):
  68. self.wrong_attribute = 1
  69. X, y = self._validate_data(X, y)
  70. return self
  71. class ChangesUnderscoreAttribute(BaseEstimator):
  72. def fit(self, X, y=None):
  73. self._good_attribute = 1
  74. X, y = self._validate_data(X, y)
  75. return self
  76. class RaisesErrorInSetParams(BaseEstimator):
  77. def __init__(self, p=0):
  78. self.p = p
  79. def set_params(self, **kwargs):
  80. if 'p' in kwargs:
  81. p = kwargs.pop('p')
  82. if p < 0:
  83. raise ValueError("p can't be less than 0")
  84. self.p = p
  85. return super().set_params(**kwargs)
  86. def fit(self, X, y=None):
  87. X, y = self._validate_data(X, y)
  88. return self
  89. class ModifiesValueInsteadOfRaisingError(BaseEstimator):
  90. def __init__(self, p=0):
  91. self.p = p
  92. def set_params(self, **kwargs):
  93. if 'p' in kwargs:
  94. p = kwargs.pop('p')
  95. if p < 0:
  96. p = 0
  97. self.p = p
  98. return super().set_params(**kwargs)
  99. def fit(self, X, y=None):
  100. X, y = self._validate_data(X, y)
  101. return self
  102. class ModifiesAnotherValue(BaseEstimator):
  103. def __init__(self, a=0, b='method1'):
  104. self.a = a
  105. self.b = b
  106. def set_params(self, **kwargs):
  107. if 'a' in kwargs:
  108. a = kwargs.pop('a')
  109. self.a = a
  110. if a is None:
  111. kwargs.pop('b')
  112. self.b = 'method2'
  113. return super().set_params(**kwargs)
  114. def fit(self, X, y=None):
  115. X, y = self._validate_data(X, y)
  116. return self
  117. class NoCheckinPredict(BaseBadClassifier):
  118. def fit(self, X, y):
  119. X, y = self._validate_data(X, y)
  120. return self
  121. class NoSparseClassifier(BaseBadClassifier):
  122. def fit(self, X, y):
  123. X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'])
  124. if sp.issparse(X):
  125. raise ValueError("Nonsensical Error")
  126. return self
  127. def predict(self, X):
  128. X = check_array(X)
  129. return np.ones(X.shape[0])
  130. class CorrectNotFittedErrorClassifier(BaseBadClassifier):
  131. def fit(self, X, y):
  132. X, y = self._validate_data(X, y)
  133. self.coef_ = np.ones(X.shape[1])
  134. return self
  135. def predict(self, X):
  136. check_is_fitted(self)
  137. X = check_array(X)
  138. return np.ones(X.shape[0])
  139. class NoSampleWeightPandasSeriesType(BaseEstimator):
  140. def fit(self, X, y, sample_weight=None):
  141. # Convert data
  142. X, y = self._validate_data(
  143. X, y,
  144. accept_sparse=("csr", "csc"),
  145. multi_output=True,
  146. y_numeric=True)
  147. # Function is only called after we verify that pandas is installed
  148. from pandas import Series
  149. if isinstance(sample_weight, Series):
  150. raise ValueError("Estimator does not accept 'sample_weight'"
  151. "of type pandas.Series")
  152. return self
  153. def predict(self, X):
  154. X = check_array(X)
  155. return np.ones(X.shape[0])
  156. class BadBalancedWeightsClassifier(BaseBadClassifier):
  157. def __init__(self, class_weight=None):
  158. self.class_weight = class_weight
  159. def fit(self, X, y):
  160. from sklearn.preprocessing import LabelEncoder
  161. from sklearn.utils import compute_class_weight
  162. label_encoder = LabelEncoder().fit(y)
  163. classes = label_encoder.classes_
  164. class_weight = compute_class_weight(self.class_weight, classes, y)
  165. # Intentionally modify the balanced class_weight
  166. # to simulate a bug and raise an exception
  167. if self.class_weight == "balanced":
  168. class_weight += 1.
  169. # Simply assigning coef_ to the class_weight
  170. self.coef_ = class_weight
  171. return self
  172. class BadTransformerWithoutMixin(BaseEstimator):
  173. def fit(self, X, y=None):
  174. X = self._validate_data(X)
  175. return self
  176. def transform(self, X):
  177. X = check_array(X)
  178. return X
  179. class NotInvariantPredict(BaseEstimator):
  180. def fit(self, X, y):
  181. # Convert data
  182. X, y = self._validate_data(
  183. X, y,
  184. accept_sparse=("csr", "csc"),
  185. multi_output=True,
  186. y_numeric=True)
  187. return self
  188. def predict(self, X):
  189. # return 1 if X has more than one element else return 0
  190. X = check_array(X)
  191. if X.shape[0] > 1:
  192. return np.ones(X.shape[0])
  193. return np.zeros(X.shape[0])
  194. class LargeSparseNotSupportedClassifier(BaseEstimator):
  195. def fit(self, X, y):
  196. X, y = self._validate_data(
  197. X, y,
  198. accept_sparse=("csr", "csc", "coo"),
  199. accept_large_sparse=True,
  200. multi_output=True,
  201. y_numeric=True)
  202. if sp.issparse(X):
  203. if X.getformat() == "coo":
  204. if X.row.dtype == "int64" or X.col.dtype == "int64":
  205. raise ValueError(
  206. "Estimator doesn't support 64-bit indices")
  207. elif X.getformat() in ["csc", "csr"]:
  208. assert "int64" not in (X.indices.dtype, X.indptr.dtype),\
  209. "Estimator doesn't support 64-bit indices"
  210. return self
  211. class SparseTransformer(BaseEstimator):
  212. def fit(self, X, y=None):
  213. self.X_shape_ = self._validate_data(X).shape
  214. return self
  215. def fit_transform(self, X, y=None):
  216. return self.fit(X, y).transform(X)
  217. def transform(self, X):
  218. X = check_array(X)
  219. if X.shape[1] != self.X_shape_[1]:
  220. raise ValueError('Bad number of features')
  221. return sp.csr_matrix(X)
  222. class EstimatorInconsistentForPandas(BaseEstimator):
  223. def fit(self, X, y):
  224. try:
  225. from pandas import DataFrame
  226. if isinstance(X, DataFrame):
  227. self.value_ = X.iloc[0, 0]
  228. else:
  229. X = check_array(X)
  230. self.value_ = X[1, 0]
  231. return self
  232. except ImportError:
  233. X = check_array(X)
  234. self.value_ = X[1, 0]
  235. return self
  236. def predict(self, X):
  237. X = check_array(X)
  238. return np.array([self.value_] * X.shape[0])
  239. class UntaggedBinaryClassifier(DecisionTreeClassifier):
  240. # Toy classifier that only supports binary classification, will fail tests.
  241. def fit(self, X, y, sample_weight=None):
  242. super().fit(X, y, sample_weight)
  243. if np.all(self.n_classes_ > 2):
  244. raise ValueError('Only 2 classes are supported')
  245. return self
  246. class TaggedBinaryClassifier(UntaggedBinaryClassifier):
  247. # Toy classifier that only supports binary classification.
  248. def _more_tags(self):
  249. return {'binary_only': True}
  250. class RequiresPositiveYRegressor(LinearRegression):
  251. def fit(self, X, y):
  252. X, y = self._validate_data(X, y, multi_output=True)
  253. if (y <= 0).any():
  254. raise ValueError('negative y values not supported!')
  255. return super().fit(X, y)
  256. def _more_tags(self):
  257. return {"requires_positive_y": True}
  258. def test_not_an_array_array_function():
  259. np_version = _parse_version(np.__version__)
  260. if np_version < (1, 17):
  261. raise SkipTest("array_function protocol not supported in numpy <1.17")
  262. not_array = _NotAnArray(np.ones(10))
  263. msg = "Don't want to call array_function sum!"
  264. assert_raises_regex(TypeError, msg, np.sum, not_array)
  265. # always returns True
  266. assert np.may_share_memory(not_array, None)
  267. def test_check_fit_score_takes_y_works_on_deprecated_fit():
  268. # Tests that check_fit_score_takes_y works on a class with
  269. # a deprecated fit method
  270. class TestEstimatorWithDeprecatedFitMethod(BaseEstimator):
  271. @deprecated("Deprecated for the purpose of testing "
  272. "check_fit_score_takes_y")
  273. def fit(self, X, y):
  274. return self
  275. check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())
  276. @ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
  277. def test_check_estimator():
  278. # tests that the estimator actually fails on "bad" estimators.
  279. # not a complete test of all checks, which are very extensive.
  280. # check that we have a set_params and can clone
  281. msg = "it does not implement a 'get_params' method"
  282. assert_raises_regex(TypeError, msg, check_estimator, object)
  283. msg = "object has no attribute '_get_tags'"
  284. assert_raises_regex(AttributeError, msg, check_estimator, object())
  285. # check that values returned by get_params match set_params
  286. msg = "get_params result does not match what was passed to set_params"
  287. assert_raises_regex(AssertionError, msg, check_estimator,
  288. ModifiesValueInsteadOfRaisingError())
  289. assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams())
  290. assert_raises_regex(AssertionError, msg, check_estimator,
  291. ModifiesAnotherValue())
  292. # check that we have a fit method
  293. msg = "object has no attribute 'fit'"
  294. assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
  295. assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator())
  296. # check that fit does input validation
  297. msg = "ValueError not raised"
  298. assert_raises_regex(AssertionError, msg, check_estimator,
  299. BaseBadClassifier)
  300. assert_raises_regex(AssertionError, msg, check_estimator,
  301. BaseBadClassifier())
  302. # check that sample_weights in fit accepts pandas.Series type
  303. try:
  304. from pandas import Series # noqa
  305. msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
  306. "'sample_weight' parameter is of type pandas.Series")
  307. assert_raises_regex(
  308. ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType)
  309. except ImportError:
  310. pass
  311. # check that predict does input validation (doesn't accept dicts in input)
  312. msg = "Estimator doesn't check for NaN and inf in predict"
  313. assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
  314. assert_raises_regex(AssertionError, msg, check_estimator,
  315. NoCheckinPredict())
  316. # check that estimator state does not change
  317. # at transform/predict/predict_proba time
  318. msg = 'Estimator changes __dict__ during predict'
  319. assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict)
  320. # check that `fit` only changes attribures that
  321. # are private (start with an _ or end with a _).
  322. msg = ('Estimator ChangesWrongAttribute should not change or mutate '
  323. 'the parameter wrong_attribute from 0 to 1 during fit.')
  324. assert_raises_regex(AssertionError, msg,
  325. check_estimator, ChangesWrongAttribute)
  326. check_estimator(ChangesUnderscoreAttribute)
  327. # check that `fit` doesn't add any public attribute
  328. msg = (r'Estimator adds public attribute\(s\) during the fit method.'
  329. ' Estimators are only allowed to add private attributes'
  330. ' either started with _ or ended'
  331. ' with _ but wrong_attribute added')
  332. assert_raises_regex(AssertionError, msg,
  333. check_estimator, SetsWrongAttribute)
  334. # check for invariant method
  335. name = NotInvariantPredict.__name__
  336. method = 'predict'
  337. msg = ("{method} of {name} is not invariant when applied "
  338. "to a subset.").format(method=method, name=name)
  339. assert_raises_regex(AssertionError, msg,
  340. check_estimator, NotInvariantPredict)
  341. # check for sparse matrix input handling
  342. name = NoSparseClassifier.__name__
  343. msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
  344. # the check for sparse input handling prints to the stdout,
  345. # instead of raising an error, so as not to remove the original traceback.
  346. # that means we need to jump through some hoops to catch it.
  347. old_stdout = sys.stdout
  348. string_buffer = StringIO()
  349. sys.stdout = string_buffer
  350. try:
  351. check_estimator(NoSparseClassifier)
  352. except:
  353. pass
  354. finally:
  355. sys.stdout = old_stdout
  356. assert msg in string_buffer.getvalue()
  357. # Large indices test on bad estimator
  358. msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to '
  359. r'support \S{3}_64 matrix, and is not failing gracefully.*')
  360. assert_raises_regex(AssertionError, msg, check_estimator,
  361. LargeSparseNotSupportedClassifier)
  362. # does error on binary_only untagged estimator
  363. msg = 'Only 2 classes are supported'
  364. assert_raises_regex(ValueError, msg, check_estimator,
  365. UntaggedBinaryClassifier)
  366. # non-regression test for estimators transforming to sparse data
  367. check_estimator(SparseTransformer())
  368. # doesn't error on actual estimator
  369. check_estimator(LogisticRegression)
  370. check_estimator(LogisticRegression(C=0.01))
  371. check_estimator(MultiTaskElasticNet)
  372. check_estimator(MultiTaskElasticNet())
  373. # doesn't error on binary_only tagged estimator
  374. check_estimator(TaggedBinaryClassifier)
  375. # Check regressor with requires_positive_y estimator tag
  376. msg = 'negative y values not supported!'
  377. assert_raises_regex(ValueError, msg, check_estimator,
  378. RequiresPositiveYRegressor)
  379. def test_check_outlier_corruption():
  380. # should raise AssertionError
  381. decision = np.array([0., 1., 1.5, 2.])
  382. assert_raises(AssertionError, check_outlier_corruption, 1, 2, decision)
  383. # should pass
  384. decision = np.array([0., 1., 1., 2.])
  385. check_outlier_corruption(1, 2, decision)
  386. def test_check_estimator_transformer_no_mixin():
  387. # check that TransformerMixin is not required for transformer tests to run
  388. assert_raises_regex(AttributeError, '.*fit_transform.*',
  389. check_estimator, BadTransformerWithoutMixin())
  390. def test_check_estimator_clones():
  391. # check that check_estimator doesn't modify the estimator it receives
  392. from sklearn.datasets import load_iris
  393. iris = load_iris()
  394. for Estimator in [GaussianMixture, LinearRegression,
  395. RandomForestClassifier, NMF, SGDClassifier,
  396. MiniBatchKMeans]:
  397. with ignore_warnings(category=FutureWarning):
  398. # when 'est = SGDClassifier()'
  399. est = Estimator()
  400. _set_checking_parameters(est)
  401. set_random_state(est)
  402. # without fitting
  403. old_hash = joblib.hash(est)
  404. check_estimator(est)
  405. assert old_hash == joblib.hash(est)
  406. with ignore_warnings(category=FutureWarning):
  407. # when 'est = SGDClassifier()'
  408. est = Estimator()
  409. _set_checking_parameters(est)
  410. set_random_state(est)
  411. # with fitting
  412. est.fit(iris.data + 10, iris.target)
  413. old_hash = joblib.hash(est)
  414. check_estimator(est)
  415. assert old_hash == joblib.hash(est)
  416. def test_check_estimators_unfitted():
  417. # check that a ValueError/AttributeError is raised when calling predict
  418. # on an unfitted estimator
  419. msg = "NotFittedError not raised by predict"
  420. assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
  421. "estimator", NoSparseClassifier())
  422. # check that CorrectNotFittedError inherit from either ValueError
  423. # or AttributeError
  424. check_estimators_unfitted("estimator", CorrectNotFittedErrorClassifier())
  425. def test_check_no_attributes_set_in_init():
  426. class NonConformantEstimatorPrivateSet(BaseEstimator):
  427. def __init__(self):
  428. self.you_should_not_set_this_ = None
  429. class NonConformantEstimatorNoParamSet(BaseEstimator):
  430. def __init__(self, you_should_set_this_=None):
  431. pass
  432. assert_raises_regex(AssertionError,
  433. "Estimator estimator_name should not set any"
  434. " attribute apart from parameters during init."
  435. r" Found attributes \['you_should_not_set_this_'\].",
  436. check_no_attributes_set_in_init,
  437. 'estimator_name',
  438. NonConformantEstimatorPrivateSet())
  439. assert_raises_regex(AssertionError,
  440. "Estimator estimator_name should store all "
  441. "parameters as an attribute during init. "
  442. "Did not find attributes "
  443. r"\['you_should_set_this_'\].",
  444. check_no_attributes_set_in_init,
  445. 'estimator_name',
  446. NonConformantEstimatorNoParamSet())
  447. def test_check_estimator_pairwise():
  448. # check that check_estimator() works on estimator with _pairwise
  449. # kernel or metric
  450. # test precomputed kernel
  451. est = SVC(kernel='precomputed')
  452. check_estimator(est)
  453. # test precomputed metric
  454. est = KNeighborsRegressor(metric='precomputed')
  455. check_estimator(est)
  456. def test_check_classifier_data_not_an_array():
  457. assert_raises_regex(AssertionError,
  458. 'Not equal to tolerance',
  459. check_classifier_data_not_an_array,
  460. 'estimator_name',
  461. EstimatorInconsistentForPandas())
  462. def test_check_regressor_data_not_an_array():
  463. assert_raises_regex(AssertionError,
  464. 'Not equal to tolerance',
  465. check_regressor_data_not_an_array,
  466. 'estimator_name',
  467. EstimatorInconsistentForPandas())
  468. @ignore_warnings("Passing a class is depr", category=FutureWarning) # 0.24
  469. def test_check_estimator_required_parameters_skip():
  470. # TODO: remove whole test in 0.24 since passes classes to check_estimator()
  471. # isn't supported anymore
  472. class MyEstimator(BaseEstimator):
  473. _required_parameters = ["special_parameter"]
  474. def __init__(self, special_parameter):
  475. self.special_parameter = special_parameter
  476. assert_raises_regex(SkipTest, r"Can't instantiate estimator MyEstimator "
  477. r"which requires parameters "
  478. r"\['special_parameter'\]",
  479. check_estimator, MyEstimator)
  480. def run_tests_without_pytest():
  481. """Runs the tests in this file without using pytest.
  482. """
  483. main_module = sys.modules['__main__']
  484. test_functions = [getattr(main_module, name) for name in dir(main_module)
  485. if name.startswith('test_')]
  486. test_cases = [unittest.FunctionTestCase(fn) for fn in test_functions]
  487. suite = unittest.TestSuite()
  488. suite.addTests(test_cases)
  489. runner = unittest.TextTestRunner()
  490. runner.run(suite)
  491. def test_check_class_weight_balanced_linear_classifier():
  492. # check that ill-computed balanced weights raises an exception
  493. assert_raises_regex(AssertionError,
  494. "Classifier estimator_name is not computing"
  495. " class_weight=balanced properly.",
  496. check_class_weight_balanced_linear_classifier,
  497. 'estimator_name',
  498. BadBalancedWeightsClassifier)
  499. def test_all_estimators_all_public():
  500. # all_estimator should not fail when pytest is not installed and return
  501. # only public estimators
  502. estimators = all_estimators()
  503. for est in estimators:
  504. assert not est.__class__.__name__.startswith("_")
  505. if __name__ == '__main__':
  506. # This module is run as a script to check that we have no dependency on
  507. # pytest for estimator checks.
  508. run_tests_without_pytest()