PageRenderTime 52ms CodeModel.GetById 15ms RepoModel.GetById 0ms app.codeStats 1ms

/sklearn/utils/estimator_checks.py

http://github.com/scikit-learn/scikit-learn
Python | 3047 lines | 2528 code | 259 blank | 260 comment | 227 complexity | 75f5c0f5bd11b82772c418f6bdfe02bc MD5 | raw file
Possible License(s): BSD-3-Clause

Large files files are truncated, but you can click here to view the full file

  1. import types
  2. import warnings
  3. import sys
  4. import traceback
  5. import pickle
  6. import re
  7. from copy import deepcopy
  8. from functools import partial
  9. from itertools import chain
  10. from inspect import signature
  11. import numpy as np
  12. from scipy import sparse
  13. from scipy.stats import rankdata
  14. import joblib
  15. from . import IS_PYPY
  16. from .. import config_context
  17. from ._testing import assert_raises, _get_args
  18. from ._testing import assert_raises_regex
  19. from ._testing import assert_raise_message
  20. from ._testing import assert_array_equal
  21. from ._testing import assert_array_almost_equal
  22. from ._testing import assert_allclose
  23. from ._testing import assert_allclose_dense_sparse
  24. from ._testing import assert_warns_message
  25. from ._testing import set_random_state
  26. from ._testing import SkipTest
  27. from ._testing import ignore_warnings
  28. from ._testing import create_memmap_backed_data
  29. from . import is_scalar_nan
  30. from ..discriminant_analysis import LinearDiscriminantAnalysis
  31. from ..linear_model import Ridge
  32. from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
  33. RegressorMixin, is_outlier_detector, BaseEstimator)
  34. from ..metrics import accuracy_score, adjusted_rand_score, f1_score
  35. from ..random_projection import BaseRandomProjection
  36. from ..feature_selection import SelectKBest
  37. from ..pipeline import make_pipeline
  38. from ..exceptions import DataConversionWarning
  39. from ..exceptions import NotFittedError
  40. from ..exceptions import SkipTestWarning
  41. from ..model_selection import train_test_split
  42. from ..model_selection import ShuffleSplit
  43. from ..model_selection._validation import _safe_split
  44. from ..metrics.pairwise import (rbf_kernel, linear_kernel, pairwise_distances)
  45. from .import shuffle
  46. from .import deprecated
  47. from .validation import has_fit_parameter, _num_samples
  48. from ..preprocessing import StandardScaler
  49. from ..datasets import (load_iris, load_boston, make_blobs,
  50. make_multilabel_classification, make_regression)
  51. BOSTON = None
  52. CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
  53. def _yield_checks(name, estimator):
  54. tags = estimator._get_tags()
  55. yield check_no_attributes_set_in_init
  56. yield check_estimators_dtypes
  57. yield check_fit_score_takes_y
  58. yield check_sample_weights_pandas_series
  59. yield check_sample_weights_not_an_array
  60. yield check_sample_weights_list
  61. yield check_sample_weights_shape
  62. yield check_sample_weights_invariance
  63. yield check_estimators_fit_returns_self
  64. yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
  65. # Check that all estimator yield informative messages when
  66. # trained on empty datasets
  67. if not tags["no_validation"]:
  68. yield check_complex_data
  69. yield check_dtype_object
  70. yield check_estimators_empty_data_messages
  71. if name not in CROSS_DECOMPOSITION:
  72. # cross-decomposition's "transform" returns X and Y
  73. yield check_pipeline_consistency
  74. if not tags["allow_nan"] and not tags["no_validation"]:
  75. # Test that all estimators check their input for NaN's and infs
  76. yield check_estimators_nan_inf
  77. if _is_pairwise(estimator):
  78. # Check that pairwise estimator throws error on non-square input
  79. yield check_nonsquare_error
  80. yield check_estimators_overwrite_params
  81. if hasattr(estimator, 'sparsify'):
  82. yield check_sparsify_coefficients
  83. yield check_estimator_sparse_data
  84. # Test that estimators can be pickled, and once pickled
  85. # give the same answer as before.
  86. yield check_estimators_pickle
  87. def _yield_classifier_checks(name, classifier):
  88. tags = classifier._get_tags()
  89. # test classifiers can handle non-array data and pandas objects
  90. yield check_classifier_data_not_an_array
  91. # test classifiers trained on a single label always return this label
  92. yield check_classifiers_one_label
  93. yield check_classifiers_classes
  94. yield check_estimators_partial_fit_n_features
  95. if tags["multioutput"]:
  96. yield check_classifier_multioutput
  97. # basic consistency testing
  98. yield check_classifiers_train
  99. yield partial(check_classifiers_train, readonly_memmap=True)
  100. yield partial(check_classifiers_train, readonly_memmap=True,
  101. X_dtype='float32')
  102. yield check_classifiers_regression_target
  103. if tags["multilabel"]:
  104. yield check_classifiers_multilabel_representation_invariance
  105. if not tags["no_validation"]:
  106. yield check_supervised_y_no_nan
  107. yield check_supervised_y_2d
  108. if tags["requires_fit"]:
  109. yield check_estimators_unfitted
  110. if 'class_weight' in classifier.get_params().keys():
  111. yield check_class_weight_classifiers
  112. yield check_non_transformer_estimators_n_iter
  113. # test if predict_proba is a monotonic transformation of decision_function
  114. yield check_decision_proba_consistency
  115. @ignore_warnings(category=FutureWarning)
  116. def check_supervised_y_no_nan(name, estimator_orig):
  117. # Checks that the Estimator targets are not NaN.
  118. estimator = clone(estimator_orig)
  119. rng = np.random.RandomState(888)
  120. X = rng.randn(10, 5)
  121. y = np.full(10, np.inf)
  122. y = _enforce_estimator_tags_y(estimator, y)
  123. errmsg = "Input contains NaN, infinity or a value too large for " \
  124. "dtype('float64')."
  125. try:
  126. estimator.fit(X, y)
  127. except ValueError as e:
  128. if str(e) != errmsg:
  129. raise ValueError("Estimator {0} raised error as expected, but "
  130. "does not match expected error message"
  131. .format(name))
  132. else:
  133. raise ValueError("Estimator {0} should have raised error on fitting "
  134. "array y with NaN value.".format(name))
  135. def _yield_regressor_checks(name, regressor):
  136. tags = regressor._get_tags()
  137. # TODO: test with intercept
  138. # TODO: test with multiple responses
  139. # basic testing
  140. yield check_regressors_train
  141. yield partial(check_regressors_train, readonly_memmap=True)
  142. yield partial(check_regressors_train, readonly_memmap=True,
  143. X_dtype='float32')
  144. yield check_regressor_data_not_an_array
  145. yield check_estimators_partial_fit_n_features
  146. if tags["multioutput"]:
  147. yield check_regressor_multioutput
  148. yield check_regressors_no_decision_function
  149. if not tags["no_validation"]:
  150. yield check_supervised_y_2d
  151. yield check_supervised_y_no_nan
  152. if name != 'CCA':
  153. # check that the regressor handles int input
  154. yield check_regressors_int
  155. if tags["requires_fit"]:
  156. yield check_estimators_unfitted
  157. yield check_non_transformer_estimators_n_iter
  158. def _yield_transformer_checks(name, transformer):
  159. # All transformers should either deal with sparse data or raise an
  160. # exception with type TypeError and an intelligible error message
  161. if not transformer._get_tags()["no_validation"]:
  162. yield check_transformer_data_not_an_array
  163. # these don't actually fit the data, so don't raise errors
  164. yield check_transformer_general
  165. yield partial(check_transformer_general, readonly_memmap=True)
  166. if not transformer._get_tags()["stateless"]:
  167. yield check_transformers_unfitted
  168. # Dependent on external solvers and hence accessing the iter
  169. # param is non-trivial.
  170. external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
  171. 'RandomizedLasso', 'LogisticRegressionCV']
  172. if name not in external_solver:
  173. yield check_transformer_n_iter
  174. def _yield_clustering_checks(name, clusterer):
  175. yield check_clusterer_compute_labels_predict
  176. if name not in ('WardAgglomeration', "FeatureAgglomeration"):
  177. # this is clustering on the features
  178. # let's not test that here.
  179. yield check_clustering
  180. yield partial(check_clustering, readonly_memmap=True)
  181. yield check_estimators_partial_fit_n_features
  182. yield check_non_transformer_estimators_n_iter
  183. def _yield_outliers_checks(name, estimator):
  184. # checks for outlier detectors that have a fit_predict method
  185. if hasattr(estimator, 'fit_predict'):
  186. yield check_outliers_fit_predict
  187. # checks for estimators that can be used on a test set
  188. if hasattr(estimator, 'predict'):
  189. yield check_outliers_train
  190. yield partial(check_outliers_train, readonly_memmap=True)
  191. # test outlier detectors can handle non-array data
  192. yield check_classifier_data_not_an_array
  193. # test if NotFittedError is raised
  194. if estimator._get_tags()["requires_fit"]:
  195. yield check_estimators_unfitted
  196. def _yield_all_checks(name, estimator):
  197. tags = estimator._get_tags()
  198. if "2darray" not in tags["X_types"]:
  199. warnings.warn("Can't test estimator {} which requires input "
  200. " of type {}".format(name, tags["X_types"]),
  201. SkipTestWarning)
  202. return
  203. if tags["_skip_test"]:
  204. warnings.warn("Explicit SKIP via _skip_test tag for estimator "
  205. "{}.".format(name),
  206. SkipTestWarning)
  207. return
  208. for check in _yield_checks(name, estimator):
  209. yield check
  210. if is_classifier(estimator):
  211. for check in _yield_classifier_checks(name, estimator):
  212. yield check
  213. if is_regressor(estimator):
  214. for check in _yield_regressor_checks(name, estimator):
  215. yield check
  216. if hasattr(estimator, 'transform'):
  217. for check in _yield_transformer_checks(name, estimator):
  218. yield check
  219. if isinstance(estimator, ClusterMixin):
  220. for check in _yield_clustering_checks(name, estimator):
  221. yield check
  222. if is_outlier_detector(estimator):
  223. for check in _yield_outliers_checks(name, estimator):
  224. yield check
  225. yield check_fit2d_predict1d
  226. yield check_methods_subset_invariance
  227. yield check_fit2d_1sample
  228. yield check_fit2d_1feature
  229. yield check_fit1d
  230. yield check_get_params_invariance
  231. yield check_set_params
  232. yield check_dict_unchanged
  233. yield check_dont_overwrite_parameters
  234. yield check_fit_idempotent
  235. if not tags["no_validation"]:
  236. yield check_n_features_in
  237. if tags["requires_y"]:
  238. yield check_requires_y_none
  239. if tags["requires_positive_X"]:
  240. yield check_fit_non_negative
  241. def _set_check_estimator_ids(obj):
  242. """Create pytest ids for checks.
  243. When `obj` is an estimator, this returns the pprint version of the
  244. estimator (with `print_changed_only=True`). When `obj` is a function, the
  245. name of the function is returned with its keyworld arguments.
  246. `_set_check_estimator_ids` is designed to be used as the `id` in
  247. `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)`
  248. is yielding estimators and checks.
  249. Parameters
  250. ----------
  251. obj : estimator or function
  252. Items generated by `check_estimator`
  253. Returns
  254. -------
  255. id : string or None
  256. See also
  257. --------
  258. check_estimator
  259. """
  260. if callable(obj):
  261. if not isinstance(obj, partial):
  262. return obj.__name__
  263. if not obj.keywords:
  264. return obj.func.__name__
  265. kwstring = ",".join(["{}={}".format(k, v)
  266. for k, v in obj.keywords.items()])
  267. return "{}({})".format(obj.func.__name__, kwstring)
  268. if hasattr(obj, "get_params"):
  269. with config_context(print_changed_only=True):
  270. return re.sub(r"\s", "", str(obj))
  271. def _construct_instance(Estimator):
  272. """Construct Estimator instance if possible"""
  273. required_parameters = getattr(Estimator, "_required_parameters", [])
  274. if len(required_parameters):
  275. if required_parameters in (["estimator"], ["base_estimator"]):
  276. if issubclass(Estimator, RegressorMixin):
  277. estimator = Estimator(Ridge())
  278. else:
  279. estimator = Estimator(LinearDiscriminantAnalysis())
  280. else:
  281. raise SkipTest("Can't instantiate estimator {} which requires "
  282. "parameters {}".format(Estimator.__name__,
  283. required_parameters))
  284. else:
  285. estimator = Estimator()
  286. return estimator
  287. # TODO: probably not needed anymore in 0.24 since _generate_class_checks should
  288. # be removed too. Just put this in check_estimator()
  289. def _generate_instance_checks(name, estimator):
  290. """Generate instance checks."""
  291. yield from ((estimator, partial(check, name))
  292. for check in _yield_all_checks(name, estimator))
  293. # TODO: remove this in 0.24
  294. def _generate_class_checks(Estimator):
  295. """Generate class checks."""
  296. name = Estimator.__name__
  297. yield (Estimator, partial(check_parameters_default_constructible, name))
  298. estimator = _construct_instance(Estimator)
  299. yield from _generate_instance_checks(name, estimator)
  300. def _mark_xfail_checks(estimator, check, pytest):
  301. """Mark (estimator, check) pairs with xfail according to the
  302. _xfail_checks_ tag"""
  303. if isinstance(estimator, type):
  304. # try to construct estimator instance, if it is unable to then
  305. # return the estimator class, ignoring the tag
  306. # TODO: remove this if block in 0.24 since passing instances isn't
  307. # supported anymore
  308. try:
  309. estimator = _construct_instance(estimator)
  310. except Exception:
  311. return estimator, check
  312. xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
  313. check_name = _set_check_estimator_ids(check)
  314. if check_name not in xfail_checks:
  315. # check isn't part of the xfail_checks tags, just return it
  316. return estimator, check
  317. else:
  318. # check is in the tag, mark it as xfail for pytest
  319. reason = xfail_checks[check_name]
  320. return pytest.param(estimator, check,
  321. marks=pytest.mark.xfail(reason=reason))
  322. def parametrize_with_checks(estimators):
  323. """Pytest specific decorator for parametrizing estimator checks.
  324. The `id` of each check is set to be a pprint version of the estimator
  325. and the name of the check with its keyword arguments.
  326. This allows to use `pytest -k` to specify which tests to run::
  327. pytest test_check_estimators.py -k check_estimators_fit_returns_self
  328. Parameters
  329. ----------
  330. estimators : list of estimators objects or classes
  331. Estimators to generated checks for.
  332. .. deprecated:: 0.23
  333. Passing a class is deprecated from version 0.23, and won't be
  334. supported in 0.24. Pass an instance instead.
  335. Returns
  336. -------
  337. decorator : `pytest.mark.parametrize`
  338. Examples
  339. --------
  340. >>> from sklearn.utils.estimator_checks import parametrize_with_checks
  341. >>> from sklearn.linear_model import LogisticRegression
  342. >>> from sklearn.tree import DecisionTreeRegressor
  343. >>> @parametrize_with_checks([LogisticRegression(),
  344. ... DecisionTreeRegressor()])
  345. ... def test_sklearn_compatible_estimator(estimator, check):
  346. ... check(estimator)
  347. """
  348. import pytest
  349. if any(isinstance(est, type) for est in estimators):
  350. # TODO: remove class support in 0.24 and update docstrings
  351. msg = ("Passing a class is deprecated since version 0.23 "
  352. "and won't be supported in 0.24."
  353. "Please pass an instance instead.")
  354. warnings.warn(msg, FutureWarning)
  355. checks_generator = chain.from_iterable(
  356. check_estimator(estimator, generate_only=True)
  357. for estimator in estimators)
  358. checks_with_marks = (
  359. _mark_xfail_checks(estimator, check, pytest)
  360. for estimator, check in checks_generator)
  361. return pytest.mark.parametrize("estimator, check", checks_with_marks,
  362. ids=_set_check_estimator_ids)
  363. def check_estimator(Estimator, generate_only=False):
  364. """Check if estimator adheres to scikit-learn conventions.
  365. This estimator will run an extensive test-suite for input validation,
  366. shapes, etc, making sure that the estimator complies with `scikit-learn`
  367. conventions as detailed in :ref:`rolling_your_own_estimator`.
  368. Additional tests for classifiers, regressors, clustering or transformers
  369. will be run if the Estimator class inherits from the corresponding mixin
  370. from sklearn.base.
  371. This test can be applied to classes or instances.
  372. Classes currently have some additional tests that related to construction,
  373. while passing instances allows the testing of multiple options. However,
  374. support for classes is deprecated since version 0.23 and will be removed
  375. in version 0.24 (class checks will still be run on the instances).
  376. Setting `generate_only=True` returns a generator that yields (estimator,
  377. check) tuples where the check can be called independently from each
  378. other, i.e. `check(estimator)`. This allows all checks to be run
  379. independently and report the checks that are failing.
  380. scikit-learn provides a pytest specific decorator,
  381. :func:`~sklearn.utils.parametrize_with_checks`, making it easier to test
  382. multiple estimators.
  383. Parameters
  384. ----------
  385. estimator : estimator object
  386. Estimator to check. Estimator is a class object or instance.
  387. .. deprecated:: 0.23
  388. Passing a class is deprecated from version 0.23, and won't be
  389. supported in 0.24. Pass an instance instead.
  390. generate_only : bool, optional (default=False)
  391. When `False`, checks are evaluated when `check_estimator` is called.
  392. When `True`, `check_estimator` returns a generator that yields
  393. (estimator, check) tuples. The check is run by calling
  394. `check(estimator)`.
  395. .. versionadded:: 0.22
  396. Returns
  397. -------
  398. checks_generator : generator
  399. Generator that yields (estimator, check) tuples. Returned when
  400. `generate_only=True`.
  401. """
  402. # TODO: remove class support in 0.24 and update docstrings
  403. if isinstance(Estimator, type):
  404. # got a class
  405. msg = ("Passing a class is deprecated since version 0.23 "
  406. "and won't be supported in 0.24."
  407. "Please pass an instance instead.")
  408. warnings.warn(msg, FutureWarning)
  409. checks_generator = _generate_class_checks(Estimator)
  410. else:
  411. # got an instance
  412. estimator = Estimator
  413. name = type(estimator).__name__
  414. checks_generator = _generate_instance_checks(name, estimator)
  415. if generate_only:
  416. return checks_generator
  417. for estimator, check in checks_generator:
  418. try:
  419. check(estimator)
  420. except SkipTest as exception:
  421. # the only SkipTest thrown currently results from not
  422. # being able to import pandas.
  423. warnings.warn(str(exception), SkipTestWarning)
  424. def _boston_subset(n_samples=200):
  425. global BOSTON
  426. if BOSTON is None:
  427. X, y = load_boston(return_X_y=True)
  428. X, y = shuffle(X, y, random_state=0)
  429. X, y = X[:n_samples], y[:n_samples]
  430. X = StandardScaler().fit_transform(X)
  431. BOSTON = X, y
  432. return BOSTON
  433. @deprecated("set_checking_parameters is deprecated in version "
  434. "0.22 and will be removed in version 0.24.")
  435. def set_checking_parameters(estimator):
  436. _set_checking_parameters(estimator)
  437. def _set_checking_parameters(estimator):
  438. # set parameters to speed up some estimators and
  439. # avoid deprecated behaviour
  440. params = estimator.get_params()
  441. name = estimator.__class__.__name__
  442. if ("n_iter" in params and name != "TSNE"):
  443. estimator.set_params(n_iter=5)
  444. if "max_iter" in params:
  445. if estimator.max_iter is not None:
  446. estimator.set_params(max_iter=min(5, estimator.max_iter))
  447. # LinearSVR, LinearSVC
  448. if estimator.__class__.__name__ in ['LinearSVR', 'LinearSVC']:
  449. estimator.set_params(max_iter=20)
  450. # NMF
  451. if estimator.__class__.__name__ == 'NMF':
  452. estimator.set_params(max_iter=100)
  453. # MLP
  454. if estimator.__class__.__name__ in ['MLPClassifier', 'MLPRegressor']:
  455. estimator.set_params(max_iter=100)
  456. if "n_resampling" in params:
  457. # randomized lasso
  458. estimator.set_params(n_resampling=5)
  459. if "n_estimators" in params:
  460. estimator.set_params(n_estimators=min(5, estimator.n_estimators))
  461. if "max_trials" in params:
  462. # RANSAC
  463. estimator.set_params(max_trials=10)
  464. if "n_init" in params:
  465. # K-Means
  466. estimator.set_params(n_init=2)
  467. if name == 'TruncatedSVD':
  468. # TruncatedSVD doesn't run with n_components = n_features
  469. # This is ugly :-/
  470. estimator.n_components = 1
  471. if hasattr(estimator, "n_clusters"):
  472. estimator.n_clusters = min(estimator.n_clusters, 2)
  473. if hasattr(estimator, "n_best"):
  474. estimator.n_best = 1
  475. if name == "SelectFdr":
  476. # be tolerant of noisy datasets (not actually speed)
  477. estimator.set_params(alpha=.5)
  478. if name == "TheilSenRegressor":
  479. estimator.max_subpopulation = 100
  480. if isinstance(estimator, BaseRandomProjection):
  481. # Due to the jl lemma and often very few samples, the number
  482. # of components of the random matrix projection will be probably
  483. # greater than the number of features.
  484. # So we impose a smaller number (avoid "auto" mode)
  485. estimator.set_params(n_components=2)
  486. if isinstance(estimator, SelectKBest):
  487. # SelectKBest has a default of k=10
  488. # which is more feature than we have in most case.
  489. estimator.set_params(k=1)
  490. if name in ('HistGradientBoostingClassifier',
  491. 'HistGradientBoostingRegressor'):
  492. # The default min_samples_leaf (20) isn't appropriate for small
  493. # datasets (only very shallow trees are built) that the checks use.
  494. estimator.set_params(min_samples_leaf=5)
  495. # Speed-up by reducing the number of CV or splits for CV estimators
  496. loo_cv = ['RidgeCV']
  497. if name not in loo_cv and hasattr(estimator, 'cv'):
  498. estimator.set_params(cv=3)
  499. if hasattr(estimator, 'n_splits'):
  500. estimator.set_params(n_splits=3)
  501. if name == 'OneHotEncoder':
  502. estimator.set_params(handle_unknown='ignore')
  503. class _NotAnArray:
  504. """An object that is convertible to an array
  505. Parameters
  506. ----------
  507. data : array_like
  508. The data.
  509. """
  510. def __init__(self, data):
  511. self.data = np.asarray(data)
  512. def __array__(self, dtype=None):
  513. return self.data
  514. def __array_function__(self, func, types, args, kwargs):
  515. if func.__name__ == "may_share_memory":
  516. return True
  517. raise TypeError("Don't want to call array_function {}!".format(
  518. func.__name__))
  519. @deprecated("NotAnArray is deprecated in version "
  520. "0.22 and will be removed in version 0.24.")
  521. class NotAnArray(_NotAnArray):
  522. # TODO: remove in 0.24
  523. pass
  524. def _is_pairwise(estimator):
  525. """Returns True if estimator has a _pairwise attribute set to True.
  526. Parameters
  527. ----------
  528. estimator : object
  529. Estimator object to test.
  530. Returns
  531. -------
  532. out : bool
  533. True if _pairwise is set to True and False otherwise.
  534. """
  535. return bool(getattr(estimator, "_pairwise", False))
  536. def _is_pairwise_metric(estimator):
  537. """Returns True if estimator accepts pairwise metric.
  538. Parameters
  539. ----------
  540. estimator : object
  541. Estimator object to test.
  542. Returns
  543. -------
  544. out : bool
  545. True if _pairwise is set to True and False otherwise.
  546. """
  547. metric = getattr(estimator, "metric", None)
  548. return bool(metric == 'precomputed')
  549. @deprecated("pairwise_estimator_convert_X is deprecated in version "
  550. "0.22 and will be removed in version 0.24.")
  551. def pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
  552. return _pairwise_estimator_convert_X(X, estimator, kernel)
  553. def _pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
  554. if _is_pairwise_metric(estimator):
  555. return pairwise_distances(X, metric='euclidean')
  556. if _is_pairwise(estimator):
  557. return kernel(X, X)
  558. return X
  559. def _generate_sparse_matrix(X_csr):
  560. """Generate sparse matrices with {32,64}bit indices of diverse format
  561. Parameters
  562. ----------
  563. X_csr: CSR Matrix
  564. Input matrix in CSR format
  565. Returns
  566. -------
  567. out: iter(Matrices)
  568. In format['dok', 'lil', 'dia', 'bsr', 'csr', 'csc', 'coo',
  569. 'coo_64', 'csc_64', 'csr_64']
  570. """
  571. assert X_csr.format == 'csr'
  572. yield 'csr', X_csr.copy()
  573. for sparse_format in ['dok', 'lil', 'dia', 'bsr', 'csc', 'coo']:
  574. yield sparse_format, X_csr.asformat(sparse_format)
  575. # Generate large indices matrix only if its supported by scipy
  576. X_coo = X_csr.asformat('coo')
  577. X_coo.row = X_coo.row.astype('int64')
  578. X_coo.col = X_coo.col.astype('int64')
  579. yield "coo_64", X_coo
  580. for sparse_format in ['csc', 'csr']:
  581. X = X_csr.asformat(sparse_format)
  582. X.indices = X.indices.astype('int64')
  583. X.indptr = X.indptr.astype('int64')
  584. yield sparse_format + "_64", X
  585. def check_estimator_sparse_data(name, estimator_orig):
  586. rng = np.random.RandomState(0)
  587. X = rng.rand(40, 10)
  588. X[X < .8] = 0
  589. X = _pairwise_estimator_convert_X(X, estimator_orig)
  590. X_csr = sparse.csr_matrix(X)
  591. tags = estimator_orig._get_tags()
  592. if tags['binary_only']:
  593. y = (2 * rng.rand(40)).astype(np.int)
  594. else:
  595. y = (4 * rng.rand(40)).astype(np.int)
  596. # catch deprecation warnings
  597. with ignore_warnings(category=FutureWarning):
  598. estimator = clone(estimator_orig)
  599. y = _enforce_estimator_tags_y(estimator, y)
  600. for matrix_format, X in _generate_sparse_matrix(X_csr):
  601. # catch deprecation warnings
  602. with ignore_warnings(category=FutureWarning):
  603. estimator = clone(estimator_orig)
  604. if name in ['Scaler', 'StandardScaler']:
  605. estimator.set_params(with_mean=False)
  606. # fit and predict
  607. try:
  608. with ignore_warnings(category=FutureWarning):
  609. estimator.fit(X, y)
  610. if hasattr(estimator, "predict"):
  611. pred = estimator.predict(X)
  612. if tags['multioutput_only']:
  613. assert pred.shape == (X.shape[0], 1)
  614. else:
  615. assert pred.shape == (X.shape[0],)
  616. if hasattr(estimator, 'predict_proba'):
  617. probs = estimator.predict_proba(X)
  618. if tags['binary_only']:
  619. expected_probs_shape = (X.shape[0], 2)
  620. else:
  621. expected_probs_shape = (X.shape[0], 4)
  622. assert probs.shape == expected_probs_shape
  623. except (TypeError, ValueError) as e:
  624. if 'sparse' not in repr(e).lower():
  625. if "64" in matrix_format:
  626. msg = ("Estimator %s doesn't seem to support %s matrix, "
  627. "and is not failing gracefully, e.g. by using "
  628. "check_array(X, accept_large_sparse=False)")
  629. raise AssertionError(msg % (name, matrix_format))
  630. else:
  631. print("Estimator %s doesn't seem to fail gracefully on "
  632. "sparse data: error message state explicitly that "
  633. "sparse input is not supported if this is not"
  634. " the case." % name)
  635. raise
  636. except Exception:
  637. print("Estimator %s doesn't seem to fail gracefully on "
  638. "sparse data: it should raise a TypeError if sparse input "
  639. "is explicitly not supported." % name)
  640. raise
  641. @ignore_warnings(category=FutureWarning)
  642. def check_sample_weights_pandas_series(name, estimator_orig):
  643. # check that estimators will accept a 'sample_weight' parameter of
  644. # type pandas.Series in the 'fit' function.
  645. estimator = clone(estimator_orig)
  646. if has_fit_parameter(estimator, "sample_weight"):
  647. try:
  648. import pandas as pd
  649. X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
  650. [2, 1], [2, 2], [2, 3], [2, 4],
  651. [3, 1], [3, 2], [3, 3], [3, 4]])
  652. X = pd.DataFrame(_pairwise_estimator_convert_X(X, estimator_orig))
  653. y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
  654. weights = pd.Series([1] * 12)
  655. if estimator._get_tags()["multioutput_only"]:
  656. y = pd.DataFrame(y)
  657. try:
  658. estimator.fit(X, y, sample_weight=weights)
  659. except ValueError:
  660. raise ValueError("Estimator {0} raises error if "
  661. "'sample_weight' parameter is of "
  662. "type pandas.Series".format(name))
  663. except ImportError:
  664. raise SkipTest("pandas is not installed: not testing for "
  665. "input of type pandas.Series to class weight.")
  666. @ignore_warnings(category=(FutureWarning))
  667. def check_sample_weights_not_an_array(name, estimator_orig):
  668. # check that estimators will accept a 'sample_weight' parameter of
  669. # type _NotAnArray in the 'fit' function.
  670. estimator = clone(estimator_orig)
  671. if has_fit_parameter(estimator, "sample_weight"):
  672. X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
  673. [2, 1], [2, 2], [2, 3], [2, 4],
  674. [3, 1], [3, 2], [3, 3], [3, 4]])
  675. X = _NotAnArray(pairwise_estimator_convert_X(X, estimator_orig))
  676. y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
  677. weights = _NotAnArray([1] * 12)
  678. if estimator._get_tags()["multioutput_only"]:
  679. y = _NotAnArray(y.data.reshape(-1, 1))
  680. estimator.fit(X, y, sample_weight=weights)
  681. @ignore_warnings(category=(FutureWarning))
  682. def check_sample_weights_list(name, estimator_orig):
  683. # check that estimators will accept a 'sample_weight' parameter of
  684. # type list in the 'fit' function.
  685. if has_fit_parameter(estimator_orig, "sample_weight"):
  686. estimator = clone(estimator_orig)
  687. rnd = np.random.RandomState(0)
  688. n_samples = 30
  689. X = _pairwise_estimator_convert_X(rnd.uniform(size=(n_samples, 3)),
  690. estimator_orig)
  691. if estimator._get_tags()['binary_only']:
  692. y = np.arange(n_samples) % 2
  693. else:
  694. y = np.arange(n_samples) % 3
  695. y = _enforce_estimator_tags_y(estimator, y)
  696. sample_weight = [3] * n_samples
  697. # Test that estimators don't raise any exception
  698. estimator.fit(X, y, sample_weight=sample_weight)
  699. @ignore_warnings(category=FutureWarning)
  700. def check_sample_weights_shape(name, estimator_orig):
  701. # check that estimators raise an error if sample_weight
  702. # shape mismatches the input
  703. if (has_fit_parameter(estimator_orig, "sample_weight") and
  704. not (hasattr(estimator_orig, "_pairwise")
  705. and estimator_orig._pairwise)):
  706. estimator = clone(estimator_orig)
  707. X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
  708. [2, 1], [2, 1], [2, 1], [2, 1],
  709. [3, 3], [3, 3], [3, 3], [3, 3],
  710. [4, 1], [4, 1], [4, 1], [4, 1]])
  711. y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
  712. 1, 1, 1, 1, 2, 2, 2, 2])
  713. y = _enforce_estimator_tags_y(estimator, y)
  714. estimator.fit(X, y, sample_weight=np.ones(len(y)))
  715. assert_raises(ValueError, estimator.fit, X, y,
  716. sample_weight=np.ones(2*len(y)))
  717. assert_raises(ValueError, estimator.fit, X, y,
  718. sample_weight=np.ones((len(y), 2)))
  719. @ignore_warnings(category=FutureWarning)
  720. def check_sample_weights_invariance(name, estimator_orig):
  721. # check that the estimators yield same results for
  722. # unit weights and no weights
  723. if (has_fit_parameter(estimator_orig, "sample_weight") and
  724. not (hasattr(estimator_orig, "_pairwise")
  725. and estimator_orig._pairwise)):
  726. # We skip pairwise because the data is not pairwise
  727. estimator1 = clone(estimator_orig)
  728. estimator2 = clone(estimator_orig)
  729. set_random_state(estimator1, random_state=0)
  730. set_random_state(estimator2, random_state=0)
  731. X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
  732. [2, 1], [2, 1], [2, 1], [2, 1],
  733. [3, 3], [3, 3], [3, 3], [3, 3],
  734. [4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
  735. y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
  736. 1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))
  737. y = _enforce_estimator_tags_y(estimator1, y)
  738. estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
  739. estimator2.fit(X, y=y, sample_weight=None)
  740. for method in ["predict", "transform"]:
  741. if hasattr(estimator_orig, method):
  742. X_pred1 = getattr(estimator1, method)(X)
  743. X_pred2 = getattr(estimator2, method)(X)
  744. if sparse.issparse(X_pred1):
  745. X_pred1 = X_pred1.toarray()
  746. X_pred2 = X_pred2.toarray()
  747. assert_allclose(X_pred1, X_pred2,
  748. err_msg="For %s sample_weight=None is not"
  749. " equivalent to sample_weight=ones"
  750. % name)
  751. @ignore_warnings(category=(FutureWarning, UserWarning))
  752. def check_dtype_object(name, estimator_orig):
  753. # check that estimators treat dtype object as numeric if possible
  754. rng = np.random.RandomState(0)
  755. X = _pairwise_estimator_convert_X(rng.rand(40, 10), estimator_orig)
  756. X = X.astype(object)
  757. tags = estimator_orig._get_tags()
  758. if tags['binary_only']:
  759. y = (X[:, 0] * 2).astype(np.int)
  760. else:
  761. y = (X[:, 0] * 4).astype(np.int)
  762. estimator = clone(estimator_orig)
  763. y = _enforce_estimator_tags_y(estimator, y)
  764. estimator.fit(X, y)
  765. if hasattr(estimator, "predict"):
  766. estimator.predict(X)
  767. if hasattr(estimator, "transform"):
  768. estimator.transform(X)
  769. try:
  770. estimator.fit(X, y.astype(object))
  771. except Exception as e:
  772. if "Unknown label type" not in str(e):
  773. raise
  774. if 'string' not in tags['X_types']:
  775. X[0, 0] = {'foo': 'bar'}
  776. msg = "argument must be a string.* number"
  777. assert_raises_regex(TypeError, msg, estimator.fit, X, y)
  778. else:
  779. # Estimators supporting string will not call np.asarray to convert the
  780. # data to numeric and therefore, the error will not be raised.
  781. # Checking for each element dtype in the input array will be costly.
  782. # Refer to #11401 for full discussion.
  783. estimator.fit(X, y)
  784. def check_complex_data(name, estimator_orig):
  785. # check that estimators raise an exception on providing complex data
  786. X = np.random.sample(10) + 1j * np.random.sample(10)
  787. X = X.reshape(-1, 1)
  788. y = np.random.sample(10) + 1j * np.random.sample(10)
  789. estimator = clone(estimator_orig)
  790. assert_raises_regex(ValueError, "Complex data not supported",
  791. estimator.fit, X, y)
  792. @ignore_warnings
  793. def check_dict_unchanged(name, estimator_orig):
  794. # this estimator raises
  795. # ValueError: Found array with 0 feature(s) (shape=(23, 0))
  796. # while a minimum of 1 is required.
  797. # error
  798. if name in ['SpectralCoclustering']:
  799. return
  800. rnd = np.random.RandomState(0)
  801. if name in ['RANSACRegressor']:
  802. X = 3 * rnd.uniform(size=(20, 3))
  803. else:
  804. X = 2 * rnd.uniform(size=(20, 3))
  805. X = _pairwise_estimator_convert_X(X, estimator_orig)
  806. y = X[:, 0].astype(np.int)
  807. estimator = clone(estimator_orig)
  808. y = _enforce_estimator_tags_y(estimator, y)
  809. if hasattr(estimator, "n_components"):
  810. estimator.n_components = 1
  811. if hasattr(estimator, "n_clusters"):
  812. estimator.n_clusters = 1
  813. if hasattr(estimator, "n_best"):
  814. estimator.n_best = 1
  815. set_random_state(estimator, 1)
  816. estimator.fit(X, y)
  817. for method in ["predict", "transform", "decision_function",
  818. "predict_proba"]:
  819. if hasattr(estimator, method):
  820. dict_before = estimator.__dict__.copy()
  821. getattr(estimator, method)(X)
  822. assert estimator.__dict__ == dict_before, (
  823. 'Estimator changes __dict__ during %s' % method)
  824. @deprecated("is_public_parameter is deprecated in version "
  825. "0.22 and will be removed in version 0.24.")
  826. def is_public_parameter(attr):
  827. return _is_public_parameter(attr)
  828. def _is_public_parameter(attr):
  829. return not (attr.startswith('_') or attr.endswith('_'))
  830. @ignore_warnings(category=FutureWarning)
  831. def check_dont_overwrite_parameters(name, estimator_orig):
  832. # check that fit method only changes or sets private attributes
  833. if hasattr(estimator_orig.__init__, "deprecated_original"):
  834. # to not check deprecated classes
  835. return
  836. estimator = clone(estimator_orig)
  837. rnd = np.random.RandomState(0)
  838. X = 3 * rnd.uniform(size=(20, 3))
  839. X = _pairwise_estimator_convert_X(X, estimator_orig)
  840. y = X[:, 0].astype(np.int)
  841. if estimator._get_tags()['binary_only']:
  842. y[y == 2] = 1
  843. y = _enforce_estimator_tags_y(estimator, y)
  844. if hasattr(estimator, "n_components"):
  845. estimator.n_components = 1
  846. if hasattr(estimator, "n_clusters"):
  847. estimator.n_clusters = 1
  848. set_random_state(estimator, 1)
  849. dict_before_fit = estimator.__dict__.copy()
  850. estimator.fit(X, y)
  851. dict_after_fit = estimator.__dict__
  852. public_keys_after_fit = [key for key in dict_after_fit.keys()
  853. if _is_public_parameter(key)]
  854. attrs_added_by_fit = [key for key in public_keys_after_fit
  855. if key not in dict_before_fit.keys()]
  856. # check that fit doesn't add any public attribute
  857. assert not attrs_added_by_fit, (
  858. 'Estimator adds public attribute(s) during'
  859. ' the fit method.'
  860. ' Estimators are only allowed to add private attributes'
  861. ' either started with _ or ended'
  862. ' with _ but %s added'
  863. % ', '.join(attrs_added_by_fit))
  864. # check that fit doesn't change any public attribute
  865. attrs_changed_by_fit = [key for key in public_keys_after_fit
  866. if (dict_before_fit[key]
  867. is not dict_after_fit[key])]
  868. assert not attrs_changed_by_fit, (
  869. 'Estimator changes public attribute(s) during'
  870. ' the fit method. Estimators are only allowed'
  871. ' to change attributes started'
  872. ' or ended with _, but'
  873. ' %s changed'
  874. % ', '.join(attrs_changed_by_fit))
  875. @ignore_warnings(category=FutureWarning)
  876. def check_fit2d_predict1d(name, estimator_orig):
  877. # check by fitting a 2d array and predicting with a 1d array
  878. rnd = np.random.RandomState(0)
  879. X = 3 * rnd.uniform(size=(20, 3))
  880. X = _pairwise_estimator_convert_X(X, estimator_orig)
  881. y = X[:, 0].astype(np.int)
  882. tags = estimator_orig._get_tags()
  883. if tags['binary_only']:
  884. y[y == 2] = 1
  885. estimator = clone(estimator_orig)
  886. y = _enforce_estimator_tags_y(estimator, y)
  887. if hasattr(estimator, "n_components"):
  888. estimator.n_components = 1
  889. if hasattr(estimator, "n_clusters"):
  890. estimator.n_clusters = 1
  891. set_random_state(estimator, 1)
  892. estimator.fit(X, y)
  893. if tags["no_validation"]:
  894. # FIXME this is a bit loose
  895. return
  896. for method in ["predict", "transform", "decision_function",
  897. "predict_proba"]:
  898. if hasattr(estimator, method):
  899. assert_raise_message(ValueError, "Reshape your data",
  900. getattr(estimator, method), X[0])
  901. def _apply_on_subsets(func, X):
  902. # apply function on the whole set and on mini batches
  903. result_full = func(X)
  904. n_features = X.shape[1]
  905. result_by_batch = [func(batch.reshape(1, n_features))
  906. for batch in X]
  907. # func can output tuple (e.g. score_samples)
  908. if type(result_full) == tuple:
  909. result_full = result_full[0]
  910. result_by_batch = list(map(lambda x: x[0], result_by_batch))
  911. if sparse.issparse(result_full):
  912. result_full = result_full.A
  913. result_by_batch = [x.A for x in result_by_batch]
  914. return np.ravel(result_full), np.ravel(result_by_batch)
  915. @ignore_warnings(category=FutureWarning)
  916. def check_methods_subset_invariance(name, estimator_orig):
  917. # check that method gives invariant results if applied
  918. # on mini batches or the whole set
  919. rnd = np.random.RandomState(0)
  920. X = 3 * rnd.uniform(size=(20, 3))
  921. X = _pairwise_estimator_convert_X(X, estimator_orig)
  922. y = X[:, 0].astype(np.int)
  923. if estimator_orig._get_tags()['binary_only']:
  924. y[y == 2] = 1
  925. estimator = clone(estimator_orig)
  926. y = _enforce_estimator_tags_y(estimator, y)
  927. if hasattr(estimator, "n_components"):
  928. estimator.n_components = 1
  929. if hasattr(estimator, "n_clusters"):
  930. estimator.n_clusters = 1
  931. set_random_state(estimator, 1)
  932. estimator.fit(X, y)
  933. for method in ["predict", "transform", "decision_function",
  934. "score_samples", "predict_proba"]:
  935. msg = ("{method} of {name} is not invariant when applied "
  936. "to a subset.").format(method=method, name=name)
  937. if hasattr(estimator, method):
  938. result_full, result_by_batch = _apply_on_subsets(
  939. getattr(estimator, method), X)
  940. assert_allclose(result_full, result_by_batch,
  941. atol=1e-7, err_msg=msg)
  942. @ignore_warnings
  943. def check_fit2d_1sample(name, estimator_orig):
  944. # Check that fitting a 2d array with only one sample either works or
  945. # returns an informative message. The error message should either mention
  946. # the number of samples or the number of classes.
  947. rnd = np.random.RandomState(0)
  948. X = 3 * rnd.uniform(size=(1, 10))
  949. X = _pairwise_estimator_convert_X(X, estimator_orig)
  950. y = X[:, 0].astype(np.int)
  951. estimator = clone(estimator_orig)
  952. y = _enforce_estimator_tags_y(estimator, y)
  953. if hasattr(estimator, "n_components"):
  954. estimator.n_components = 1
  955. if hasattr(estimator, "n_clusters"):
  956. estimator.n_clusters = 1
  957. set_random_state(estimator, 1)
  958. # min_cluster_size cannot be less than the data size for OPTICS.
  959. if name == 'OPTICS':
  960. estimator.set_params(min_samples=1)
  961. msgs = ["1 sample", "n_samples = 1", "n_samples=1", "one sample",
  962. "1 class", "one class"]
  963. try:
  964. estimator.fit(X, y)
  965. except ValueError as e:
  966. if all(msg not in repr(e) for msg in msgs):
  967. raise e
  968. @ignore_warnings
  969. def check_fit2d_1feature(name, estimator_orig):
  970. # check fitting a 2d array with only 1 feature either works or returns
  971. # informative message
  972. rnd = np.random.RandomState(0)
  973. X = 3 * rnd.uniform(size=(10, 1))
  974. X = _pairwise_estimator_convert_X(X, estimator_orig)
  975. y = X[:, 0].astype(np.int)
  976. estimator = clone(estimator_orig)
  977. y = _enforce_estimator_tags_y(estimator, y)
  978. if hasattr(estimator, "n_components"):
  979. estimator.n_components = 1
  980. if hasattr(estimator, "n_clusters"):
  981. estimator.n_clusters = 1
  982. # ensure two labels in subsample for RandomizedLogisticRegression
  983. if name == 'RandomizedLogisticRegression':
  984. estimator.sample_fraction = 1
  985. # ensure non skipped trials for RANSACRegressor
  986. if name == 'RANSACRegressor':
  987. estimator.residual_threshold = 0.5
  988. y = _enforce_estimator_tags_y(estimator, y)
  989. set_random_state(estimator, 1)
  990. msgs = ["1 feature(s)", "n_features = 1", "n_features=1"]
  991. try:
  992. estimator.fit(X, y)
  993. except ValueError as e:
  994. if all(msg not in repr(e) for msg in msgs):
  995. raise e
  996. @ignore_warnings
  997. def check_fit1d(name, estimator_orig):
  998. # check fitting 1d X array raises a ValueError
  999. rnd = np.random.RandomState(0)
  1000. X = 3 * rnd.uniform(size=(20))
  1001. y = X.astype(np.int)
  1002. estimator = clone(estimator_orig)
  1003. tags = estimator._get_tags()
  1004. if tags["no_validation"]:
  1005. # FIXME this is a bit loose
  1006. return
  1007. y = _enforce_estimator_tags_y(estimator, y)
  1008. if hasattr(estimator, "n_components"):
  1009. estimator.n_components = 1
  1010. if hasattr(estimator, "n_clusters"):
  1011. estimator.n_clusters = 1
  1012. set_random_state(estimator, 1)
  1013. assert_raises(ValueError, estimator.fit, X, y)
  1014. @ignore_warnings(category=FutureWarning)
  1015. def check_transformer_general(name, transformer, readonly_memmap=False):
  1016. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1017. random_state=0, n_features=2, cluster_std=0.1)
  1018. X = StandardScaler().fit_transform(X)
  1019. X -= X.min()
  1020. X = _pairwise_estimator_convert_X(X, transformer)
  1021. if readonly_memmap:
  1022. X, y = create_memmap_backed_data([X, y])
  1023. _check_transformer(name, transformer, X, y)
  1024. @ignore_warnings(category=FutureWarning)
  1025. def check_transformer_data_not_an_array(name, transformer):
  1026. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1027. random_state=0, n_features=2, cluster_std=0.1)
  1028. X = StandardScaler().fit_transform(X)
  1029. # We need to make sure that we have non negative data, for things
  1030. # like NMF
  1031. X -= X.min() - .1
  1032. X = _pairwise_estimator_convert_X(X, transformer)
  1033. this_X = _NotAnArray(X)
  1034. this_y = _NotAnArray(np.asarray(y))
  1035. _check_transformer(name, transformer, this_X, this_y)
  1036. # try the same with some list
  1037. _check_transformer(name, transformer, X.tolist(), y.tolist())
  1038. @ignore_warnings(category=FutureWarning)
  1039. def check_transformers_unfitted(name, transformer):
  1040. X, y = _boston_subset()
  1041. transformer = clone(transformer)
  1042. with assert_raises((AttributeError, ValueError), msg="The unfitted "
  1043. "transformer {} does not raise an error when "
  1044. "transform is called. Perhaps use "
  1045. "check_is_fitted in transform.".format(name)):
  1046. transformer.transform(X)
  1047. def _check_transformer(name, transformer_orig, X, y):
  1048. n_samples, n_features = np.asarray(X).shape
  1049. transformer = clone(transformer_orig)
  1050. set_random_state(transformer)
  1051. # fit
  1052. if name in CROSS_DECOMPOSITION:
  1053. y_ = np.c_[np.asarray(y), np.asarray(y)]
  1054. y_[::2, 1] *= 2
  1055. if isinstance(X, _NotAnArray):
  1056. y_ = _NotAnArray(y_)
  1057. else:
  1058. y_ = y
  1059. transformer.fit(X, y_)
  1060. # fit_transform method should work on non fitted estimator
  1061. transformer_clone = clone(transformer)
  1062. X_pred = transformer_clone.fit_transform(X, y=y_)
  1063. if isinstance(X_pred, tuple):
  1064. for x_pred in X_pred:
  1065. assert x_pred.shape[0] == n_samples
  1066. else:
  1067. # check for consistent n_samples
  1068. assert X_pred.shape[0] == n_samples
  1069. if hasattr(transformer, 'transform'):
  1070. if name in CROSS_DECOMPOSITION:
  1071. X_pred2 = transformer.transform(X, y_)
  1072. X_pred3 = transformer.fit_transform(X, y=y_)
  1073. else:
  1074. X_pred2 = transformer.transform(X)
  1075. X_pred3 = transformer.fit_transform(X, y=y_)
  1076. if transformer_orig._get_tags()['non_deterministic']:
  1077. msg = name + ' is non deterministic'
  1078. raise SkipTest(msg)
  1079. if isinstance(X_pred, tuple) and isinstance(X_pred2, tuple):
  1080. for x_pred, x_pred2, x_pred3 in zip(X_pred, X_pred2, X_pred3):
  1081. assert_allclose_dense_sparse(
  1082. x_pred, x_pred2, atol=1e-2,
  1083. err_msg="fit_transform and transform outcomes "
  1084. "not consistent in %s"
  1085. % transformer)
  1086. assert_allclose_dense_sparse(
  1087. x_pred, x_pred3, atol=1e-2,
  1088. err_msg="consecutive fit_transform outcomes "
  1089. "not consistent in %s"
  1090. % transformer)
  1091. else:
  1092. assert_allclose_dense_sparse(
  1093. X_pred, X_pred2,
  1094. err_msg="fit_transform and transform outcomes "
  1095. "not consistent in %s"
  1096. % transformer, atol=1e-2)
  1097. assert_allclose_dense_sparse(
  1098. X_pred, X_pred3, atol=1e-2,
  1099. err_msg="consecutive fit_transform outcomes "
  1100. "not consistent in %s"
  1101. % transformer)
  1102. assert _num_samples(X_pred2) == n_samples
  1103. assert _num_samples(X_pred3) == n_samples
  1104. # raises error on malformed input for transform
  1105. if hasattr(X, 'shape') and \
  1106. not transformer._get_tags()["stateless"] and \
  1107. X.ndim == 2 and X.shape[1] > 1:
  1108. # If it's not an array, it does not have a 'T' property
  1109. with assert_raises(ValueError, msg="The transformer {} does "
  1110. "not raise an error when the number of "
  1111. "features in transform is different from"
  1112. " the number of features in "
  1113. "fit.".format(name)):
  1114. transformer.transform(X[:, :-1])
  1115. @ignore_warnings
  1116. def check_pipeline_consistency(name, estimator_orig):
  1117. if estimator_orig._get_tags()['non_deterministic']:
  1118. msg = name + ' is non deterministic'
  1119. raise SkipTest(msg)
  1120. # check that make_pipeline(est) gives same score as est
  1121. X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
  1122. random_state=0, n_features=2, cluster_std=0.1)
  1123. X -= X.min()
  1124. X = _pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
  1125. estimat

Large files files are truncated, but you can click here to view the full file