PageRenderTime 71ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 1ms

/sklearn/utils/estimator_checks.py

http://github.com/scikit-learn/scikit-learn
Python | 3047 lines | 2528 code | 259 blank | 260 comment | 227 complexity | 75f5c0f5bd11b82772c418f6bdfe02bc MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import types
  2. import warnings
  3. import sys
  4. import traceback
  5. import pickle
  6. import re
  7. from copy import deepcopy
  8. from functools import partial
  9. from itertools import chain
  10. from inspect import signature
  11. import numpy as np
  12. from scipy import sparse
  13. from scipy.stats import rankdata
  14. import joblib
  15. from . import IS_PYPY
  16. from .. import config_context
  17. from ._testing import assert_raises, _get_args
  18. from ._testing import assert_raises_regex
  19. from ._testing import assert_raise_message
  20. from ._testing import assert_array_equal
  21. from ._testing import assert_array_almost_equal
  22. from ._testing import assert_allclose
  23. from ._testing import assert_allclose_dense_sparse
  24. from ._testing import assert_warns_message
  25. from ._testing import set_random_state
  26. from ._testing import SkipTest
  27. from ._testing import ignore_warnings
  28. from ._testing import create_memmap_backed_data
  29. from . import is_scalar_nan
  30. from ..discriminant_analysis import LinearDiscriminantAnalysis
  31. from ..linear_model import Ridge
  32. from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
  33. RegressorMixin, is_outlier_detector, BaseEstimator)
  34. from ..metrics import accuracy_score, adjusted_rand_score, f1_score
  35. from ..random_projection import BaseRandomProjection
  36. from ..feature_selection import SelectKBest
  37. from ..pipeline import make_pipeline
  38. from ..exceptions import DataConversionWarning
  39. from ..exceptions import NotFittedError
  40. from ..exceptions import SkipTestWarning
  41. from ..model_selection import train_test_split
  42. from ..model_selection import ShuffleSplit
  43. from ..model_selection._validation import _safe_split
  44. from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances)
  45. from .import shuffle
  46. from .import deprecated
  47. from .validation import has_fit_parameter, _num_samples
  48. from ..preprocessing import StandardScaler
  49. from ..datasets import (load_iris, load_boston, make_blobs,
  50. make_multilabel_classification, make_regression)
  51. BOSTON = None
  52. CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
  53. def _yield_checks(name, estimator):
  54. tags = estimator._get_tags()
  55. yield check_no_attributes_set_in_init
  56. yield check_estimators_dtypes
  57. yield check_fit_score_takes_y
  58. yield check_sample_weights_pandas_series
  59. yield check_sample_weights_not_an_array
  60. yield check_sample_weights_list
  61. yield check_sample_weights_shape
  62. yield check_sample_weights_invariance
  63. yield check_estimators_fit_returns_self
  64. yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
  65. # Check that all estimator yield informative messages when
  66. # trained on empty datasets
  67. if not tags["no_validation"]:
  68. yield check_complex_data
  69. yield check_dtype_object
  70. yield check_estimators_empty_data_messages
  71. if name not in CROSS_DECOMPOSITION:
  72. # cross-decomposition's "transform" returns X and Y
  73. yield check_pipeline_consistency
  74. if not tags["allow_nan"] and not tags["no_validation"]:
  75. # Test that all estimators check their input for NaN's and infs
  76. yield check_estimators_nan_inf
  77. if _is_pairwise(estimator):
  78. # Check that pairwise estimator throws error on non-square input
  79. yield check_nonsquare_error
  80. yield check_estimators_overwrite_params
  81. if hasattr(estimator, 'sparsify'):
  82. yield check_sparsify_coefficients
  83. yield check_estimator_sparse_data
  84. # Test that estimators can be pickled, and once pickled
  85. # give the same answer as before.
  86. yield check_estimators_pickle
  87. def _yield_classifier_checks(name, classifier):
  88. tags = classifier._get_tags()
  89. # test classifiers can handle non-array data and pandas objects
  90. yield check_classifier_data_not_an_array
  91. # test classifiers trained on a single label always return this label
  92. yield check_classifiers_one_label
  93. yield check_classifiers_classes
  94. yield check_estimators_partial_fit_n_features
  95. if tags["multioutput"]:
  96. yield check_classifier_multioutput
  97. # basic consistency testing
  98. yield check_classifiers_train
  99. yield partial(check_classifiers_train, readonly_memmap=True)
  100. yield partial(check_classifiers_train, readonly_memmap=True,
  101. X_dtype='float32')
  102. yield check_classifiers_regression_target
  103. if tags["multilabel"]:
  104. yield check_classifiers_multilabel_representation_invariance
  105. if not tags["no_validation"]:
  106. yield check_supervised_y_no_nan
  107. yield check_supervised_y_2d
  108. if tags["requires_fit"]:
  109. yield check_estimators_unfitted
  110. if 'class_weight' in classifier.get_params().keys():
  111. yield check_class_weight_classifiers
  112. yield check_non_transformer_estimators_n_iter
  113. # test if predict_proba is a monotonic transformation of decision_function
  114. yield check_decision_proba_consistency
  115. @ignore_warnings(category=FutureWarning)
  116. def check_supervised_y_no_nan(name, estimator_orig):
  117. # Checks that the Estimator targets are not NaN.
  118. estimator = clone(estimator_orig)
  119. rng = np.random.RandomState(888)
  120. X = rng.randn(10, 5)
  121. y = np.full(10, np.inf)
  122. y = _enforce_estimator_tags_y(estimator, y)
  123. errmsg = "Input contains NaN, infinity or a value too large for " \
  124. "dtype('float64')."
  125. try:
  126. estimator.fit(X, y)
  127. except ValueError as e:
  128. if str(e) != errmsg:
  129. raise ValueError("Estimator {0} raised error as expected, but "
  130. "does not match expected error message"
  131. .format(name))
  132. else:
  133. raise ValueError("Estimator {0} should have raised error on fitting "
  134. "array y with NaN value.".format(name))
  135. def _yield_regressor_checks(name, regressor):
  136. tags = regressor._get_tags()
  137. # TODO: test with intercept
  138. # TODO: test with multiple responses
  139. # basic testing
  140. yield check_regressors_train
  141. yield partial(check_regressors_train, readonly_memmap=True)
  142. yield partial(check_regressors_train, readonly_memmap=True,
  143. X_dtype='float32')
  144. yield check_regressor_data_not_an_array
  145. yield check_estimators_partial_fit_n_features
  146. if tags["multioutput"]:
  147. yield check_regressor_multioutput
  148. yield check_regressors_no_decision_function
  149. if not tags["no_validation"]:
  150. yield check_supervised_y_2d
  151. yield check_supervised_y_no_nan
  152. if name != 'CCA':
  153. # check that the regressor handles int input
  154. yield check_regressors_int
  155. if tags["requires_fit"]:
  156. yield check_estimators_unfitted
  157. yield check_non_transformer_estimators_n_iter
  158. def _yield_transformer_checks(name, transformer):
  159. # All transformers should either deal with sparse data or raise an
  160. # exception with type TypeError and an intelligible error message
  161. if not transformer._get_tags()["no_validation"]:
  162. yield check_transformer_data_not_an_array
  163. # these don't actually fit the data, so don't raise errors
  164. yield check_transformer_general
  165. yield partial(check_transformer_general, readonly_memmap=True)
  166. if not transformer._get_tags()["stateless"]:
  167. yield check_transformers_unfitted
  168. # Dependent on external solvers and hence accessing the iter
  169. # param is non-trivial.
  170. external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
  171. 'RandomizedLasso', 'LogisticRegressionCV']
  172. if name not in external_solver:
  173. yield check_transformer_n_iter
  174. def _yield_clustering_checks(name, clusterer):
  175. yield check_clusterer_compute_labels_predict
  176. if name not in ('WardAgglomeration', "FeatureAgglomeration"):
  177. # this is clustering on the features
  178. # let's not test that here.
  179. yield check_clustering
  180. yield partial(check_clustering, readonly_memmap=True)
  181. yield check_estimators_partial_fit_n_features
  182. yield check_non_transformer_estimators_n_iter
  183. def _yield_outliers_checks(name, estimator):
  184. # checks for outlier detectors that have a fit_predict method
  185. if hasattr(estimator, 'fit_predict'):
  186. yield check_outliers_fit_predict
  187. # checks for estimators that can be used on a test set
  188. if hasattr(estimator, 'predict'):
  189. yield check_outliers_train
  190. yield partial(check_outliers_train, readonly_memmap=True)
  191. # test outlier detectors can handle non-array data
  192. yield check_classifier_data_not_an_array
  193. # test if NotFittedError is raised
  194. if estimator._get_tags()["requires_fit"]:
  195. yield check_estimators_unfitted
  196. def _yield_all_checks(name, estimator):
  197. tags = estimator._get_tags()
  198. if "2darray" not in tags["X_types"]:
  199. warnings.warn("Can't test estimator {} which requires input "
  200. " of type {}".format(name, tags["X_types"]),
  201. SkipTestWarning)
  202. return
  203. if tags["_skip_test"]:
  204. warnings.warn("Explicit SKIP via _skip_test tag for estimator "
  205. "{}.".format(name),
  206. SkipTestWarning)
  207. return
  208. for check in _yield_checks(name, estimator):
  209. yield check
  210. if is_classifier(estimator):
  211. for check in _yield_classifier_checks(name, estimator):
  212. yield check
  213. if is_regressor(estimator):
  214. for check in _yield_regressor_checks(name, estimator):
  215. yield check
  216. if hasattr(estimator, 'transform'):
  217. for check in _yield_transformer_checks(name, estimator):
  218. yield check
  219. if isinstance(estimator, ClusterMixin):
  220. for check in _yield_clustering_checks(name, estimator):
  221. yield check
  222. if is_outlier_detector(estimator):
  223. for check in _yield_outliers_checks(name, estimator):
  224. yield check
  225. yield check_fit2d_predict1d
  226. yield check_methods_subset_invariance
  227. yield check_fit2d_1sample
  228. yield check_fit2d_1feature
  229. yield check_fit1d
  230. yield check_get_params_invariance
  231. yield check_set_params
  232. yield check_dict_unchanged
  233. yield check_dont_overwrite_parameters
  234. yield check_fit_idempotent
  235. if not tags["no_validation"]:
  236. yield check_n_features_in
  237. if tags["requires_y"]:
  238. yield check_requires_y_none
  239. if tags["requires_positive_X"]:
  240. yield check_fit_non_negative
  241. def _set_check_estimator_ids(obj):
  242. """Create pytest ids for checks.
  243. When `obj` is an estimator, this returns the pprint version of the
  244. estimator (with `print_changed_only=True`). When `obj` is a function, the
  245. name of the function is returned with its keyworld arguments.
  246. `_set_check_estimator_ids` is designed to be used as the `id` in
  247. `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)`
  248. is yielding estimators and checks.
  249. Parameters
  250. ----------
  251. obj : estimator or function
  252. Items generated by `check_estimator`
  253. Returns
  254. -------
  255. id : string or None
  256. See also
  257. --------
  258. check_estimator
  259. """
  260. if callable(obj):
  261. if not isinstance(obj, partial):
  262. return obj.__name__
  263. if not obj.keywords:
  264. return obj.func.__name__
  265. kwstring = ",".join(["{}={}".format(k, v)
  266. for k, v in obj.keywords.items()])
  267. return "{}({})".format(obj.func.__name__, kwstring)
  268. if hasattr(obj, "get_params"):
  269. with config_context(print_changed_only=True):
  270. return re.sub(r"\s", "", str(obj))
  271. def _construct_instance(Estimator):
  272. """Construct Estimator instance if possible"""
  273. required_parameters = getattr(Estimator, "_required_parameters", [])
  274. if len(required_parameters):
  275. if required_parameters in (["estimator"], ["base_estimator"]):
  276. if issubclass(Estimator, RegressorMixin):
  277. estimator = Estimator(Ridge())
  278. else:
  279. estimator = Estimator(LinearDiscriminantAnalysis())
  280. else:
  281. raise SkipTest("Can't instantiate estimator {} which requires "
  282. "parameters {}".format(Estimator.__name__,
  283. required_parameters))
  284. else:
  285. estimator = Estimator()
  286. return estimator
  287. # TODO: probably not needed anymore in 0.24 since _generate_class_checks should
  288. # be removed too. Just put this in check_estimator()
  289. def _generate_instance_checks(name, estimator):
  290. """Generate instance checks."""
  291. yield from ((estimator, partial(check, name))
  292. for check in _yield_all_checks(name, estimator))
  293. # TODO: remove this in 0.24
  294. def _generate_class_checks(Estimator):
  295. """Generate class checks."""
  296. name = Estimator.__name__
  297. yield (Estimator, partial(check_parameters_default_constructible, name))
  298. estimator = _construct_instance(Estimator)
  299. yield from _generate_instance_checks(name, estimator)
  300. def _mark_xfail_checks(estimator, check, pytest):
  301. """Mark (estimator, check) pairs with xfail according to the
  302. _xfail_checks_ tag"""
  303. if isinstance(estimator, type):
  304. # try to construct estimator instance, if it is unable to then
  305. # return the estimator class, ignoring the tag
  306. # TODO: remove this if block in 0.24 since passing instances isn't
  307. # supported anymore
  308. try:
  309. estimator = _construct_instance(estimator)
  310. except Exception:
  311. return estimator, check
  312. xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
  313. check_name = _set_check_estimator_ids(check)
  314. if check_name not in xfail_checks:
  315. # check isn't part of the xfail_checks tags, just return it
  316. return estimator, check
  317. else:
  318. # check is in the tag, mark it as xfail for pytest
  319. reason = xfail_checks[check_name]
  320. return pytest.param(estimator, check,
  321. marks=pytest.mark.xfail(reason=reason))
  322. def parametrize_with_checks(estimators):
  323. """Pytest specific decorator for parametrizing estimator checks.
  324. The `id` of each check is set to be a pprint version of the estimator
  325. and the name of the check with its keyword arguments.
  326. This allows to use `pytest -k` to specify which tests to run::
  327. pytest test_check_estimators.py -k check_estimators_fit_returns_self
  328. Parameters
  329. ----------
  330. estimators : list of estimators objects or classes
  331. Estimators to generated checks for.
  332. .. deprecated:: 0.23
  333. Passing a class is deprecated from version 0.23, and won't be
  334. supported in 0.24. Pass an instance instead.
  335. Returns
  336. -------
  337. decorator : `pytest.mark.parametrize`
  338. Examples
  339. --------
  340. >>> from sklearn.utils.estimator_checks import parametrize_with_checks
  341. >>> from sklearn.linear_model import LogisticRegression
  342. >>> from sklearn.tree import DecisionTreeRegressor
  343. >>> @parametrize_with_checks([LogisticRegression(),
  344. ... DecisionTreeRegressor()])
  345. ... def test_sklearn_compatible_estimator(estimator, check):
  346. ... check(estimator)
  347. """
  348. import pytest
  349. if any(isinstance(est, type) for est in estimators):
  350. # TODO: remove class support in 0.24 and update docstrings
  351. msg = ("Passing a class is deprecated since version 0.23 "
  352. "and won't be supported in 0.24."
  353. "Please pass an instance instead.")
  354. warnings.warn(msg, FutureWarning)
  355. checks_generator = chain.from_iterable(
  356. check_estimator(estimator, generate_only=True)
  357. for estimator in estimators)
  358. checks_with_marks = (
  359. _mark_xfail_checks(estimator, check, pytest)
  360. for estimator, check in checks_generator)
  361. return pytest.mark.parametrize("estimator, check", checks_with_marks,
  362. ids=_set_check_estimator_ids)
  363. def check_estimator(Estimator, generate_only=False):
  364. """Check if estimator adheres to scikit-learn conventions.
  365. This estimator will run an extensive test-suite for input validation,
  366. shapes, etc, making sure that the estimator complies with `scikit-learn`
  367. conventions as detailed in :ref:`rolling_your_own_estimator`.
  368. Additional tests for classifiers, regressors, clustering or transformers
  369. will be run if the Estimator class inherits from the corresponding mixin
  370. from sklearn.base.
  371. This test can be applied to classes or instances.
  372. Classes currently have some additional tests that related to construction,
  373. while passing instances allows the testing of multiple options. However,
  374. support for classes is deprecated since version 0.23 and will be removed
  375. in version 0.24 (class checks will still be run on the instances).
  376. Setting `generate_only=True` returns a generator that yields (estimator,
  377. check) tuples where the check can be called independently from each
  378. other, i.e. `check(estimator)`. This allows all checks to be run
  379. independently and report the checks that are failing.
  380. scikit-learn provides a pytest specific decorator,
  381. :func:`~sklearn.utils.parametrize_with_checks`, making it easier to test
  382. multiple estimators.
  383. Parameters
  384. ----------
  385. estimator : estimator object
  386. Estimator to check. Estimator is a class object or instance.
  387. .. deprecated:: 0.23
  388. Passing a class is deprecated from version 0.23, and won't be
  389. supported in 0.24. Pass an instance instead.
  390. generate_only : bool, optional (default=False)
  391. When `False`, checks are evaluated when `check_estimator` is called.
  392. When `True`, `check_estimator` returns a generator that yields
  393. (estimator, check) tuples. The check is run by calling
  394. `check(estimator)`.
  395. .. versionadded:: 0.22
  396. Returns
  397. -------
  398. checks_generator : generator
  399. Generator that yields (estimator, check) tuples. Returned when
  400. `generate_only=True`.
  401. """
  402. # TODO: remove class support in 0.24 and update docstrings
  403. if isinstance(Estimator, type):
  404. # got a class
  405. msg = ("Passing a class is deprecated since version 0.23 "
  406. "and won't be supported in 0.24."
  407. "Please pass an instance instead.")
  408. warnings.warn(msg, FutureWarning)
  409. checks_generator = _generate_class_checks(Estimator)
  410. else:
  411. # got an instance
  412. estimator = Estimator
  413. name = type(estimator).__name__
  414. checks_generator = _generate_instance_checks(name, estimator)
  415. if generate_only:
  416. return checks_generator
  417. for estimator, check in checks_generator:
  418. try:
  419. check(estimator)
  420. except SkipTest as exception:
  421. # the only SkipTest thrown currently results from not
  422. # being able to import pandas.
  423. warnings.warn(str(exception), SkipTestWarning)
  424. def _boston_subset(n_samples=200):
  425. global BOSTON
  426. if BOSTON is None:
  427. X, y = load_boston(return_X_y=True)
  428. X, y = shuffle(X, y, random_state=0)
  429. X, y = X[:n_samples], y[:n_samples]
  430. X = StandardScaler().fit_transform(X)
  431. BOSTON = X, y
  432. return BOSTON
  433. @deprecated("set_checking_parameters is deprecated in version "
  434. "0.22 and will be removed in version 0.24.")
  435. def set_checking_parameters(estimator):
  436. _set_checking_parameters(estimator)
  437. def _set_checking_parameters(estimator):
  438. # set parameters to speed up some estimators and
  439. # avoid deprecated behaviour
  440. params = estimator.get_params()
  441. name = estimator.__class__.__name__
  442. if ("n_iter" in params and name != "TSNE"):
  443. estimator.set_params(n_iter=5)
  444. if "max_iter" in params:
  445. if estimator.max_iter is not None:
  446. estimator.set_params(max_iter=min(5, estimator.max_iter))
  447. # LinearSVR, LinearSVC
  448. if estimator.__class__.__name__ in ['LinearSVR', 'LinearSVC']:
  449. estimator.set_params(max_iter=20)
  450. # NMF
  451. if estimator.__class__.__name__ == 'NMF':
  452. estimator.set_params(max_iter=100)
  453. # MLP
  454. if estimator.__class__.__name__ in ['MLPClassifier', 'MLPRegressor']:
  455. estimator.set_params(max_iter=100)
  456. if "n_resampling" in params:
  457. # randomized lasso
  458. estimator.set_params(n_resampling=5)
  459. if "n_estimators" in params:
  460. estimator.set_params(n_estimators=min(5, estimator.n_estimators))
  461. if "max_trials" in params:
  462. # RANSAC
  463. estimator.set_params(max_trials=10)
  464. if "n_init" in params:
  465. # K-Means
  466. estimator.set_params(n_init=2)
  467. if name == 'TruncatedSVD':
  468. # TruncatedSVD doesn't run with n_components = n_features
  469. # This is ugly :-/
  470. estimator.n_components = 1
  471. if hasattr(estimator, "n_clusters"):
  472. estimator.n_clusters = min(estimator.n_clusters, 2)
  473. if hasattr(estimator, "n_best"):
  474. estimator.n_best = 1
  475. if name == "SelectFdr":
  476. # be tolerant of noisy datasets (not actually speed)
  477. estimator.set_params(alpha=.5)
  478. if name == "TheilSenRegressor":
  479. estimator.max_subpopulation = 100
  480. if isinstance(estimator, BaseRandomProjection):
  481. # Due to the jl lemma and often very few samples, the number
  482. # of components of the random matrix projection will be probably
  483. # greater than the number of features.
  484. # So we impose a smaller number (avoid "auto" mode)
  485. estimator.set_params(n_components=2)
  486. if isinstance(estimator, SelectKBest):
  487. # SelectKBest has a default of k=10
  488. # which is more feature than we have in most case.
  489. estimator.set_params(k=1)
  490. if name in ('HistGradientBoostingClassifier',
  491. 'HistGradientBoostingRegressor'):
  492. # The default min_samples_leaf (20) isn't appropriate for small
  493. # datasets (only very shallow trees are built) that the checks use.
  494. estimator.set_params(min_samples_leaf=5)
  495. # Speed-up by reducing the number of CV or splits for CV estimators
  496. loo_cv = ['RidgeCV']
  497. if name not in loo_cv and hasattr(estimator, 'cv'):
  498. estimator.set_params(cv=3)
  499. if hasattr(estimator, 'n_splits'):
  500. estimator.set_params(n_splits=3)
  501. if name == 'OneHotEncoder':
  502. estimator.set_params(handle_unknown='ignore')
  503. class _NotAnArray:
  504. """An object that is convertible to an array
  505. Parameters
  506. ----------
  507. data : array_like
  508. The data.
  509. """
  510. def __init__(self, data):
  511. self.data = np.asarray(data)
  512. def __array__(self, dtype=None):
  513. return self.data
  514. def __array_function__(self, func, types, args, kwargs):
  515. if func.__name__ == "may_share_memory":
  516. return True
  517. raise TypeError("Don't want to call array_function {}!".format(
  518. func.__name__))
  519. @deprecated("NotAnArray is deprecated in version "
  520. "0.22 and will be removed in version 0.24.")
  521. class NotAnArray(_NotAnArray):
  522. # TODO: remove in 0.24
  523. pass
  524. def _is_pairwise(estimator):
  525. """Returns True if estimator has a _pairwise attribute set to True.
  526. Parameters
  527. ----------
  528. estimator : object
  529. Estimator object to test.
  530. Returns
  531. -------
  532. out : bool
  533. True if _pairwise is set to True and False otherwise.
  534. """
  535. return bool(getattr(estimator, "_pairwise", False))
  536. def _is_pairwise_metric(estimator):
  537. """Returns True if estimator accepts pairwise metric.
  538. Parameters
  539. ----------
  540. estimator : object
  541. Estimator object to test.
  542. Returns
  543. -------
  544. out : bool
  545. True if _pairwise is set to True and False otherwise.
  546. """
  547. metric = getattr(estimator, "metric", None)
  548. return bool(metric == 'precomputed')
  549. @deprecated("pairwise_estimator_convert_X is deprecated in version "
  550. "0.22 and will be removed in version 0.24.")
  551. def pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
  552. return _pairwise_estimator_convert_X(X, estimator, kernel)
  553. def _pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
  554. if _is_pairwise_metric(estimator):
  555. return pairwise_distances(X, metric='euclidean')
  556. if _is_pairwise(estimator):
  557. return kernel(X, X)
  558. return X
  559. def _generate_sparse_matrix(X_csr):
  560. """Generate sparse matrices with {32,64}bit indices of diverse format
  561. Parameters
  562. ----------
  563. X_csr: CSR Matrix
  564. Input matrix in CSR format
  565. Returns
  566. -------
  567. out: iter(Matrices)
  568. In format['dok', 'lil', 'dia', 'bsr', 'csr', 'csc', 'coo',
  569. 'coo_64', 'csc_64', 'csr_64']
  570. """
  571. assert X_csr.format == 'csr'
  572. yield 'csr', X_csr.copy()
  573. for sparse_format in ['dok', 'lil', 'dia', 'bsr', 'csc', 'coo']:
  574. yield sparse_format, X_csr.asformat(sparse_format)
  575. # Generate large indices matrix only if its supported by scipy
  576. X_coo = X_csr.asformat('coo')
  577. X_coo.row = X_coo.row.astype('int64')
  578. X_coo.col = X_coo.col.astype('int64')
  579. yield "coo_64", X_coo
  580. for sparse_format in ['csc', 'csr']:
  581. X = X_csr.asformat(sparse_format)
  582. X.indices = X.indices.astype('int64')
  583. X.indptr = X.indptr.astype('int64')
  584. yield sparse_format + "_64", X
  585. def check_estimator_sparse_data(name, estimator_orig):
  586. rng = np.random.RandomState(0)
  587. X = rng.rand(40, 10)
  588. X[X < .8] = 0
  589. X = _pairwise_estimator_convert_X(X, estimator_orig)
  590. X_csr = sparse.csr_matrix(X)
  591. tags = estimator_orig._get_tags()
  592. if tags['binary_only']:
  593. y = (2 * rng.rand(40)).astype(np.int)
  594. else:
  595. y = (4 * rng.rand(40)).astype(np.int)
  596. # catch deprecation warnings
  597. with ignore_warnings(category=FutureWarning):
  598. estimator = clone(estimator_orig)
  599. y = _enforce_estimator_tags_y(estimator, y)
  600. for matrix_format, X in _generate_sparse_matrix(X_csr):
  601. # catch deprecation warnings
  602. with ignore_warnings(category=FutureWarning):
  603. estimator = clone(estimator_orig)
  604. if name in ['Scaler', 'StandardScaler']:
  605. estimator.set_params(with_mean=False)
  606. # fit and predict
  607. try:
  608. with ignore_warnings(category=FutureWarning):
  609. estimator.fit(X, y)
  610. if hasattr(estimator, "predict"):
  611. pred = estimator.predict(X)
  612. if tags['multioutput_only']:
  613. assert pred.shape == (X.shape[0], 1)
  614. else:
  615. assert pred.shape == (X.shape[0],)
  616. if hasattr(estimator, 'predict_proba'):
  617. probs = estimator.predict_proba(X)
  618. if tags['binary_only']:
  619. expected_probs_shape = (X.shape[0], 2)
  620. else:
  621. expected_probs_shape = (X.shape[0], 4)
  622. assert probs.shape == expected_probs_shape
  623. except (TypeError, ValueError) as e:
  624. if 'sparse' not in repr(e).lower():
  625. if "64" in matrix_format:
  626. msg = ("Estimator %s doesn't seem to support %s matrix, "
  627. "and is not failing gracefully, e.g. by using "
  628. "check_array(X, accept_large_sparse=False)")
  629. raise AssertionError(msg % (name, matrix_format))
  630. else:
  631. print("Estimator %s doesn't seem to fail gracefully on "
  632. "sparse data: error message state explicitly that "
  633. "sparse input is not supported if this is not"
  634. " the case." % name)
  635. raise
  636. except Exception:
  637. print("Estimator %s doesn't seem to fail gracefully on "
  638. "sparse data: it should raise a TypeError if sparse input "
  639. "is explicitly not supported." % name)
  640. raise
  641. @ignore_warnings(category=FutureWarning)
  642. def check_sample_weights_pandas_series(name, estimator_orig):
  643. # check that estimators will accept a 'sample_weight' parameter of
  644. # type pandas.Series in the 'fit' function.
  645. estimator = clone(estimator_orig)
  646. if has_fit_parameter(estimator, "sample_weight"):
  647. try:
  648. import pandas as pd
  649. X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
  650. [2, 1], [2, 2], [2, 3], [2, 4],
  651. [3, 1], [3, 2], [3, 3], [3, 4]])
  652. X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig))
  653. y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
  654. weights = pd.Series([1] * 12)
  655. if estimator._get_tags()["multioutput_only"]:
  656. y = pd.DataFrame(y)
  657. try:
  658. estimator.fit(X, y, sample_weight=weights)
  659. except ValueError:
  660. raise ValueError("Estimator {0} raises error if "
  661. "'sample_weight' parameter is of "
  662. "type pandas.Series".format(name))
  663. except ImportError:
  664. raise SkipTest("pandas is not installed: not testing for "
  665. "input of type pandas.Series to class weight.")
  666. @ignore_warnings(category=(FutureWarning))
  667. def check_sample_weights_not_an_array(name, estimator_orig):
  668. # check that estimators will accept a 'sample_weight' parameter of
  669. # type _NotAnArray in the 'fit' function.
  670. estimator = clone(estimator_orig)
  671. if has_fit_parameter(estimator, "sample_weight"):
  672. X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
  673. [2, 1], [2, 2], [2, 3], [2, 4],
  674. [3, 1], [3, 2], [3, 3], [3, 4]])
  675. X = _NotAnArray(pairwise_estimator_convert_X(X, estimator_orig))
  676. y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
  677. weights = _NotAnArray([1] * 12)
  678. if estimator._get_tags()["multioutput_only"]:
  679. y = _NotAnArray(y.data.reshape(-1, 1))
  680. estimator.fit(X, y, sample_weight=weights)
  681. @ignore_warnings(category=(FutureWarning))
  682. def check_sample_weights_list(name, estimator_orig):
  683. # check that estimators will accept a 'sample_weight' parameter of
  684. # type list in the 'fit' function.
  685. if has_fit_parameter(estimator_orig, "sample_weight"):
  686. estimator = clone(estimator_orig)
  687. rnd = np.random.RandomState(0)
  688. n_samples = 30
  689. X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
  690. estimator_orig)
  691. if estimator._get_tags()['binary_only']:
  692. y = np.arange(n_samples) % 2
  693. else:
  694. y = np.arange(n_samples) % 3
  695. y = _enforce_estimator_tags_y(estimator, y)
  696. sample_weight = [3] * n_samples
  697. # Test that estimators don't raise any exception
  698. estimator.fit(X, y, sample_weight=sample_weight)
  699. @ignore_warnings(category=FutureWarning)
  700. def check_sample_weights_shape(name, estimator_orig):
  701. # check that estimators raise an error if sample_weight
  702. # shape mismatches the input
  703. if (has_fit_parameter(estimator_orig, "sample_weight") and
  704. not (hasattr(estimator_orig, "_pairwise")
  705. and estimator_orig._pairwise)):
  706. estimator = clone(estimator_orig)
  707. X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
  708. [2, 1], [2, 1], [2, 1], [2, 1],
  709. [3, 3], [3, 3], [3, 3], [3, 3],
  710. [4, 1], [4, 1], [4, 1], [4, 1]])
  711. y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
  712. 1, 1, 1, 1, 2, 2, 2, 2])
  713. y = _enforce_estimator_tags_y(estimator, y)
  714. estimator.fit(X, y, sample_weight=np.ones(len(y)))
  715. assert_raises(ValueError, estimator.fit, X, y,
  716. sample_weight=np.ones(2*len(y)))
  717. assert_raises(ValueError, estimator.fit, X, y,
  718. sample_weight=np.ones((len(y), 2)))
  719. @ignore_warnings(category=FutureWarning)
  720. def check_sample_weights_invariance(name, estimator_orig):
  721. # check that the estimators yield same results for
  722. # unit weights and no weights
  723. if (has_fit_parameter(estimator_orig, "sample_weight") and
  724. not (hasattr(estimator_orig, "_pairwise")
  725. and estimator_orig._pairwise)):
  726. # We skip pairwise because the data is not pairwise
  727. estimator1 = clone(estimator_orig)
  728. estimator2 = clone(estimator_orig)
  729. set_random_state(estimator1, random_state=0)
  730. set_random_state(estimator2, random_state=0)
  731. X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
  732. [2, 1], [2, 1], [2, 1], [2, 1],
  733. [3, 3], [3, 3], [3, 3], [3, 3],
  734. [4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
  735. y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
  736. 1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))
  737. y = _enforce_estimator_tags_y(estimator1, y)
  738. estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
  739. estimator2.fit(X, y=y, sample_weight=None)
  740. for method in ["predict", "transform"]:
  741. if hasattr(estimator_orig, method):
  742. X_pred1 = getattr(estimator1, method)(X)
  743. X_pred2 = getattr(estimator2, method)(X)
  744. if sparse.issparse(X_pred1):
  745. X_pred1 = X_pred1.toarray()
  746. X_pred2 = X_pred2.toarray()
  747. assert_allclose(X_pred1, X_pred2,
  748. err_msg="For %s sample_weight=None is not"
  749. " equivalent to sample_weight=ones"
  750. % name)
  751. @ignore_warnings(category=(FutureWarning, UserWarning))
  752. def check_dtype_object(name, estimator_orig):
  753. # check that estimators treat dtype object as numeric if possible
  754. rng = np.random.RandomState(0)
  755. X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
  756. X = X.astype(object)
  757. tags = estimator_orig._get_tags()
  758. if tags['binary_only']:
  759. y = (X[:, 0] * 2).astype(np.int)
  760. else:
  761. y = (X[:, 0] * 4).astype(np.int)
  762. estimator = clone(estimator_orig)
  763. y = _enforce_estimator_tags_y(estimator, y)
  764. estimator.fit(X, y)
  765. if hasattr(estimator, "predict"):
  766. estimator.predict(X)
  767. if hasattr(estimator, "transform"):
  768. estimator.transform(X)
  769. try:
  770. estimator.fit(X, y.astype(object))
  771. except Exception as e:
  772. if "Unknown label type" not in str(e):
  773. raise
  774. if 'string' not in tags['X_types']:
  775. X[0, 0] = {'foo': 'bar'}
  776. msg = "argument must be a string.* number"
  777. assert_raises_regex(TypeError, msg, estimator.fit, X, y)
  778. else:
  779. # Estimators supporting string will not call np.asarray to convert the
  780. # data to numeric and therefore, the error will not be raised.
  781. # Checking for each element dtype in the input array will be costly.
  782. # Refer to #11401 for full discussion.
  783. estimator.fit(X, y)
  784. def check_complex_data(name, estimator_orig):
  785. # check that estimators raise an exception on providing complex data
  786. X = np.random.sample(10) + 1j * np.random.sample(10)
  787. X = X.reshape(-1, 1)
  788. y = np.random.sample(10) + 1j * np.random.sample(10)
  789. estimator = clone(estimator_orig)
  790. assert_raises_regex(ValueError, "Complex data not supported",
  791. estimator.fit, X, y)
  792. @ignore_warnings
  793. def check_dict_unchanged(name, estimator_orig):
  794. # this estimator raises
  795. # ValueError: Found array with 0 feature(s) (shape=(23, 0))
  796. # while a minimum of 1 is required.
  797. # error
  798. if name in ['SpectralCoclustering']:
  799. return
  800. rnd = np.random.RandomState(0)
  801. if name in ['RANSACRegressor']:
  802. X = 3 * rnd.uniform(size=(20, 3))
  803. else:
  804. X = 2 * rnd.uniform(size=(20, 3))
  805. X = _pairwise_estimator_convert_X(X, estimator_orig)
  806. y = X[:, 0].astype(np.int)
  807. estimator = clone(estimator_orig)
  808. y = _enforce_estimator_tags_y(estimator, y)
  809. if hasattr(estimator, "n_components"):
  810. estimator.n_components = 1
  811. if hasattr(estimator, "n_clusters"):
  812. estimator.n_clusters = 1
  813. if hasattr(estimator, "n_best"):
  814. estimator.n_best = 1
  815. set_random_state(estimator, 1)
  816. estimator.fit(X, y)
  817. for method in ["predict", "transform", "decision_function",
  818. "predict_proba"]:
  819. if hasattr(estimator, method):
  820. dict_before = estimator.__dict__.copy()
  821. getattr(estimator, method)(X)
  822. assert estimator.__dict__ == dict_before, (
  823. 'Estimator changes __dict__ during %s' % method)
  824. @deprecated("is_public_parameter is deprecated in version "
  825. "0.22 and will be removed in version 0.24.")
  826. def is_public_parameter(attr):
  827. return _is_public_parameter(attr)
  828. def _is_public_parameter(attr):
  829. return not (attr.startswith('_') or attr.endswith('_'))
  830. @ignore_warnings(category=FutureWarning)
  831. def check_dont_overwrite_parameters(name, estimator_orig):
  832. # check that fit method only changes or sets private attributes
  833. if hasattr(estimator_orig.__init__, "deprecated_original"):
  834. # to not check deprecated classes
  835. return
  836. estimator = clone(estimator_orig)
  837. rnd = np.random.RandomState(0)
  838. X = 3 * rnd.uniform(size=(20, 3))
  839. X = _pairwise_estimator_convert_X(X, estimator_orig)
  840. y = X[:, 0].astype(np.int)
  841. if estimator._get_tags()['binary_only']:
  842. y[y == 2] = 1
  843. y = _enforce_estimator_tags_y(estimator, y)
  844. if hasattr(estimator, "n_components"):
  845. estimator.n_components = 1
  846. if hasattr(estimator, "n_clusters"):
  847. estimator.n_clusters = 1
  848. set_random_state(estimator, 1)
  849. dict_before_fit = estimator.__dict__.copy()
  850. estimator.fit(X, y)
  851. dict_after_fit = estimator.__dict__
  852. public_keys_after_fit = [key for key in dict_after_fit.keys()
  853. if _is_public_parameter(key)]
  854. attrs_added_by_fit = [key for key in public_keys_after_fit
  855. if key not in dict_before_fit.keys()]
  856. # check that fit doesn't add any public attribute
  857. assert not attrs_added_by_fit, (
  858. 'Estimator adds public attribute(s) during'
  859. ' the fit method.'
  860. ' Estimators are only allowed to add private attributes'
  861. ' either started with _ or ended'
  862. ' with _ but %s added'
  863. % ', '.join(attrs_added_by_fit))
  864. # check that fit doesn't change any public attribute
  865. attrs_changed_by_fit = [key for key in public_keys_after_fit
  866. if (dict_before_fit[key]
  867. is not dict_after_fit[key])]
  868. assert not attrs_changed_by_fit, (
  869. 'Estimator changes public attribute(s) during'
  870. ' the fit method. Estimators are only allowed'
  871. ' to change attributes started'
  872. ' or ended with _, but'
  873. ' %s changed'
  874. % ', '.join(attrs_changed_by_fit))
  875. @ignore_warnings(category=FutureWarning)
  876. def check_fit2d_predict1d(name, estimator_orig):
  877. # check by fitting a 2d array and predicting with a 1d array
  878. rnd = np.random.RandomState(0)
  879. X = 3 * rnd.uniform(size=(20, 3))
  880. X = _pairwise_estimator_convert_X(X, estimator_orig)
  881. y = X[:, 0].astype(np.int)
  882. tags = estimator_orig._get_tags()
  883. if tags['binary_only']:
  884. y[y == 2] = 1
  885. estimator = clone(estimator_orig)
  886. y = _enforce_estimator_tags_y(estimator, y)
  887. if hasattr(estimator, "n_components"):
  888. estimator.n_components = 1
  889. if hasattr(estimator, "n_clusters"):
  890. estimator.n_clusters = 1
  891. set_random_state(estimator, 1)
  892. estimator.fit(X, y)
  893. if tags["no_validation"]:
  894. # FIXME this is a bit loose
  895. return
  896. for method in ["predict", "transform", "decision_function",
  897. "predict_proba"]:
  898. if hasattr(estimator, method):
  899. assert_raise_message(ValueError, "Reshape your data",
  900. getattr(estimator, method), X[0])
  901. def _apply_on_subsets(func, X):
  902. # apply function on the whole set and on mini batches
  903. result_full = func(X)
  904. n_features = X.shape[1]
  905. result_by_batch = [func(batch.reshape(1, n_features))
  906. for batch in X]
  907. # func can output tuple (e.g. score_samples)
  908. if type(result_full) == tuple:
  909. result_full = result_full[0]
  910. result_by_batch = list(map(lambda x: x[0], result_by_batch))
  911. if sparse.issparse(result_full):
  912. result_full = result_full.A
  913. result_by_batch = [x.A for x in result_by_batch]
  914. return np.ravel(result_full), np.ravel(result_by_batch)
  915. @ignore_warnings(category=FutureWarning)
  916. def check_methods_subset_invariance(name, estimator_orig):
  917. # check that method gives invariant results if applied
  918. # on mini batches or the whole set
  919. rnd = np.random.RandomState(0)
  920. X = 3 * rnd.uniform(size=(20, 3))
  921. X = _pairwise_estimator_convert_X(X, estimator_orig)
  922. y = X[:, 0].astype(np.int)
  923. if estimator_orig._get_tags()['binary_only']:
  924. y[y == 2] = 1
  925. estimator = clone(estimator_orig)
  926. y = _enforce_estimator_tags_y(estimator, y)
  927. if hasattr(estimator, "n_components"):
  928. estimator.n_components = 1
  929. if hasattr(estimator, "n_clusters"):
  930. estimator.n_clusters = 1
  931. set_random_state(estimator, 1)
  932. estimator.fit(X, y)
  933. for method in ["predict", "transform", "decision_function",
  934. "score_samples", "predict_proba"]:
  935. msg = ("{method} of {name} is not invariant when applied "
  936. "to a subset.").format(method=method, name=name)
  937. if hasattr(estimator, method):
  938. result_full, result_by_batch = _apply_on_subsets(
  939. getattr(estimator, method), X)
  940. assert_allclose(result_full, result_by_batch,
  941. atol=1e-7, err_msg=msg)
  942. @ignore_warnings
  943. def check_fit2d_1sample(name, estimator_orig):
  944. # Check that fitting a 2d array with only one sample either works or
  945. # returns an informative message. The error message should either mention
  946. # the number of samples or the number of classes.
  947. rnd = np.random.RandomState(0)
  948. X = 3 * rnd.uniform(size=(1, 10))
  949. X = _pairwise_estimator_convert_X(X, estimator_orig)
  950. y = X[:, 0].astype(np.int)
  951. estimator = clone(estimator_orig)
  952. y = _enforce_estimator_tags_y(estimator, y)
  953. if hasattr(estimator, "n_components"):
  954. estimator.n_components = 1
  955. if hasattr(estimator, "n_clusters"):
  956. estimator.n_clusters = 1
  957. set_random_state(estimator, 1)
  958. # min_cluster_size cannot be less than the data size for OPTICS.
  959. if name == 'OPTICS':
  960. estimator.set_params(min_samples=1)
  961. msgs = ["1 sample", "n_samples = 1", "n_samples=1", "one sample",
  962. "1 class", "one class"]
  963. try:
  964. estimator.fit(X, y)
  965. except ValueError as e:
  966. if all(msg not in repr(e) for msg in msgs):
  967. raise e
  968. @ignore_warnings
  969. def check_fit2d_1feature(name, estimator_orig):
  970. # check fitting a 2d array with only 1 feature either works or returns
  971. # informative message
  972. rnd = np.random.RandomState(0)
  973. X = 3 * rnd.uniform(size=(10, 1))
  974. X = _pairwise_estimator_convert_X(X, estimator_orig)
  975. y = X[:, 0].astype(np.int)
  976. estimator = clone(estimator_orig)
  977. y = _enforce_estimator_tags_y(estimator, y)
  978. if hasattr(estimator, "n_components"):
  979. estimator.n_components = 1
  980. if hasattr(estimator, "n_clusters"):
  981. estimator.n_clusters = 1
  982. # ensure two labels in subsample for RandomizedLogisticRegression
  983. if name == 'RandomizedLogisticRegression':
  984. estimator.sample_fraction = 1
  985. # ensure non skipped trials for RANSACRegressor
  986. if name == 'RANSACRegressor':
  987. estimator.residual_threshold = 0.5
  988. y = _enforce_estimator_tags_y(estimator, y)
  989. set_random_state(estimator, 1)
  990. msgs = ["1 feature(s)", "n_features = 1", "n_features=1"]
  991. try:
  992. estimator.fit(X, y)
  993. except ValueError as e:
  994. if all(msg not in repr(e) for msg in msgs):
  995. raise e
  996. @ignore_warnings
  997. def check_fit1d(name, estimator_orig):
  998. # check fitting 1d X array raises a ValueError
  999. rnd = np.random.RandomState(0)
  1000. X = 3 * rnd.uniform(size=(20))
  1001. y = X.astype(np.int)
  1002. estimator = clone(estimator_orig)
  1003. tags = estimator._get_tags()
  1004. if tags["no_validation"]:
  1005. # FIXME this is a bit loose
  1006. return
  1007. y = _enforce_estimator_tags_y(estimator, y)
  1008. if hasattr(estimator, "n_components"):
  1009. estimator.n_components = 1
  1010. if hasattr(estimator, "n_clusters"):
  1011. estimator.n_clusters = 1
  1012. set_random_state(estimator, 1)
  1013. assert_raises(ValueError, estimator.fit, X, y)
  1014. @ignore_warnings(category=FutureWarning)
  1015. def check_transformer_general(name, transformer, readonly_memmap=False):
  1016. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1017. random_state=0, n_features=2, cluster_std=0.1)
  1018. X = StandardScaler().fit_transform(X)
  1019. X -= X.min()
  1020. X = _pairwise_estimator_convert_X(X, transformer)
  1021. if readonly_memmap:
  1022. X, y = create_memmap_backed_data([X, y])
  1023. _check_transformer(name, transformer, X, y)
  1024. @ignore_warnings(category=FutureWarning)
  1025. def check_transformer_data_not_an_array(name, transformer):
  1026. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1027. random_state=0, n_features=2, cluster_std=0.1)
  1028. X = StandardScaler().fit_transform(X)
  1029. # We need to make sure that we have non negative data, for things
  1030. # like NMF
  1031. X -= X.min() - .1
  1032. X = _pairwise_estimator_convert_X(X, transformer)
  1033. this_X = _NotAnArray(X)
  1034. this_y = _NotAnArray(np.asarray(y))
  1035. _check_transformer(name, transformer, this_X, this_y)
  1036. # try the same with some list
  1037. _check_transformer(name, transformer, X.tolist(), y.tolist())
  1038. @ignore_warnings(category=FutureWarning)
  1039. def check_transformers_unfitted(name, transformer):
  1040. X, y = _boston_subset()
  1041. transformer = clone(transformer)
  1042. with assert_raises((AttributeError, ValueError), msg="The unfitted "
  1043. "transformer {} does not raise an error when "
  1044. "transform is called. Perhaps use "
  1045. "check_is_fitted in transform.".format(name)):
  1046. transformer.transform(X)
  1047. def _check_transformer(name, transformer_orig, X, y):
  1048. n_samples, n_features = np.asarray(X).shape
  1049. transformer = clone(transformer_orig)
  1050. set_random_state(transformer)
  1051. # fit
  1052. if name in CROSS_DECOMPOSITION:
  1053. y_ = np.c_[np.asarray(y), np.asarray(y)]
  1054. y_[::2, 1] *= 2
  1055. if isinstance(X, _NotAnArray):
  1056. y_ = _NotAnArray(y_)
  1057. else:
  1058. y_ = y
  1059. transformer.fit(X, y_)
  1060. # fit_transform method should work on non fitted estimator
  1061. transformer_clone = clone(transformer)
  1062. X_pred = transformer_clone.fit_transform(X, y=y_)
  1063. if isinstance(X_pred, tuple):
  1064. for x_pred in X_pred:
  1065. assert x_pred.shape[0] == n_samples
  1066. else:
  1067. # check for consistent n_samples
  1068. assert X_pred.shape[0] == n_samples
  1069. if hasattr(transformer, 'transform'):
  1070. if name in CROSS_DECOMPOSITION:
  1071. X_pred2 = transformer.transform(X, y_)
  1072. X_pred3 = transformer.fit_transform(X, y=y_)
  1073. else:
  1074. X_pred2 = transformer.transform(X)
  1075. X_pred3 = transformer.fit_transform(X, y=y_)
  1076. if transformer_orig._get_tags()['non_deterministic']:
  1077. msg = name + ' is non deterministic'
  1078. raise SkipTest(msg)
  1079. if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
  1080. for x_pred, x_pred2, x_pred3 in zip(X_pred, X_pred2, X_pred3):
  1081. assert_allclose_dense_sparse(
  1082. x_pred, x_pred2, atol=1e-2,
  1083. err_msg="fit_transform and transform outcomes "
  1084. "not consistent in %s"
  1085. % transformer)
  1086. assert_allclose_dense_sparse(
  1087. x_pred, x_pred3, atol=1e-2,
  1088. err_msg="consecutive fit_transform outcomes "
  1089. "not consistent in %s"
  1090. % transformer)
  1091. else:
  1092. assert_allclose_dense_sparse(
  1093. X_pred, X_pred2,
  1094. err_msg="fit_transform and transform outcomes "
  1095. "not consistent in %s"
  1096. % transformer, atol=1e-2)
  1097. assert_allclose_dense_sparse(
  1098. X_pred, X_pred3, atol=1e-2,
  1099. err_msg="consecutive fit_transform outcomes "
  1100. "not consistent in %s"
  1101. % transformer)
  1102. assert _num_samples(X_pred2) == n_samples
  1103. assert _num_samples(X_pred3) == n_samples
  1104. # raises error on malformed input for transform
  1105. if hasattr(X, 'shape') and \
  1106. not transformer._get_tags()["stateless"] and \
  1107. X.ndim == 2 and X.shape[1] > 1:
  1108. # If it's not an array, it does not have a 'T' property
  1109. with assert_raises(ValueError, msg="The transformer {} does "
  1110. "not raise an error when the number of "
  1111. "features in transform is different from"
  1112. " the number of features in "
  1113. "fit.".format(name)):
  1114. transformer.transform(X[:, :-1])
  1115. @ignore_warnings
  1116. def check_pipeline_consistency(name, estimator_orig):
  1117. if estimator_orig._get_tags()['non_deterministic']:
  1118. msg = name + ' is non deterministic'
  1119. raise SkipTest(msg)
  1120. # check that make_pipeline(est) gives same score as est
  1121. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1122. random_state=0, n_features=2, cluster_std=0.1)
  1123. X -= X.min()
  1124. X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
  1125. estimator = clone(estimator_orig)
  1126. y = _enforce_estimator_tags_y(estimator, y)
  1127. set_random_state(estimator)
  1128. pipeline = make_pipeline(estimator)
  1129. estimator.fit(X, y)
  1130. pipeline.fit(X, y)
  1131. funcs = ["score", "fit_transform"]
  1132. for func_name in funcs:
  1133. func = getattr(estimator, func_name, None)
  1134. if func is not None:
  1135. func_pipeline = getattr(pipeline, func_name)
  1136. result = func(X, y)
  1137. result_pipe = func_pipeline(X, y)
  1138. assert_allclose_dense_sparse(result, result_pipe)
  1139. @ignore_warnings
  1140. def check_fit_score_takes_y(name, estimator_orig):
  1141. # check that all estimators accept an optional y
  1142. # in fit and score so they can be used in pipelines
  1143. rnd = np.random.RandomState(0)
  1144. n_samples = 30
  1145. X = rnd.uniform(size=(n_samples, 3))
  1146. X = _pairwise_estimator_convert_X(X, estimator_orig)
  1147. if estimator_orig._get_tags()['binary_only']:
  1148. y = np.arange(n_samples) % 2
  1149. else:
  1150. y = np.arange(n_samples) % 3
  1151. estimator = clone(estimator_orig)
  1152. y = _enforce_estimator_tags_y(estimator, y)
  1153. set_random_state(estimator)
  1154. funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"]
  1155. for func_name in funcs:
  1156. func = getattr(estimator, func_name, None)
  1157. if func is not None:
  1158. func(X, y)
  1159. args = [p.name for p in signature(func).parameters.values()]
  1160. if args[0] == "self":
  1161. # if_delegate_has_method makes methods into functions
  1162. # with an explicit "self", so need to shift arguments
  1163. args = args[1:]
  1164. assert args[1] in ["y", "Y"], (
  1165. "Expected y or Y as second argument for method "
  1166. "%s of %s. Got arguments: %r."
  1167. % (func_name, type(estimator).__name__, args))
  1168. @ignore_warnings
  1169. def check_estimators_dtypes(name, estimator_orig):
  1170. rnd = np.random.RandomState(0)
  1171. X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32)
  1172. X_train_32 = _pairwise_estimator_convert_X(X_train_32, estimator_orig)
  1173. X_train_64 = X_train_32.astype(np.float64)
  1174. X_train_int_64 = X_train_32.astype(np.int64)
  1175. X_train_int_32 = X_train_32.astype(np.int32)
  1176. y = X_train_int_64[:, 0]
  1177. if estimator_orig._get_tags()['binary_only']:
  1178. y[y == 2] = 1
  1179. y = _enforce_estimator_tags_y(estimator_orig, y)
  1180. methods = ["predict", "transform", "decision_function", "predict_proba"]
  1181. for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]:
  1182. estimator = clone(estimator_orig)
  1183. set_random_state(estimator, 1)
  1184. estimator.fit(X_train, y)
  1185. for method in methods:
  1186. if hasattr(estimator, method):
  1187. getattr(estimator, method)(X_train)
  1188. @ignore_warnings(category=FutureWarning)
  1189. def check_estimators_empty_data_messages(name, estimator_orig):
  1190. e = clone(estimator_orig)
  1191. set_random_state(e, 1)
  1192. X_zero_samples = np.empty(0).reshape(0, 3)
  1193. # The precise message can change depending on whether X or y is
  1194. # validated first. Let us test the type of exception only:
  1195. with assert_raises(ValueError, msg="The estimator {} does not"
  1196. " raise an error when an empty data is used "
  1197. "to train. Perhaps use "
  1198. "check_array in train.".format(name)):
  1199. e.fit(X_zero_samples, [])
  1200. X_zero_features = np.empty(0).reshape(3, 0)
  1201. # the following y should be accepted by both classifiers and regressors
  1202. # and ignored by unsupervised models
  1203. y = _enforce_estimator_tags_y(e, np.array([1, 0, 1]))
  1204. msg = (r"0 feature\(s\) \(shape=\(3, 0\)\) while a minimum of \d* "
  1205. "is required.")
  1206. assert_raises_regex(ValueError, msg, e.fit, X_zero_features, y)
  1207. @ignore_warnings(category=FutureWarning)
  1208. def check_estimators_nan_inf(name, estimator_orig):
  1209. # Checks that Estimator X's do not contain NaN or inf.
  1210. rnd = np.random.RandomState(0)
  1211. X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)),
  1212. estimator_orig)
  1213. X_train_nan = rnd.uniform(size=(10, 3))
  1214. X_train_nan[0, 0] = np.nan
  1215. X_train_inf = rnd.uniform(size=(10, 3))
  1216. X_train_inf[0, 0] = np.inf
  1217. y = np.ones(10)
  1218. y[:5] = 0
  1219. y = _enforce_estimator_tags_y(estimator_orig, y)
  1220. error_string_fit = "Estimator doesn't check for NaN and inf in fit."
  1221. error_string_predict = ("Estimator doesn't check for NaN and inf in"
  1222. " predict.")
  1223. error_string_transform = ("Estimator doesn't check for NaN and inf in"
  1224. " transform.")
  1225. for X_train in [X_train_nan, X_train_inf]:
  1226. # catch deprecation warnings
  1227. with ignore_warnings(category=FutureWarning):
  1228. estimator = clone(estimator_orig)
  1229. set_random_state(estimator, 1)
  1230. # try to fit
  1231. try:
  1232. estimator.fit(X_train, y)
  1233. except ValueError as e:
  1234. if 'inf' not in repr(e) and 'NaN' not in repr(e):
  1235. print(error_string_fit, estimator, e)
  1236. traceback.print_exc(file=sys.stdout)
  1237. raise e
  1238. except Exception as exc:
  1239. print(error_string_fit, estimator, exc)
  1240. traceback.print_exc(file=sys.stdout)
  1241. raise exc
  1242. else:
  1243. raise AssertionError(error_string_fit, estimator)
  1244. # actually fit
  1245. estimator.fit(X_train_finite, y)
  1246. # predict
  1247. if hasattr(estimator, "predict"):
  1248. try:
  1249. estimator.predict(X_train)
  1250. except ValueError as e:
  1251. if 'inf' not in repr(e) and 'NaN' not in repr(e):
  1252. print(error_string_predict, estimator, e)
  1253. traceback.print_exc(file=sys.stdout)
  1254. raise e
  1255. except Exception as exc:
  1256. print(error_string_predict, estimator, exc)
  1257. traceback.print_exc(file=sys.stdout)
  1258. else:
  1259. raise AssertionError(error_string_predict, estimator)
  1260. # transform
  1261. if hasattr(estimator, "transform"):
  1262. try:
  1263. estimator.transform(X_train)
  1264. except ValueError as e:
  1265. if 'inf' not in repr(e) and 'NaN' not in repr(e):
  1266. print(error_string_transform, estimator, e)
  1267. traceback.print_exc(file=sys.stdout)
  1268. raise e
  1269. except Exception as exc:
  1270. print(error_string_transform, estimator, exc)
  1271. traceback.print_exc(file=sys.stdout)
  1272. else:
  1273. raise AssertionError(error_string_transform, estimator)
  1274. @ignore_warnings
  1275. def check_nonsquare_error(name, estimator_orig):
  1276. """Test that error is thrown when non-square data provided"""
  1277. X, y = make_blobs(n_samples=20, n_features=10)
  1278. estimator = clone(estimator_orig)
  1279. with assert_raises(ValueError, msg="The pairwise estimator {}"
  1280. " does not raise an error on non-square data"
  1281. .format(name)):
  1282. estimator.fit(X, y)
  1283. @ignore_warnings
  1284. def check_estimators_pickle(name, estimator_orig):
  1285. """Test that we can pickle all estimators"""
  1286. check_methods = ["predict", "transform", "decision_function",
  1287. "predict_proba"]
  1288. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1289. random_state=0, n_features=2, cluster_std=0.1)
  1290. # some estimators can't do features less than 0
  1291. X -= X.min()
  1292. X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
  1293. tags = estimator_orig._get_tags()
  1294. # include NaN values when the estimator should deal with them
  1295. if tags['allow_nan']:
  1296. # set randomly 10 elements to np.nan
  1297. rng = np.random.RandomState(42)
  1298. mask = rng.choice(X.size, 10, replace=False)
  1299. X.reshape(-1)[mask] = np.nan
  1300. estimator = clone(estimator_orig)
  1301. y = _enforce_estimator_tags_y(estimator, y)
  1302. set_random_state(estimator)
  1303. estimator.fit(X, y)
  1304. result = dict()
  1305. for method in check_methods:
  1306. if hasattr(estimator, method):
  1307. result[method] = getattr(estimator, method)(X)
  1308. # pickle and unpickle!
  1309. pickled_estimator = pickle.dumps(estimator)
  1310. if estimator.__module__.startswith('sklearn.'):
  1311. assert b"version" in pickled_estimator
  1312. unpickled_estimator = pickle.loads(pickled_estimator)
  1313. result = dict()
  1314. for method in check_methods:
  1315. if hasattr(estimator, method):
  1316. result[method] = getattr(estimator, method)(X)
  1317. for method in result:
  1318. unpickled_result = getattr(unpickled_estimator, method)(X)
  1319. assert_allclose_dense_sparse(result[method], unpickled_result)
  1320. @ignore_warnings(category=FutureWarning)
  1321. def check_estimators_partial_fit_n_features(name, estimator_orig):
  1322. # check if number of features changes between calls to partial_fit.
  1323. if not hasattr(estimator_orig, 'partial_fit'):
  1324. return
  1325. estimator = clone(estimator_orig)
  1326. X, y = make_blobs(n_samples=50, random_state=1)
  1327. X -= X.min()
  1328. try:
  1329. if is_classifier(estimator):
  1330. classes = np.unique(y)
  1331. estimator.partial_fit(X, y, classes=classes)
  1332. else:
  1333. estimator.partial_fit(X, y)
  1334. except NotImplementedError:
  1335. return
  1336. with assert_raises(ValueError,
  1337. msg="The estimator {} does not raise an"
  1338. " error when the number of features"
  1339. " changes between calls to "
  1340. "partial_fit.".format(name)):
  1341. estimator.partial_fit(X[:, :-1], y)
  1342. @ignore_warnings(category=FutureWarning)
  1343. def check_classifier_multioutput(name, estimator):
  1344. n_samples, n_labels, n_classes = 42, 5, 3
  1345. tags = estimator._get_tags()
  1346. estimator = clone(estimator)
  1347. X, y = make_multilabel_classification(random_state=42,
  1348. n_samples=n_samples,
  1349. n_labels=n_labels,
  1350. n_classes=n_classes)
  1351. estimator.fit(X, y)
  1352. y_pred = estimator.predict(X)
  1353. assert y_pred.shape == (n_samples, n_classes), (
  1354. "The shape of the prediction for multioutput data is "
  1355. "incorrect. Expected {}, got {}."
  1356. .format((n_samples, n_labels), y_pred.shape))
  1357. assert y_pred.dtype.kind == 'i'
  1358. if hasattr(estimator, "decision_function"):
  1359. decision = estimator.decision_function(X)
  1360. assert isinstance(decision, np.ndarray)
  1361. assert decision.shape == (n_samples, n_classes), (
  1362. "The shape of the decision function output for "
  1363. "multioutput data is incorrect. Expected {}, got {}."
  1364. .format((n_samples, n_classes), decision.shape))
  1365. dec_pred = (decision > 0).astype(np.int)
  1366. dec_exp = estimator.classes_[dec_pred]
  1367. assert_array_equal(dec_exp, y_pred)
  1368. if hasattr(estimator, "predict_proba"):
  1369. y_prob = estimator.predict_proba(X)
  1370. if isinstance(y_prob, list) and not tags['poor_score']:
  1371. for i in range(n_classes):
  1372. assert y_prob[i].shape == (n_samples, 2), (
  1373. "The shape of the probability for multioutput data is"
  1374. " incorrect. Expected {}, got {}."
  1375. .format((n_samples, 2), y_prob[i].shape))
  1376. assert_array_equal(
  1377. np.argmax(y_prob[i], axis=1).astype(np.int),
  1378. y_pred[:, i]
  1379. )
  1380. elif not tags['poor_score']:
  1381. assert y_prob.shape == (n_samples, n_classes), (
  1382. "The shape of the probability for multioutput data is"
  1383. " incorrect. Expected {}, got {}."
  1384. .format((n_samples, n_classes), y_prob.shape))
  1385. assert_array_equal(y_prob.round().astype(int), y_pred)
  1386. if (hasattr(estimator, "decision_function") and
  1387. hasattr(estimator, "predict_proba")):
  1388. for i in range(n_classes):
  1389. y_proba = estimator.predict_proba(X)[:, i]
  1390. y_decision = estimator.decision_function(X)
  1391. assert_array_equal(rankdata(y_proba), rankdata(y_decision[:, i]))
  1392. @ignore_warnings(category=FutureWarning)
  1393. def check_regressor_multioutput(name, estimator):
  1394. estimator = clone(estimator)
  1395. n_samples = n_features = 10
  1396. if not _is_pairwise_metric(estimator):
  1397. n_samples = n_samples + 1
  1398. X, y = make_regression(random_state=42, n_targets=5,
  1399. n_samples=n_samples, n_features=n_features)
  1400. X = pairwise_estimator_convert_X(X, estimator)
  1401. estimator.fit(X, y)
  1402. y_pred = estimator.predict(X)
  1403. assert y_pred.dtype == np.dtype('float64'), (
  1404. "Multioutput predictions by a regressor are expected to be"
  1405. " floating-point precision. Got {} instead".format(y_pred.dtype))
  1406. assert y_pred.shape == y.shape, (
  1407. "The shape of the orediction for multioutput data is incorrect."
  1408. " Expected {}, got {}.")
  1409. @ignore_warnings(category=FutureWarning)
  1410. def check_clustering(name, clusterer_orig, readonly_memmap=False):
  1411. clusterer = clone(clusterer_orig)
  1412. X, y = make_blobs(n_samples=50, random_state=1)
  1413. X, y = shuffle(X, y, random_state=7)
  1414. X = StandardScaler().fit_transform(X)
  1415. rng = np.random.RandomState(7)
  1416. X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))])
  1417. if readonly_memmap:
  1418. X, y, X_noise = create_memmap_backed_data([X, y, X_noise])
  1419. n_samples, n_features = X.shape
  1420. # catch deprecation and neighbors warnings
  1421. if hasattr(clusterer, "n_clusters"):
  1422. clusterer.set_params(n_clusters=3)
  1423. set_random_state(clusterer)
  1424. if name == 'AffinityPropagation':
  1425. clusterer.set_params(preference=-100)
  1426. clusterer.set_params(max_iter=100)
  1427. # fit
  1428. clusterer.fit(X)
  1429. # with lists
  1430. clusterer.fit(X.tolist())
  1431. pred = clusterer.labels_
  1432. assert pred.shape == (n_samples,)
  1433. assert adjusted_rand_score(pred, y) > 0.4
  1434. if clusterer._get_tags()['non_deterministic']:
  1435. return
  1436. set_random_state(clusterer)
  1437. with warnings.catch_warnings(record=True):
  1438. pred2 = clusterer.fit_predict(X)
  1439. assert_array_equal(pred, pred2)
  1440. # fit_predict(X) and labels_ should be of type int
  1441. assert pred.dtype in [np.dtype('int32'), np.dtype('int64')]
  1442. assert pred2.dtype in [np.dtype('int32'), np.dtype('int64')]
  1443. # Add noise to X to test the possible values of the labels
  1444. labels = clusterer.fit_predict(X_noise)
  1445. # There should be at least one sample in every cluster. Equivalently
  1446. # labels_ should contain all the consecutive values between its
  1447. # min and its max.
  1448. labels_sorted = np.unique(labels)
  1449. assert_array_equal(labels_sorted, np.arange(labels_sorted[0],
  1450. labels_sorted[-1] + 1))
  1451. # Labels are expected to start at 0 (no noise) or -1 (if noise)
  1452. assert labels_sorted[0] in [0, -1]
  1453. # Labels should be less than n_clusters - 1
  1454. if hasattr(clusterer, 'n_clusters'):
  1455. n_clusters = getattr(clusterer, 'n_clusters')
  1456. assert n_clusters - 1 >= labels_sorted[-1]
  1457. # else labels should be less than max(labels_) which is necessarily true
  1458. @ignore_warnings(category=FutureWarning)
  1459. def check_clusterer_compute_labels_predict(name, clusterer_orig):
  1460. """Check that predict is invariant of compute_labels"""
  1461. X, y = make_blobs(n_samples=20, random_state=0)
  1462. clusterer = clone(clusterer_orig)
  1463. set_random_state(clusterer)
  1464. if hasattr(clusterer, "compute_labels"):
  1465. # MiniBatchKMeans
  1466. X_pred1 = clusterer.fit(X).predict(X)
  1467. clusterer.set_params(compute_labels=False)
  1468. X_pred2 = clusterer.fit(X).predict(X)
  1469. assert_array_equal(X_pred1, X_pred2)
  1470. @ignore_warnings(category=FutureWarning)
  1471. def check_classifiers_one_label(name, classifier_orig):
  1472. error_string_fit = "Classifier can't train when only one class is present."
  1473. error_string_predict = ("Classifier can't predict when only one class is "
  1474. "present.")
  1475. rnd = np.random.RandomState(0)
  1476. X_train = rnd.uniform(size=(10, 3))
  1477. X_test = rnd.uniform(size=(10, 3))
  1478. y = np.ones(10)
  1479. # catch deprecation warnings
  1480. with ignore_warnings(category=FutureWarning):
  1481. classifier = clone(classifier_orig)
  1482. # try to fit
  1483. try:
  1484. classifier.fit(X_train, y)
  1485. except ValueError as e:
  1486. if 'class' not in repr(e):
  1487. print(error_string_fit, classifier, e)
  1488. traceback.print_exc(file=sys.stdout)
  1489. raise e
  1490. else:
  1491. return
  1492. except Exception as exc:
  1493. print(error_string_fit, classifier, exc)
  1494. traceback.print_exc(file=sys.stdout)
  1495. raise exc
  1496. # predict
  1497. try:
  1498. assert_array_equal(classifier.predict(X_test), y)
  1499. except Exception as exc:
  1500. print(error_string_predict, classifier, exc)
  1501. raise exc
  1502. @ignore_warnings # Warnings are raised by decision function
  1503. def check_classifiers_train(name, classifier_orig, readonly_memmap=False,
  1504. X_dtype='float64'):
  1505. X_m, y_m = make_blobs(n_samples=300, random_state=0)
  1506. X_m = X_m.astype(X_dtype)
  1507. X_m, y_m = shuffle(X_m, y_m, random_state=7)
  1508. X_m = StandardScaler().fit_transform(X_m)
  1509. # generate binary problem from multi-class one
  1510. y_b = y_m[y_m != 2]
  1511. X_b = X_m[y_m != 2]
  1512. if name in ['BernoulliNB', 'MultinomialNB', 'ComplementNB',
  1513. 'CategoricalNB']:
  1514. X_m -= X_m.min()
  1515. X_b -= X_b.min()
  1516. if readonly_memmap:
  1517. X_m, y_m, X_b, y_b = create_memmap_backed_data([X_m, y_m, X_b, y_b])
  1518. problems = [(X_b, y_b)]
  1519. tags = classifier_orig._get_tags()
  1520. if not tags['binary_only']:
  1521. problems.append((X_m, y_m))
  1522. for (X, y) in problems:
  1523. classes = np.unique(y)
  1524. n_classes = len(classes)
  1525. n_samples, n_features = X.shape
  1526. classifier = clone(classifier_orig)
  1527. X = _pairwise_estimator_convert_X(X, classifier)
  1528. y = _enforce_estimator_tags_y(classifier, y)
  1529. set_random_state(classifier)
  1530. # raises error on malformed input for fit
  1531. if not tags["no_validation"]:
  1532. with assert_raises(
  1533. ValueError,
  1534. msg="The classifier {} does not "
  1535. "raise an error when incorrect/malformed input "
  1536. "data for fit is passed. The number of training "
  1537. "examples is not the same as the number of labels. "
  1538. "Perhaps use check_X_y in fit.".format(name)):
  1539. classifier.fit(X, y[:-1])
  1540. # fit
  1541. classifier.fit(X, y)
  1542. # with lists
  1543. classifier.fit(X.tolist(), y.tolist())
  1544. assert hasattr(classifier, "classes_")
  1545. y_pred = classifier.predict(X)
  1546. assert y_pred.shape == (n_samples,)
  1547. # training set performance
  1548. if not tags['poor_score']:
  1549. assert accuracy_score(y, y_pred) > 0.83
  1550. # raises error on malformed input for predict
  1551. msg_pairwise = (
  1552. "The classifier {} does not raise an error when shape of X in "
  1553. " {} is not equal to (n_test_samples, n_training_samples)")
  1554. msg = ("The classifier {} does not raise an error when the number of "
  1555. "features in {} is different from the number of features in "
  1556. "fit.")
  1557. if not tags["no_validation"]:
  1558. if _is_pairwise(classifier):
  1559. with assert_raises(ValueError,
  1560. msg=msg_pairwise.format(name, "predict")):
  1561. classifier.predict(X.reshape(-1, 1))
  1562. else:
  1563. with assert_raises(ValueError,
  1564. msg=msg.format(name, "predict")):
  1565. classifier.predict(X.T)
  1566. if hasattr(classifier, "decision_function"):
  1567. try:
  1568. # decision_function agrees with predict
  1569. decision = classifier.decision_function(X)
  1570. if n_classes == 2:
  1571. if not tags["multioutput_only"]:
  1572. assert decision.shape == (n_samples,)
  1573. else:
  1574. assert decision.shape == (n_samples, 1)
  1575. dec_pred = (decision.ravel() > 0).astype(np.int)
  1576. assert_array_equal(dec_pred, y_pred)
  1577. else:
  1578. assert decision.shape == (n_samples, n_classes)
  1579. assert_array_equal(np.argmax(decision, axis=1), y_pred)
  1580. # raises error on malformed input for decision_function
  1581. if not tags["no_validation"]:
  1582. if _is_pairwise(classifier):
  1583. with assert_raises(ValueError, msg=msg_pairwise.format(
  1584. name, "decision_function")):
  1585. classifier.decision_function(X.reshape(-1, 1))
  1586. else:
  1587. with assert_raises(ValueError, msg=msg.format(
  1588. name, "decision_function")):
  1589. classifier.decision_function(X.T)
  1590. except NotImplementedError:
  1591. pass
  1592. if hasattr(classifier, "predict_proba"):
  1593. # predict_proba agrees with predict
  1594. y_prob = classifier.predict_proba(X)
  1595. assert y_prob.shape == (n_samples, n_classes)
  1596. assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
  1597. # check that probas for all classes sum to one
  1598. assert_array_almost_equal(np.sum(y_prob, axis=1),
  1599. np.ones(n_samples))
  1600. if not tags["no_validation"]:
  1601. # raises error on malformed input for predict_proba
  1602. if _is_pairwise(classifier_orig):
  1603. with assert_raises(ValueError, msg=msg_pairwise.format(
  1604. name, "predict_proba")):
  1605. classifier.predict_proba(X.reshape(-1, 1))
  1606. else:
  1607. with assert_raises(ValueError, msg=msg.format(
  1608. name, "predict_proba")):
  1609. classifier.predict_proba(X.T)
  1610. if hasattr(classifier, "predict_log_proba"):
  1611. # predict_log_proba is a transformation of predict_proba
  1612. y_log_prob = classifier.predict_log_proba(X)
  1613. assert_allclose(y_log_prob, np.log(y_prob), 8, atol=1e-9)
  1614. assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob))
  1615. def check_outlier_corruption(num_outliers, expected_outliers, decision):
  1616. # Check for deviation from the precise given contamination level that may
  1617. # be due to ties in the anomaly scores.
  1618. if num_outliers < expected_outliers:
  1619. start = num_outliers
  1620. end = expected_outliers + 1
  1621. else:
  1622. start = expected_outliers
  1623. end = num_outliers + 1
  1624. # ensure that all values in the 'critical area' are tied,
  1625. # leading to the observed discrepancy between provided
  1626. # and actual contamination levels.
  1627. sorted_decision = np.sort(decision)
  1628. msg = ('The number of predicted outliers is not equal to the expected '
  1629. 'number of outliers and this difference is not explained by the '
  1630. 'number of ties in the decision_function values')
  1631. assert len(np.unique(sorted_decision[start:end])) == 1, msg
  1632. def check_outliers_train(name, estimator_orig, readonly_memmap=True):
  1633. n_samples = 300
  1634. X, _ = make_blobs(n_samples=n_samples, random_state=0)
  1635. X = shuffle(X, random_state=7)
  1636. if readonly_memmap:
  1637. X = create_memmap_backed_data(X)
  1638. n_samples, n_features = X.shape
  1639. estimator = clone(estimator_orig)
  1640. set_random_state(estimator)
  1641. # fit
  1642. estimator.fit(X)
  1643. # with lists
  1644. estimator.fit(X.tolist())
  1645. y_pred = estimator.predict(X)
  1646. assert y_pred.shape == (n_samples,)
  1647. assert y_pred.dtype.kind == 'i'
  1648. assert_array_equal(np.unique(y_pred), np.array([-1, 1]))
  1649. decision = estimator.decision_function(X)
  1650. scores = estimator.score_samples(X)
  1651. for output in [decision, scores]:
  1652. assert output.dtype == np.dtype('float')
  1653. assert output.shape == (n_samples,)
  1654. # raises error on malformed input for predict
  1655. assert_raises(ValueError, estimator.predict, X.T)
  1656. # decision_function agrees with predict
  1657. dec_pred = (decision >= 0).astype(np.int)
  1658. dec_pred[dec_pred == 0] = -1
  1659. assert_array_equal(dec_pred, y_pred)
  1660. # raises error on malformed input for decision_function
  1661. assert_raises(ValueError, estimator.decision_function, X.T)
  1662. # decision_function is a translation of score_samples
  1663. y_dec = scores - estimator.offset_
  1664. assert_allclose(y_dec, decision)
  1665. # raises error on malformed input for score_samples
  1666. assert_raises(ValueError, estimator.score_samples, X.T)
  1667. # contamination parameter (not for OneClassSVM which has the nu parameter)
  1668. if (hasattr(estimator, 'contamination')
  1669. and not hasattr(estimator, 'novelty')):
  1670. # proportion of outliers equal to contamination parameter when not
  1671. # set to 'auto'. This is true for the training set and cannot thus be
  1672. # checked as follows for estimators with a novelty parameter such as
  1673. # LocalOutlierFactor (tested in check_outliers_fit_predict)
  1674. expected_outliers = 30
  1675. contamination = expected_outliers / n_samples
  1676. estimator.set_params(contamination=contamination)
  1677. estimator.fit(X)
  1678. y_pred = estimator.predict(X)
  1679. num_outliers = np.sum(y_pred != 1)
  1680. # num_outliers should be equal to expected_outliers unless
  1681. # there are ties in the decision_function values. this can
  1682. # only be tested for estimators with a decision_function
  1683. # method, i.e. all estimators except LOF which is already
  1684. # excluded from this if branch.
  1685. if num_outliers != expected_outliers:
  1686. decision = estimator.decision_function(X)
  1687. check_outlier_corruption(num_outliers, expected_outliers, decision)
  1688. # raises error when contamination is a scalar and not in [0,1]
  1689. for contamination in [-0.5, 2.3]:
  1690. estimator.set_params(contamination=contamination)
  1691. assert_raises(ValueError, estimator.fit, X)
  1692. @ignore_warnings(category=(FutureWarning))
  1693. def check_classifiers_multilabel_representation_invariance(name,
  1694. classifier_orig):
  1695. X, y = make_multilabel_classification(n_samples=100, n_features=20,
  1696. n_classes=5, n_labels=3,
  1697. length=50, allow_unlabeled=True,
  1698. random_state=0)
  1699. X_train, y_train = X[:80], y[:80]
  1700. X_test = X[80:]
  1701. y_train_list_of_lists = y_train.tolist()
  1702. y_train_list_of_arrays = list(y_train)
  1703. classifier = clone(classifier_orig)
  1704. set_random_state(classifier)
  1705. y_pred = classifier.fit(X_train, y_train).predict(X_test)
  1706. y_pred_list_of_lists = classifier.fit(
  1707. X_train, y_train_list_of_lists).predict(X_test)
  1708. y_pred_list_of_arrays = classifier.fit(
  1709. X_train, y_train_list_of_arrays).predict(X_test)
  1710. assert_array_equal(y_pred, y_pred_list_of_arrays)
  1711. assert_array_equal(y_pred, y_pred_list_of_lists)
  1712. assert y_pred.dtype == y_pred_list_of_arrays.dtype
  1713. assert y_pred.dtype == y_pred_list_of_lists.dtype
  1714. assert type(y_pred) == type(y_pred_list_of_arrays)
  1715. assert type(y_pred) == type(y_pred_list_of_lists)
  1716. @ignore_warnings(category=FutureWarning)
  1717. def check_estimators_fit_returns_self(name, estimator_orig,
  1718. readonly_memmap=False):
  1719. """Check if self is returned when calling fit"""
  1720. if estimator_orig._get_tags()['binary_only']:
  1721. n_centers = 2
  1722. else:
  1723. n_centers = 3
  1724. X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
  1725. # some want non-negative input
  1726. X -= X.min()
  1727. X = _pairwise_estimator_convert_X(X, estimator_orig)
  1728. estimator = clone(estimator_orig)
  1729. y = _enforce_estimator_tags_y(estimator, y)
  1730. if readonly_memmap:
  1731. X, y = create_memmap_backed_data([X, y])
  1732. set_random_state(estimator)
  1733. assert estimator.fit(X, y) is estimator
  1734. @ignore_warnings
  1735. def check_estimators_unfitted(name, estimator_orig):
  1736. """Check that predict raises an exception in an unfitted estimator.
  1737. Unfitted estimators should raise a NotFittedError.
  1738. """
  1739. # Common test for Regressors, Classifiers and Outlier detection estimators
  1740. X, y = _boston_subset()
  1741. estimator = clone(estimator_orig)
  1742. for method in ('decision_function', 'predict', 'predict_proba',
  1743. 'predict_log_proba'):
  1744. if hasattr(estimator, method):
  1745. assert_raises(NotFittedError, getattr(estimator, method), X)
  1746. @ignore_warnings(category=FutureWarning)
  1747. def check_supervised_y_2d(name, estimator_orig):
  1748. tags = estimator_orig._get_tags()
  1749. if tags['multioutput_only']:
  1750. # These only work on 2d, so this test makes no sense
  1751. return
  1752. rnd = np.random.RandomState(0)
  1753. n_samples = 30
  1754. X = _pairwise_estimator_convert_X(
  1755. rnd.uniform(size=(n_samples, 3)), estimator_orig
  1756. )
  1757. if tags['binary_only']:
  1758. y = np.arange(n_samples) % 2
  1759. else:
  1760. y = np.arange(n_samples) % 3
  1761. y = _enforce_estimator_tags_y(estimator_orig, y)
  1762. estimator = clone(estimator_orig)
  1763. set_random_state(estimator)
  1764. # fit
  1765. estimator.fit(X, y)
  1766. y_pred = estimator.predict(X)
  1767. set_random_state(estimator)
  1768. # Check that when a 2D y is given, a DataConversionWarning is
  1769. # raised
  1770. with warnings.catch_warnings(record=True) as w:
  1771. warnings.simplefilter("always", DataConversionWarning)
  1772. warnings.simplefilter("ignore", RuntimeWarning)
  1773. estimator.fit(X, y[:, np.newaxis])
  1774. y_pred_2d = estimator.predict(X)
  1775. msg = "expected 1 DataConversionWarning, got: %s" % (
  1776. ", ".join([str(w_x) for w_x in w]))
  1777. if not tags['multioutput']:
  1778. # check that we warned if we don't support multi-output
  1779. assert len(w) > 0, msg
  1780. assert "DataConversionWarning('A column-vector y" \
  1781. " was passed when a 1d array was expected" in msg
  1782. assert_allclose(y_pred.ravel(), y_pred_2d.ravel())
  1783. @ignore_warnings
  1784. def check_classifiers_predictions(X, y, name, classifier_orig):
  1785. classes = np.unique(y)
  1786. classifier = clone(classifier_orig)
  1787. if name == 'BernoulliNB':
  1788. X = X > X.mean()
  1789. set_random_state(classifier)
  1790. classifier.fit(X, y)
  1791. y_pred = classifier.predict(X)
  1792. if hasattr(classifier, "decision_function"):
  1793. decision = classifier.decision_function(X)
  1794. assert isinstance(decision, np.ndarray)
  1795. if len(classes) == 2:
  1796. dec_pred = (decision.ravel() > 0).astype(np.int)
  1797. dec_exp = classifier.classes_[dec_pred]
  1798. assert_array_equal(dec_exp, y_pred,
  1799. err_msg="decision_function does not match "
  1800. "classifier for %r: expected '%s', got '%s'" %
  1801. (classifier, ", ".join(map(str, dec_exp)),
  1802. ", ".join(map(str, y_pred))))
  1803. elif getattr(classifier, 'decision_function_shape', 'ovr') == 'ovr':
  1804. decision_y = np.argmax(decision, axis=1).astype(int)
  1805. y_exp = classifier.classes_[decision_y]
  1806. assert_array_equal(y_exp, y_pred,
  1807. err_msg="decision_function does not match "
  1808. "classifier for %r: expected '%s', got '%s'" %
  1809. (classifier, ", ".join(map(str, y_exp)),
  1810. ", ".join(map(str, y_pred))))
  1811. # training set performance
  1812. if name != "ComplementNB":
  1813. # This is a pathological data set for ComplementNB.
  1814. # For some specific cases 'ComplementNB' predicts less classes
  1815. # than expected
  1816. assert_array_equal(np.unique(y), np.unique(y_pred))
  1817. assert_array_equal(classes, classifier.classes_,
  1818. err_msg="Unexpected classes_ attribute for %r: "
  1819. "expected '%s', got '%s'" %
  1820. (classifier, ", ".join(map(str, classes)),
  1821. ", ".join(map(str, classifier.classes_))))
  1822. # TODO: remove in 0.24
  1823. @deprecated("choose_check_classifiers_labels is deprecated in version "
  1824. "0.22 and will be removed in version 0.24.")
  1825. def choose_check_classifiers_labels(name, y, y_names):
  1826. return _choose_check_classifiers_labels(name, y, y_names)
  1827. def _choose_check_classifiers_labels(name, y, y_names):
  1828. return y if name in ["LabelPropagation", "LabelSpreading"] else y_names
  1829. def check_classifiers_classes(name, classifier_orig):
  1830. X_multiclass, y_multiclass = make_blobs(n_samples=30, random_state=0,
  1831. cluster_std=0.1)
  1832. X_multiclass, y_multiclass = shuffle(X_multiclass, y_multiclass,
  1833. random_state=7)
  1834. X_multiclass = StandardScaler().fit_transform(X_multiclass)
  1835. # We need to make sure that we have non negative data, for things
  1836. # like NMF
  1837. X_multiclass -= X_multiclass.min() - .1
  1838. X_binary = X_multiclass[y_multiclass != 2]
  1839. y_binary = y_multiclass[y_multiclass != 2]
  1840. X_multiclass = _pairwise_estimator_convert_X(X_multiclass, classifier_orig)
  1841. X_binary = _pairwise_estimator_convert_X(X_binary, classifier_orig)
  1842. labels_multiclass = ["one", "two", "three"]
  1843. labels_binary = ["one", "two"]
  1844. y_names_multiclass = np.take(labels_multiclass, y_multiclass)
  1845. y_names_binary = np.take(labels_binary, y_binary)
  1846. problems = [(X_binary, y_binary, y_names_binary)]
  1847. if not classifier_orig._get_tags()['binary_only']:
  1848. problems.append((X_multiclass, y_multiclass, y_names_multiclass))
  1849. for X, y, y_names in problems:
  1850. for y_names_i in [y_names, y_names.astype('O')]:
  1851. y_ = _choose_check_classifiers_labels(name, y, y_names_i)
  1852. check_classifiers_predictions(X, y_, name, classifier_orig)
  1853. labels_binary = [-1, 1]
  1854. y_names_binary = np.take(labels_binary, y_binary)
  1855. y_binary = _choose_check_classifiers_labels(name, y_binary, y_names_binary)
  1856. check_classifiers_predictions(X_binary, y_binary, name, classifier_orig)
  1857. @ignore_warnings(category=FutureWarning)
  1858. def check_regressors_int(name, regressor_orig):
  1859. X, _ = _boston_subset()
  1860. X = _pairwise_estimator_convert_X(X[:50], regressor_orig)
  1861. rnd = np.random.RandomState(0)
  1862. y = rnd.randint(3, size=X.shape[0])
  1863. y = _enforce_estimator_tags_y(regressor_orig, y)
  1864. rnd = np.random.RandomState(0)
  1865. # separate estimators to control random seeds
  1866. regressor_1 = clone(regressor_orig)
  1867. regressor_2 = clone(regressor_orig)
  1868. set_random_state(regressor_1)
  1869. set_random_state(regressor_2)
  1870. if name in CROSS_DECOMPOSITION:
  1871. y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])
  1872. y_ = y_.T
  1873. else:
  1874. y_ = y
  1875. # fit
  1876. regressor_1.fit(X, y_)
  1877. pred1 = regressor_1.predict(X)
  1878. regressor_2.fit(X, y_.astype(np.float))
  1879. pred2 = regressor_2.predict(X)
  1880. assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
  1881. @ignore_warnings(category=FutureWarning)
  1882. def check_regressors_train(name, regressor_orig, readonly_memmap=False,
  1883. X_dtype=np.float64):
  1884. X, y = _boston_subset()
  1885. X = X.astype(X_dtype)
  1886. X = _pairwise_estimator_convert_X(X, regressor_orig)
  1887. y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled
  1888. y = y.ravel()
  1889. regressor = clone(regressor_orig)
  1890. y = _enforce_estimator_tags_y(regressor, y)
  1891. if name in CROSS_DECOMPOSITION:
  1892. rnd = np.random.RandomState(0)
  1893. y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])
  1894. y_ = y_.T
  1895. else:
  1896. y_ = y
  1897. if readonly_memmap:
  1898. X, y, y_ = create_memmap_backed_data([X, y, y_])
  1899. if not hasattr(regressor, 'alphas') and hasattr(regressor, 'alpha'):
  1900. # linear regressors need to set alpha, but not generalized CV ones
  1901. regressor.alpha = 0.01
  1902. if name == 'PassiveAggressiveRegressor':
  1903. regressor.C = 0.01
  1904. # raises error on malformed input for fit
  1905. with assert_raises(ValueError, msg="The classifier {} does not"
  1906. " raise an error when incorrect/malformed input "
  1907. "data for fit is passed. The number of training "
  1908. "examples is not the same as the number of "
  1909. "labels. Perhaps use check_X_y in fit.".format(name)):
  1910. regressor.fit(X, y[:-1])
  1911. # fit
  1912. set_random_state(regressor)
  1913. regressor.fit(X, y_)
  1914. regressor.fit(X.tolist(), y_.tolist())
  1915. y_pred = regressor.predict(X)
  1916. assert y_pred.shape == y_.shape
  1917. # TODO: find out why PLS and CCA fail. RANSAC is random
  1918. # and furthermore assumes the presence of outliers, hence
  1919. # skipped
  1920. if not regressor._get_tags()["poor_score"]:
  1921. assert regressor.score(X, y_) > 0.5
  1922. @ignore_warnings
  1923. def check_regressors_no_decision_function(name, regressor_orig):
  1924. # checks whether regressors have decision_function or predict_proba
  1925. rng = np.random.RandomState(0)
  1926. regressor = clone(regressor_orig)
  1927. X = rng.normal(size=(10, 4))
  1928. X = _pairwise_estimator_convert_X(X, regressor_orig)
  1929. y = _enforce_estimator_tags_y(regressor, X[:, 0])
  1930. if hasattr(regressor, "n_components"):
  1931. # FIXME CCA, PLS is not robust to rank 1 effects
  1932. regressor.n_components = 1
  1933. regressor.fit(X, y)
  1934. funcs = ["decision_function", "predict_proba", "predict_log_proba"]
  1935. for func_name in funcs:
  1936. func = getattr(regressor, func_name, None)
  1937. if func is None:
  1938. # doesn't have function
  1939. continue
  1940. # has function. Should raise deprecation warning
  1941. msg = func_name
  1942. assert_warns_message(FutureWarning, msg, func, X)
  1943. @ignore_warnings(category=FutureWarning)
  1944. def check_class_weight_classifiers(name, classifier_orig):
  1945. if classifier_orig._get_tags()['binary_only']:
  1946. problems = [2]
  1947. else:
  1948. problems = [2, 3]
  1949. for n_centers in problems:
  1950. # create a very noisy dataset
  1951. X, y = make_blobs(centers=n_centers, random_state=0, cluster_std=20)
  1952. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
  1953. random_state=0)
  1954. # can't use gram_if_pairwise() here, setting up gram matrix manually
  1955. if _is_pairwise(classifier_orig):
  1956. X_test = rbf_kernel(X_test, X_train)
  1957. X_train = rbf_kernel(X_train, X_train)
  1958. n_centers = len(np.unique(y_train))
  1959. if n_centers == 2:
  1960. class_weight = {0: 1000, 1: 0.0001}
  1961. else:
  1962. class_weight = {0: 1000, 1: 0.0001, 2: 0.0001}
  1963. classifier = clone(classifier_orig).set_params(
  1964. class_weight=class_weight)
  1965. if hasattr(classifier, "n_iter"):
  1966. classifier.set_params(n_iter=100)
  1967. if hasattr(classifier, "max_iter"):
  1968. classifier.set_params(max_iter=1000)
  1969. if hasattr(classifier, "min_weight_fraction_leaf"):
  1970. classifier.set_params(min_weight_fraction_leaf=0.01)
  1971. if hasattr(classifier, "n_iter_no_change"):
  1972. classifier.set_params(n_iter_no_change=20)
  1973. set_random_state(classifier)
  1974. classifier.fit(X_train, y_train)
  1975. y_pred = classifier.predict(X_test)
  1976. # XXX: Generally can use 0.89 here. On Windows, LinearSVC gets
  1977. # 0.88 (Issue #9111)
  1978. assert np.mean(y_pred == 0) > 0.87
  1979. @ignore_warnings(category=FutureWarning)
  1980. def check_class_weight_balanced_classifiers(name, classifier_orig, X_train,
  1981. y_train, X_test, y_test, weights):
  1982. classifier = clone(classifier_orig)
  1983. if hasattr(classifier, "n_iter"):
  1984. classifier.set_params(n_iter=100)
  1985. if hasattr(classifier, "max_iter"):
  1986. classifier.set_params(max_iter=1000)
  1987. set_random_state(classifier)
  1988. classifier.fit(X_train, y_train)
  1989. y_pred = classifier.predict(X_test)
  1990. classifier.set_params(class_weight='balanced')
  1991. classifier.fit(X_train, y_train)
  1992. y_pred_balanced = classifier.predict(X_test)
  1993. assert (f1_score(y_test, y_pred_balanced, average='weighted') >
  1994. f1_score(y_test, y_pred, average='weighted'))
  1995. @ignore_warnings(category=FutureWarning)
  1996. def check_class_weight_balanced_linear_classifier(name, Classifier):
  1997. """Test class weights with non-contiguous class labels."""
  1998. # this is run on classes, not instances, though this should be changed
  1999. X = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
  2000. [1.0, 1.0], [1.0, 0.0]])
  2001. y = np.array([1, 1, 1, -1, -1])
  2002. classifier = Classifier()
  2003. if hasattr(classifier, "n_iter"):
  2004. # This is a very small dataset, default n_iter are likely to prevent
  2005. # convergence
  2006. classifier.set_params(n_iter=1000)
  2007. if hasattr(classifier, "max_iter"):
  2008. classifier.set_params(max_iter=1000)
  2009. if hasattr(classifier, 'cv'):
  2010. classifier.set_params(cv=3)
  2011. set_random_state(classifier)
  2012. # Let the model compute the class frequencies
  2013. classifier.set_params(class_weight='balanced')
  2014. coef_balanced = classifier.fit(X, y).coef_.copy()
  2015. # Count each label occurrence to reweight manually
  2016. n_samples = len(y)
  2017. n_classes = float(len(np.unique(y)))
  2018. class_weight = {1: n_samples / (np.sum(y == 1) * n_classes),
  2019. -1: n_samples / (np.sum(y == -1) * n_classes)}
  2020. classifier.set_params(class_weight=class_weight)
  2021. coef_manual = classifier.fit(X, y).coef_.copy()
  2022. assert_allclose(coef_balanced, coef_manual,
  2023. err_msg="Classifier %s is not computing"
  2024. " class_weight=balanced properly."
  2025. % name)
  2026. @ignore_warnings(category=FutureWarning)
  2027. def check_estimators_overwrite_params(name, estimator_orig):
  2028. if estimator_orig._get_tags()['binary_only']:
  2029. n_centers = 2
  2030. else:
  2031. n_centers = 3
  2032. X, y = make_blobs(random_state=0, n_samples=21, centers=n_centers)
  2033. # some want non-negative input
  2034. X -= X.min()
  2035. X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
  2036. estimator = clone(estimator_orig)
  2037. y = _enforce_estimator_tags_y(estimator, y)
  2038. set_random_state(estimator)
  2039. # Make a physical copy of the original estimator parameters before fitting.
  2040. params = estimator.get_params()
  2041. original_params = deepcopy(params)
  2042. # Fit the model
  2043. estimator.fit(X, y)
  2044. # Compare the state of the model parameters with the original parameters
  2045. new_params = estimator.get_params()
  2046. for param_name, original_value in original_params.items():
  2047. new_value = new_params[param_name]
  2048. # We should never change or mutate the internal state of input
  2049. # parameters by default. To check this we use the joblib.hash function
  2050. # that introspects recursively any subobjects to compute a checksum.
  2051. # The only exception to this rule of immutable constructor parameters
  2052. # is possible RandomState instance but in this check we explicitly
  2053. # fixed the random_state params recursively to be integer seeds.
  2054. assert joblib.hash(new_value) == joblib.hash(original_value), (
  2055. "Estimator %s should not change or mutate "
  2056. " the parameter %s from %s to %s during fit."
  2057. % (name, param_name, original_value, new_value))
  2058. @ignore_warnings(category=FutureWarning)
  2059. def check_no_attributes_set_in_init(name, estimator_orig):
  2060. """Check setting during init. """
  2061. estimator = clone(estimator_orig)
  2062. if hasattr(type(estimator).__init__, "deprecated_original"):
  2063. return
  2064. init_params = _get_args(type(estimator).__init__)
  2065. if IS_PYPY:
  2066. # __init__ signature has additional objects in PyPy
  2067. for key in ['obj']:
  2068. if key in init_params:
  2069. init_params.remove(key)
  2070. parents_init_params = [param for params_parent in
  2071. (_get_args(parent) for parent in
  2072. type(estimator).__mro__)
  2073. for param in params_parent]
  2074. # Test for no setting apart from parameters during init
  2075. invalid_attr = (set(vars(estimator)) - set(init_params)
  2076. - set(parents_init_params))
  2077. assert not invalid_attr, (
  2078. "Estimator %s should not set any attribute apart"
  2079. " from parameters during init. Found attributes %s."
  2080. % (name, sorted(invalid_attr)))
  2081. # Ensure that each parameter is set in init
  2082. invalid_attr = set(init_params) - set(vars(estimator)) - {"self"}
  2083. assert not invalid_attr, (
  2084. "Estimator %s should store all parameters"
  2085. " as an attribute during init. Did not find "
  2086. "attributes %s."
  2087. % (name, sorted(invalid_attr)))
  2088. @ignore_warnings(category=FutureWarning)
  2089. def check_sparsify_coefficients(name, estimator_orig):
  2090. X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1],
  2091. [-1, -2], [2, 2], [-2, -2]])
  2092. y = [1, 1, 1, 2, 2, 2, 3, 3, 3]
  2093. est = clone(estimator_orig)
  2094. est.fit(X, y)
  2095. pred_orig = est.predict(X)
  2096. # test sparsify with dense inputs
  2097. est.sparsify()
  2098. assert sparse.issparse(est.coef_)
  2099. pred = est.predict(X)
  2100. assert_array_equal(pred, pred_orig)
  2101. # pickle and unpickle with sparse coef_
  2102. est = pickle.loads(pickle.dumps(est))
  2103. assert sparse.issparse(est.coef_)
  2104. pred = est.predict(X)
  2105. assert_array_equal(pred, pred_orig)
  2106. @ignore_warnings(category=FutureWarning)
  2107. def check_classifier_data_not_an_array(name, estimator_orig):
  2108. X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1],
  2109. [0, 3], [1, 0], [2, 0], [4, 4], [2, 3], [3, 2]])
  2110. X = _pairwise_estimator_convert_X(X, estimator_orig)
  2111. y = [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2]
  2112. y = _enforce_estimator_tags_y(estimator_orig, y)
  2113. for obj_type in ["NotAnArray", "PandasDataframe"]:
  2114. check_estimators_data_not_an_array(name, estimator_orig, X, y,
  2115. obj_type)
  2116. @ignore_warnings(category=FutureWarning)
  2117. def check_regressor_data_not_an_array(name, estimator_orig):
  2118. X, y = _boston_subset(n_samples=50)
  2119. X = _pairwise_estimator_convert_X(X, estimator_orig)
  2120. y = _enforce_estimator_tags_y(estimator_orig, y)
  2121. for obj_type in ["NotAnArray", "PandasDataframe"]:
  2122. check_estimators_data_not_an_array(name, estimator_orig, X, y,
  2123. obj_type)
  2124. @ignore_warnings(category=FutureWarning)
  2125. def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type):
  2126. if name in CROSS_DECOMPOSITION:
  2127. raise SkipTest("Skipping check_estimators_data_not_an_array "
  2128. "for cross decomposition module as estimators "
  2129. "are not deterministic.")
  2130. # separate estimators to control random seeds
  2131. estimator_1 = clone(estimator_orig)
  2132. estimator_2 = clone(estimator_orig)
  2133. set_random_state(estimator_1)
  2134. set_random_state(estimator_2)
  2135. if obj_type not in ["NotAnArray", 'PandasDataframe']:
  2136. raise ValueError("Data type {0} not supported".format(obj_type))
  2137. if obj_type == "NotAnArray":
  2138. y_ = _NotAnArray(np.asarray(y))
  2139. X_ = _NotAnArray(np.asarray(X))
  2140. else:
  2141. # Here pandas objects (Series and DataFrame) are tested explicitly
  2142. # because some estimators may handle them (especially their indexing)
  2143. # specially.
  2144. try:
  2145. import pandas as pd
  2146. y_ = np.asarray(y)
  2147. if y_.ndim == 1:
  2148. y_ = pd.Series(y_)
  2149. else:
  2150. y_ = pd.DataFrame(y_)
  2151. X_ = pd.DataFrame(np.asarray(X))
  2152. except ImportError:
  2153. raise SkipTest("pandas is not installed: not checking estimators "
  2154. "for pandas objects.")
  2155. # fit
  2156. estimator_1.fit(X_, y_)
  2157. pred1 = estimator_1.predict(X_)
  2158. estimator_2.fit(X, y)
  2159. pred2 = estimator_2.predict(X)
  2160. assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
  2161. def check_parameters_default_constructible(name, Estimator):
  2162. # this check works on classes, not instances
  2163. # test default-constructibility
  2164. # get rid of deprecation warnings
  2165. if isinstance(Estimator, BaseEstimator):
  2166. # Convert estimator instance to its class
  2167. # TODO: Always convert to class in 0.24, because check_estimator() will
  2168. # only accept instances, not classes
  2169. Estimator = Estimator.__class__
  2170. with ignore_warnings(category=FutureWarning):
  2171. estimator = _construct_instance(Estimator)
  2172. # test cloning
  2173. clone(estimator)
  2174. # test __repr__
  2175. repr(estimator)
  2176. # test that set_params returns self
  2177. assert estimator.set_params() is estimator
  2178. # test if init does nothing but set parameters
  2179. # this is important for grid_search etc.
  2180. # We get the default parameters from init and then
  2181. # compare these against the actual values of the attributes.
  2182. # this comes from getattr. Gets rid of deprecation decorator.
  2183. init = getattr(estimator.__init__, 'deprecated_original',
  2184. estimator.__init__)
  2185. try:
  2186. def param_filter(p):
  2187. """Identify hyper parameters of an estimator"""
  2188. return (p.name != 'self' and
  2189. p.kind != p.VAR_KEYWORD and
  2190. p.kind != p.VAR_POSITIONAL)
  2191. init_params = [p for p in signature(init).parameters.values()
  2192. if param_filter(p)]
  2193. except (TypeError, ValueError):
  2194. # init is not a python function.
  2195. # true for mixins
  2196. return
  2197. params = estimator.get_params()
  2198. # they can need a non-default argument
  2199. init_params = init_params[len(getattr(
  2200. estimator, '_required_parameters', [])):]
  2201. for init_param in init_params:
  2202. assert init_param.default != init_param.empty, (
  2203. "parameter %s for %s has no default value"
  2204. % (init_param.name, type(estimator).__name__))
  2205. if type(init_param.default) is type:
  2206. assert init_param.default in [np.float64, np.int64]
  2207. else:
  2208. assert (type(init_param.default) in
  2209. [str, int, float, bool, tuple, type(None),
  2210. np.float64, types.FunctionType, joblib.Memory])
  2211. if init_param.name not in params.keys():
  2212. # deprecated parameter, not in get_params
  2213. assert init_param.default is None
  2214. continue
  2215. param_value = params[init_param.name]
  2216. if isinstance(param_value, np.ndarray):
  2217. assert_array_equal(param_value, init_param.default)
  2218. else:
  2219. if is_scalar_nan(param_value):
  2220. # Allows to set default parameters to np.nan
  2221. assert param_value is init_param.default, init_param.name
  2222. else:
  2223. assert param_value == init_param.default, init_param.name
  2224. # TODO: remove in 0.24
  2225. @deprecated("enforce_estimator_tags_y is deprecated in version "
  2226. "0.22 and will be removed in version 0.24.")
  2227. def enforce_estimator_tags_y(estimator, y):
  2228. return _enforce_estimator_tags_y(estimator, y)
  2229. def _enforce_estimator_tags_y(estimator, y):
  2230. # Estimators with a `requires_positive_y` tag only accept strictly positive
  2231. # data
  2232. if estimator._get_tags()["requires_positive_y"]:
  2233. # Create strictly positive y. The minimal increment above 0 is 1, as
  2234. # y could be of integer dtype.
  2235. y += 1 + abs(y.min())
  2236. # Estimators in mono_output_task_error raise ValueError if y is of 1-D
  2237. # Convert into a 2-D y for those estimators.
  2238. if estimator._get_tags()["multioutput_only"]:
  2239. return np.reshape(y, (-1, 1))
  2240. return y
  2241. def _enforce_estimator_tags_x(estimator, X):
  2242. # Estimators with a `_pairwise` tag only accept
  2243. # X of shape (`n_samples`, `n_samples`)
  2244. if hasattr(estimator, '_pairwise'):
  2245. X = X.dot(X.T)
  2246. # Estimators with `1darray` in `X_types` tag only accept
  2247. # X of shape (`n_samples`,)
  2248. if '1darray' in estimator._get_tags()['X_types']:
  2249. X = X[:, 0]
  2250. # Estimators with a `requires_positive_X` tag only accept
  2251. # strictly positive data
  2252. if estimator._get_tags()['requires_positive_X']:
  2253. X -= X.min()
  2254. return X
  2255. @ignore_warnings(category=FutureWarning)
  2256. def check_non_transformer_estimators_n_iter(name, estimator_orig):
  2257. # Test that estimators that are not transformers with a parameter
  2258. # max_iter, return the attribute of n_iter_ at least 1.
  2259. # These models are dependent on external solvers like
  2260. # libsvm and accessing the iter parameter is non-trivial.
  2261. not_run_check_n_iter = ['Ridge', 'SVR', 'NuSVR', 'NuSVC',
  2262. 'RidgeClassifier', 'SVC', 'RandomizedLasso',
  2263. 'LogisticRegressionCV', 'LinearSVC',
  2264. 'LogisticRegression']
  2265. # Tested in test_transformer_n_iter
  2266. not_run_check_n_iter += CROSS_DECOMPOSITION
  2267. if name in not_run_check_n_iter:
  2268. return
  2269. # LassoLars stops early for the default alpha=1.0 the iris dataset.
  2270. if name == 'LassoLars':
  2271. estimator = clone(estimator_orig).set_params(alpha=0.)
  2272. else:
  2273. estimator = clone(estimator_orig)
  2274. if hasattr(estimator, 'max_iter'):
  2275. iris = load_iris()
  2276. X, y_ = iris.data, iris.target
  2277. y_ = _enforce_estimator_tags_y(estimator, y_)
  2278. set_random_state(estimator, 0)
  2279. estimator.fit(X, y_)
  2280. assert estimator.n_iter_ >= 1
  2281. @ignore_warnings(category=FutureWarning)
  2282. def check_transformer_n_iter(name, estimator_orig):
  2283. # Test that transformers with a parameter max_iter, return the
  2284. # attribute of n_iter_ at least 1.
  2285. estimator = clone(estimator_orig)
  2286. if hasattr(estimator, "max_iter"):
  2287. if name in CROSS_DECOMPOSITION:
  2288. # Check using default data
  2289. X = [[0., 0., 1.], [1., 0., 0.], [2., 2., 2.], [2., 5., 4.]]
  2290. y_ = [[0.1, -0.2], [0.9, 1.1], [0.1, -0.5], [0.3, -0.2]]
  2291. else:
  2292. X, y_ = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  2293. random_state=0, n_features=2, cluster_std=0.1)
  2294. X -= X.min() - 0.1
  2295. set_random_state(estimator, 0)
  2296. estimator.fit(X, y_)
  2297. # These return a n_iter per component.
  2298. if name in CROSS_DECOMPOSITION:
  2299. for iter_ in estimator.n_iter_:
  2300. assert iter_ >= 1
  2301. else:
  2302. assert estimator.n_iter_ >= 1
  2303. @ignore_warnings(category=FutureWarning)
  2304. def check_get_params_invariance(name, estimator_orig):
  2305. # Checks if get_params(deep=False) is a subset of get_params(deep=True)
  2306. e = clone(estimator_orig)
  2307. shallow_params = e.get_params(deep=False)
  2308. deep_params = e.get_params(deep=True)
  2309. assert all(item in deep_params.items() for item in
  2310. shallow_params.items())
  2311. @ignore_warnings(category=FutureWarning)
  2312. def check_set_params(name, estimator_orig):
  2313. # Check that get_params() returns the same thing
  2314. # before and after set_params() with some fuzz
  2315. estimator = clone(estimator_orig)
  2316. orig_params = estimator.get_params(deep=False)
  2317. msg = ("get_params result does not match what was passed to set_params")
  2318. estimator.set_params(**orig_params)
  2319. curr_params = estimator.get_params(deep=False)
  2320. assert set(orig_params.keys()) == set(curr_params.keys()), msg
  2321. for k, v in curr_params.items():
  2322. assert orig_params[k] is v, msg
  2323. # some fuzz values
  2324. test_values = [-np.inf, np.inf, None]
  2325. test_params = deepcopy(orig_params)
  2326. for param_name in orig_params.keys():
  2327. default_value = orig_params[param_name]
  2328. for value in test_values:
  2329. test_params[param_name] = value
  2330. try:
  2331. estimator.set_params(**test_params)
  2332. except (TypeError, ValueError) as e:
  2333. e_type = e.__class__.__name__
  2334. # Exception occurred, possibly parameter validation
  2335. warnings.warn("{0} occurred during set_params of param {1} on "
  2336. "{2}. It is recommended to delay parameter "
  2337. "validation until fit.".format(e_type,
  2338. param_name,
  2339. name))
  2340. change_warning_msg = "Estimator's parameters changed after " \
  2341. "set_params raised {}".format(e_type)
  2342. params_before_exception = curr_params
  2343. curr_params = estimator.get_params(deep=False)
  2344. try:
  2345. assert (set(params_before_exception.keys()) ==
  2346. set(curr_params.keys()))
  2347. for k, v in curr_params.items():
  2348. assert params_before_exception[k] is v
  2349. except AssertionError:
  2350. warnings.warn(change_warning_msg)
  2351. else:
  2352. curr_params = estimator.get_params(deep=False)
  2353. assert (set(test_params.keys()) ==
  2354. set(curr_params.keys())), msg
  2355. for k, v in curr_params.items():
  2356. assert test_params[k] is v, msg
  2357. test_params[param_name] = default_value
  2358. @ignore_warnings(category=FutureWarning)
  2359. def check_classifiers_regression_target(name, estimator_orig):
  2360. # Check if classifier throws an exception when fed regression targets
  2361. X, y = load_boston(return_X_y=True)
  2362. e = clone(estimator_orig)
  2363. msg = 'Unknown label type: '
  2364. if not e._get_tags()["no_validation"]:
  2365. assert_raises_regex(ValueError, msg, e.fit, X, y)
  2366. @ignore_warnings(category=FutureWarning)
  2367. def check_decision_proba_consistency(name, estimator_orig):
  2368. # Check whether an estimator having both decision_function and
  2369. # predict_proba methods has outputs with perfect rank correlation.
  2370. centers = [(2, 2), (4, 4)]
  2371. X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
  2372. centers=centers, cluster_std=1.0, shuffle=True)
  2373. X_test = np.random.randn(20, 2) + 4
  2374. estimator = clone(estimator_orig)
  2375. if (hasattr(estimator, "decision_function") and
  2376. hasattr(estimator, "predict_proba")):
  2377. estimator.fit(X, y)
  2378. # Since the link function from decision_function() to predict_proba()
  2379. # is sometimes not precise enough (typically expit), we round to the
  2380. # 10th decimal to avoid numerical issues.
  2381. a = estimator.predict_proba(X_test)[:, 1].round(decimals=10)
  2382. b = estimator.decision_function(X_test).round(decimals=10)
  2383. assert_array_equal(rankdata(a), rankdata(b))
  2384. def check_outliers_fit_predict(name, estimator_orig):
  2385. # Check fit_predict for outlier detectors.
  2386. n_samples = 300
  2387. X, _ = make_blobs(n_samples=n_samples, random_state=0)
  2388. X = shuffle(X, random_state=7)
  2389. n_samples, n_features = X.shape
  2390. estimator = clone(estimator_orig)
  2391. set_random_state(estimator)
  2392. y_pred = estimator.fit_predict(X)
  2393. assert y_pred.shape == (n_samples,)
  2394. assert y_pred.dtype.kind == 'i'
  2395. assert_array_equal(np.unique(y_pred), np.array([-1, 1]))
  2396. # check fit_predict = fit.predict when the estimator has both a predict and
  2397. # a fit_predict method. recall that it is already assumed here that the
  2398. # estimator has a fit_predict method
  2399. if hasattr(estimator, 'predict'):
  2400. y_pred_2 = estimator.fit(X).predict(X)
  2401. assert_array_equal(y_pred, y_pred_2)
  2402. if hasattr(estimator, "contamination"):
  2403. # proportion of outliers equal to contamination parameter when not
  2404. # set to 'auto'
  2405. expected_outliers = 30
  2406. contamination = float(expected_outliers)/n_samples
  2407. estimator.set_params(contamination=contamination)
  2408. y_pred = estimator.fit_predict(X)
  2409. num_outliers = np.sum(y_pred != 1)
  2410. # num_outliers should be equal to expected_outliers unless
  2411. # there are ties in the decision_function values. this can
  2412. # only be tested for estimators with a decision_function
  2413. # method
  2414. if (num_outliers != expected_outliers and
  2415. hasattr(estimator, 'decision_function')):
  2416. decision = estimator.decision_function(X)
  2417. check_outlier_corruption(num_outliers, expected_outliers, decision)
  2418. # raises error when contamination is a scalar and not in [0,1]
  2419. for contamination in [-0.5, 2.3]:
  2420. estimator.set_params(contamination=contamination)
  2421. assert_raises(ValueError, estimator.fit_predict, X)
  2422. def check_fit_non_negative(name, estimator_orig):
  2423. # Check that proper warning is raised for non-negative X
  2424. # when tag requires_positive_X is present
  2425. X = np.array([[-1., 1], [-1., 1]])
  2426. y = np.array([1, 2])
  2427. estimator = clone(estimator_orig)
  2428. assert_raises_regex(ValueError, "Negative values in data passed to",
  2429. estimator.fit, X, y)
  2430. def check_fit_idempotent(name, estimator_orig):
  2431. # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would
  2432. # check that the estimated parameters during training (e.g. coefs_) are
  2433. # the same, but having a universal comparison function for those
  2434. # attributes is difficult and full of edge cases. So instead we check that
  2435. # predict(), predict_proba(), decision_function() and transform() return
  2436. # the same results.
  2437. check_methods = ["predict", "transform", "decision_function",
  2438. "predict_proba"]
  2439. rng = np.random.RandomState(0)
  2440. estimator = clone(estimator_orig)
  2441. set_random_state(estimator)
  2442. if 'warm_start' in estimator.get_params().keys():
  2443. estimator.set_params(warm_start=False)
  2444. n_samples = 100
  2445. X = rng.normal(loc=100, size=(n_samples, 2))
  2446. X = _pairwise_estimator_convert_X(X, estimator)
  2447. if is_regressor(estimator_orig):
  2448. y = rng.normal(size=n_samples)
  2449. else:
  2450. y = rng.randint(low=0, high=2, size=n_samples)
  2451. y = _enforce_estimator_tags_y(estimator, y)
  2452. train, test = next(ShuffleSplit(test_size=.2, random_state=rng).split(X))
  2453. X_train, y_train = _safe_split(estimator, X, y, train)
  2454. X_test, y_test = _safe_split(estimator, X, y, test, train)
  2455. # Fit for the first time
  2456. estimator.fit(X_train, y_train)
  2457. result = {method: getattr(estimator, method)(X_test)
  2458. for method in check_methods
  2459. if hasattr(estimator, method)}
  2460. # Fit again
  2461. set_random_state(estimator)
  2462. estimator.fit(X_train, y_train)
  2463. for method in check_methods:
  2464. if hasattr(estimator, method):
  2465. new_result = getattr(estimator, method)(X_test)
  2466. if np.issubdtype(new_result.dtype, np.floating):
  2467. tol = 2*np.finfo(new_result.dtype).eps
  2468. else:
  2469. tol = 2*np.finfo(np.float64).eps
  2470. assert_allclose_dense_sparse(
  2471. result[method], new_result,
  2472. atol=max(tol, 1e-9), rtol=max(tol, 1e-7),
  2473. err_msg="Idempotency check failed for method {}".format(method)
  2474. )
  2475. def check_n_features_in(name, estimator_orig):
  2476. # Make sure that n_features_in_ attribute doesn't exist until fit is
  2477. # called, and that its value is correct.
  2478. rng = np.random.RandomState(0)
  2479. estimator = clone(estimator_orig)
  2480. set_random_state(estimator)
  2481. if 'warm_start' in estimator.get_params():
  2482. estimator.set_params(warm_start=False)
  2483. n_samples = 100
  2484. X = rng.normal(loc=100, size=(n_samples, 2))
  2485. X = _pairwise_estimator_convert_X(X, estimator)
  2486. if is_regressor(estimator_orig):
  2487. y = rng.normal(size=n_samples)
  2488. else:
  2489. y = rng.randint(low=0, high=2, size=n_samples)
  2490. y = _enforce_estimator_tags_y(estimator, y)
  2491. assert not hasattr(estimator, 'n_features_in_')
  2492. estimator.fit(X, y)
  2493. if hasattr(estimator, 'n_features_in_'):
  2494. assert estimator.n_features_in_ == X.shape[1]
  2495. else:
  2496. warnings.warn(
  2497. "As of scikit-learn 0.23, estimators should expose a "
  2498. "n_features_in_ attribute, unless the 'no_validation' tag is "
  2499. "True. This attribute should be equal to the number of features "
  2500. "passed to the fit method. "
  2501. "An error will be raised from version 0.25 when calling "
  2502. "check_estimator(). "
  2503. "See SLEP010: "
  2504. "https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep010/proposal.html", # noqa
  2505. FutureWarning
  2506. )
  2507. def check_requires_y_none(name, estimator_orig):
  2508. # Make sure that an estimator with requires_y=True fails gracefully when
  2509. # given y=None
  2510. rng = np.random.RandomState(0)
  2511. estimator = clone(estimator_orig)
  2512. set_random_state(estimator)
  2513. n_samples = 100
  2514. X = rng.normal(loc=100, size=(n_samples, 2))
  2515. X = _pairwise_estimator_convert_X(X, estimator)
  2516. warning_msg = ("As of scikit-learn 0.23, estimators should have a "
  2517. "'requires_y' tag set to the appropriate value. "
  2518. "The default value of the tag is False. "
  2519. "An error will be raised from version 0.25 when calling "
  2520. "check_estimator() if the tag isn't properly set.")
  2521. expected_err_msgs = (
  2522. "requires y to be passed, but the target y is None",
  2523. "Expected array-like (array or non-string sequence), got None",
  2524. "y should be a 1d array"
  2525. )
  2526. try:
  2527. estimator.fit(X, None)
  2528. except ValueError as ve:
  2529. if not any(msg in str(ve) for msg in expected_err_msgs):
  2530. warnings.warn(warning_msg, FutureWarning)