PageRenderTime 28ms CodeModel.GetById 31ms RepoModel.GetById 0ms app.codeStats 0ms

/sklearn/utils/tests/test_validation.py

https://github.com/scikit-learn/scikit-learn
Python | 1495 lines | 1219 code | 177 blank | 99 comment | 116 complexity | 3019d60498bb8a394fa1549a9a5f0908 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. """Tests for input validation functions"""
  2. import numbers
  3. import warnings
  4. import re
  5. from tempfile import NamedTemporaryFile
  6. from itertools import product
  7. from operator import itemgetter
  8. import pytest
  9. from pytest import importorskip
  10. import numpy as np
  11. import scipy.sparse as sp
  12. from sklearn.utils._testing import assert_no_warnings
  13. from sklearn.utils._testing import ignore_warnings
  14. from sklearn.utils._testing import SkipTest
  15. from sklearn.utils._testing import assert_array_equal
  16. from sklearn.utils._testing import assert_allclose_dense_sparse
  17. from sklearn.utils._testing import assert_allclose
  18. from sklearn.utils._testing import _convert_container
  19. from sklearn.utils import as_float_array, check_array, check_symmetric
  20. from sklearn.utils import check_X_y
  21. from sklearn.utils import deprecated
  22. from sklearn.utils._mocking import MockDataFrame
  23. from sklearn.utils.fixes import parse_version
  24. from sklearn.utils.estimator_checks import _NotAnArray
  25. from sklearn.random_projection import _sparse_random_matrix
  26. from sklearn.linear_model import ARDRegression
  27. from sklearn.neighbors import KNeighborsClassifier
  28. from sklearn.ensemble import RandomForestRegressor
  29. from sklearn.svm import SVR
  30. from sklearn.datasets import make_blobs
  31. from sklearn.utils import _safe_indexing
  32. from sklearn.utils.validation import (
  33. has_fit_parameter,
  34. check_is_fitted,
  35. check_consistent_length,
  36. assert_all_finite,
  37. check_memory,
  38. check_non_negative,
  39. _num_samples,
  40. check_scalar,
  41. _check_psd_eigenvalues,
  42. _check_y,
  43. _deprecate_positional_args,
  44. _check_sample_weight,
  45. _allclose_dense_sparse,
  46. _num_features,
  47. FLOAT_DTYPES,
  48. _get_feature_names,
  49. _check_feature_names_in,
  50. _check_fit_params,
  51. )
  52. from sklearn.base import BaseEstimator
  53. import sklearn
  54. from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
  55. from sklearn.utils._testing import TempMemmap
  56. # TODO: Remove np.matrix usage in 1.2
  57. @pytest.mark.filterwarnings("ignore:np.matrix usage is deprecated in 1.0:FutureWarning")
  58. @pytest.mark.filterwarnings("ignore:the matrix subclass:PendingDeprecationWarning")
  59. def test_as_float_array():
  60. # Test function for as_float_array
  61. X = np.ones((3, 10), dtype=np.int32)
  62. X = X + np.arange(10, dtype=np.int32)
  63. X2 = as_float_array(X, copy=False)
  64. assert X2.dtype == np.float32
  65. # Another test
  66. X = X.astype(np.int64)
  67. X2 = as_float_array(X, copy=True)
  68. # Checking that the array wasn't overwritten
  69. assert as_float_array(X, copy=False) is not X
  70. assert X2.dtype == np.float64
  71. # Test int dtypes <= 32bit
  72. tested_dtypes = [bool, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32]
  73. for dtype in tested_dtypes:
  74. X = X.astype(dtype)
  75. X2 = as_float_array(X)
  76. assert X2.dtype == np.float32
  77. # Test object dtype
  78. X = X.astype(object)
  79. X2 = as_float_array(X, copy=True)
  80. assert X2.dtype == np.float64
  81. # Here, X is of the right type, it shouldn't be modified
  82. X = np.ones((3, 2), dtype=np.float32)
  83. assert as_float_array(X, copy=False) is X
  84. # Test that if X is fortran ordered it stays
  85. X = np.asfortranarray(X)
  86. assert np.isfortran(as_float_array(X, copy=True))
  87. # Test the copy parameter with some matrices
  88. matrices = [
  89. np.matrix(np.arange(5)),
  90. sp.csc_matrix(np.arange(5)).toarray(),
  91. _sparse_random_matrix(10, 10, density=0.10).toarray(),
  92. ]
  93. for M in matrices:
  94. N = as_float_array(M, copy=True)
  95. N[0, 0] = np.nan
  96. assert not np.isnan(M).any()
  97. @pytest.mark.parametrize("X", [(np.random.random((10, 2))), (sp.rand(10, 2).tocsr())])
  98. def test_as_float_array_nan(X):
  99. X[5, 0] = np.nan
  100. X[6, 1] = np.nan
  101. X_converted = as_float_array(X, force_all_finite="allow-nan")
  102. assert_allclose_dense_sparse(X_converted, X)
  103. # TODO: Remove np.matrix usage in 1.2
  104. @pytest.mark.filterwarnings("ignore:np.matrix usage is deprecated in 1.0:FutureWarning")
  105. @pytest.mark.filterwarnings("ignore:the matrix subclass:PendingDeprecationWarning")
  106. def test_np_matrix():
  107. # Confirm that input validation code does not return np.matrix
  108. X = np.arange(12).reshape(3, 4)
  109. assert not isinstance(as_float_array(X), np.matrix)
  110. assert not isinstance(as_float_array(np.matrix(X)), np.matrix)
  111. assert not isinstance(as_float_array(sp.csc_matrix(X)), np.matrix)
  112. def test_memmap():
  113. # Confirm that input validation code doesn't copy memory mapped arrays
  114. asflt = lambda x: as_float_array(x, copy=False)
  115. with NamedTemporaryFile(prefix="sklearn-test") as tmp:
  116. M = np.memmap(tmp, shape=(10, 10), dtype=np.float32)
  117. M[:] = 0
  118. for f in (check_array, np.asarray, asflt):
  119. X = f(M)
  120. X[:] = 1
  121. assert_array_equal(X.ravel(), M.ravel())
  122. X[:] = 0
  123. def test_ordering():
  124. # Check that ordering is enforced correctly by validation utilities.
  125. # We need to check each validation utility, because a 'copy' without
  126. # 'order=K' will kill the ordering.
  127. X = np.ones((10, 5))
  128. for A in X, X.T:
  129. for copy in (True, False):
  130. B = check_array(A, order="C", copy=copy)
  131. assert B.flags["C_CONTIGUOUS"]
  132. B = check_array(A, order="F", copy=copy)
  133. assert B.flags["F_CONTIGUOUS"]
  134. if copy:
  135. assert A is not B
  136. X = sp.csr_matrix(X)
  137. X.data = X.data[::-1]
  138. assert not X.data.flags["C_CONTIGUOUS"]
  139. @pytest.mark.parametrize(
  140. "value, force_all_finite", [(np.inf, False), (np.nan, "allow-nan"), (np.nan, False)]
  141. )
  142. @pytest.mark.parametrize("retype", [np.asarray, sp.csr_matrix])
  143. def test_check_array_force_all_finite_valid(value, force_all_finite, retype):
  144. X = retype(np.arange(4).reshape(2, 2).astype(float))
  145. X[0, 0] = value
  146. X_checked = check_array(X, force_all_finite=force_all_finite, accept_sparse=True)
  147. assert_allclose_dense_sparse(X, X_checked)
  148. @pytest.mark.parametrize(
  149. "value, input_name, force_all_finite, match_msg",
  150. [
  151. (np.inf, "", True, "Input contains infinity"),
  152. (np.inf, "X", True, "Input X contains infinity"),
  153. (np.inf, "sample_weight", True, "Input sample_weight contains infinity"),
  154. (np.inf, "X", "allow-nan", "Input X contains infinity"),
  155. (np.nan, "", True, "Input contains NaN"),
  156. (np.nan, "X", True, "Input X contains NaN"),
  157. (np.nan, "y", True, "Input y contains NaN"),
  158. (
  159. np.nan,
  160. "",
  161. "allow-inf",
  162. 'force_all_finite should be a bool or "allow-nan"',
  163. ),
  164. (np.nan, "", 1, "Input contains NaN"),
  165. ],
  166. )
  167. @pytest.mark.parametrize("retype", [np.asarray, sp.csr_matrix])
  168. def test_check_array_force_all_finiteinvalid(
  169. value, input_name, force_all_finite, match_msg, retype
  170. ):
  171. X = retype(np.arange(4).reshape(2, 2).astype(np.float64))
  172. X[0, 0] = value
  173. with pytest.raises(ValueError, match=match_msg):
  174. check_array(
  175. X,
  176. input_name=input_name,
  177. force_all_finite=force_all_finite,
  178. accept_sparse=True,
  179. )
  180. @pytest.mark.parametrize("input_name", ["X", "y", "sample_weight"])
  181. @pytest.mark.parametrize("retype", [np.asarray, sp.csr_matrix])
  182. def test_check_array_links_to_imputer_doc_only_for_X(input_name, retype):
  183. data = retype(np.arange(4).reshape(2, 2).astype(np.float64))
  184. data[0, 0] = np.nan
  185. estimator = SVR()
  186. extended_msg = (
  187. f"\n{estimator.__class__.__name__} does not accept missing values"
  188. " encoded as NaN natively. For supervised learning, you might want"
  189. " to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor"
  190. " which accept missing values encoded as NaNs natively."
  191. " Alternatively, it is possible to preprocess the"
  192. " data, for instance by using an imputer transformer in a pipeline"
  193. " or drop samples with missing values. See"
  194. " https://scikit-learn.org/stable/modules/impute.html"
  195. " You can find a list of all estimators that handle NaN values"
  196. " at the following page:"
  197. " https://scikit-learn.org/stable/modules/impute.html"
  198. "#estimators-that-handle-nan-values"
  199. )
  200. with pytest.raises(ValueError, match=f"Input {input_name} contains NaN") as ctx:
  201. check_array(
  202. data,
  203. estimator=estimator,
  204. input_name=input_name,
  205. accept_sparse=True,
  206. )
  207. if input_name == "X":
  208. assert extended_msg in ctx.value.args[0]
  209. else:
  210. assert extended_msg not in ctx.value.args[0]
  211. if input_name == "X":
  212. # Veriy that _validate_data is automatically called with the right argument
  213. # to generate the same exception:
  214. with pytest.raises(ValueError, match=f"Input {input_name} contains NaN") as ctx:
  215. SVR().fit(data, np.ones(data.shape[0]))
  216. assert extended_msg in ctx.value.args[0]
  217. def test_check_array_force_all_finite_object():
  218. X = np.array([["a", "b", np.nan]], dtype=object).T
  219. X_checked = check_array(X, dtype=None, force_all_finite="allow-nan")
  220. assert X is X_checked
  221. X_checked = check_array(X, dtype=None, force_all_finite=False)
  222. assert X is X_checked
  223. with pytest.raises(ValueError, match="Input contains NaN"):
  224. check_array(X, dtype=None, force_all_finite=True)
  225. @pytest.mark.parametrize(
  226. "X, err_msg",
  227. [
  228. (
  229. np.array([[1, np.nan]]),
  230. "Input contains NaN.",
  231. ),
  232. (
  233. np.array([[1, np.nan]]),
  234. "Input contains NaN.",
  235. ),
  236. (
  237. np.array([[1, np.inf]]),
  238. "Input contains infinity or a value too large for.*int",
  239. ),
  240. (np.array([[1, np.nan]], dtype=object), "cannot convert float NaN to integer"),
  241. ],
  242. )
  243. @pytest.mark.parametrize("force_all_finite", [True, False])
  244. def test_check_array_force_all_finite_object_unsafe_casting(
  245. X, err_msg, force_all_finite
  246. ):
  247. # casting a float array containing NaN or inf to int dtype should
  248. # raise an error irrespective of the force_all_finite parameter.
  249. with pytest.raises(ValueError, match=err_msg):
  250. check_array(X, dtype=int, force_all_finite=force_all_finite)
  251. @ignore_warnings
  252. def test_check_array():
  253. # accept_sparse == False
  254. # raise error on sparse inputs
  255. X = [[1, 2], [3, 4]]
  256. X_csr = sp.csr_matrix(X)
  257. with pytest.raises(TypeError):
  258. check_array(X_csr)
  259. # ensure_2d=False
  260. X_array = check_array([0, 1, 2], ensure_2d=False)
  261. assert X_array.ndim == 1
  262. # ensure_2d=True with 1d array
  263. with pytest.raises(ValueError, match="Expected 2D array, got 1D array instead"):
  264. check_array([0, 1, 2], ensure_2d=True)
  265. # ensure_2d=True with scalar array
  266. with pytest.raises(ValueError, match="Expected 2D array, got scalar array instead"):
  267. check_array(10, ensure_2d=True)
  268. # don't allow ndim > 3
  269. X_ndim = np.arange(8).reshape(2, 2, 2)
  270. with pytest.raises(ValueError):
  271. check_array(X_ndim)
  272. check_array(X_ndim, allow_nd=True) # doesn't raise
  273. # dtype and order enforcement.
  274. X_C = np.arange(4).reshape(2, 2).copy("C")
  275. X_F = X_C.copy("F")
  276. X_int = X_C.astype(int)
  277. X_float = X_C.astype(float)
  278. Xs = [X_C, X_F, X_int, X_float]
  279. dtypes = [np.int32, int, float, np.float32, None, bool, object]
  280. orders = ["C", "F", None]
  281. copys = [True, False]
  282. for X, dtype, order, copy in product(Xs, dtypes, orders, copys):
  283. X_checked = check_array(X, dtype=dtype, order=order, copy=copy)
  284. if dtype is not None:
  285. assert X_checked.dtype == dtype
  286. else:
  287. assert X_checked.dtype == X.dtype
  288. if order == "C":
  289. assert X_checked.flags["C_CONTIGUOUS"]
  290. assert not X_checked.flags["F_CONTIGUOUS"]
  291. elif order == "F":
  292. assert X_checked.flags["F_CONTIGUOUS"]
  293. assert not X_checked.flags["C_CONTIGUOUS"]
  294. if copy:
  295. assert X is not X_checked
  296. else:
  297. # doesn't copy if it was already good
  298. if (
  299. X.dtype == X_checked.dtype
  300. and X_checked.flags["C_CONTIGUOUS"] == X.flags["C_CONTIGUOUS"]
  301. and X_checked.flags["F_CONTIGUOUS"] == X.flags["F_CONTIGUOUS"]
  302. ):
  303. assert X is X_checked
  304. # allowed sparse != None
  305. X_csc = sp.csc_matrix(X_C)
  306. X_coo = X_csc.tocoo()
  307. X_dok = X_csc.todok()
  308. X_int = X_csc.astype(int)
  309. X_float = X_csc.astype(float)
  310. Xs = [X_csc, X_coo, X_dok, X_int, X_float]
  311. accept_sparses = [["csr", "coo"], ["coo", "dok"]]
  312. # scipy sparse matrices do not support the object dtype so
  313. # this dtype is skipped in this loop
  314. non_object_dtypes = [dt for dt in dtypes if dt is not object]
  315. for X, dtype, accept_sparse, copy in product(
  316. Xs, non_object_dtypes, accept_sparses, copys
  317. ):
  318. X_checked = check_array(X, dtype=dtype, accept_sparse=accept_sparse, copy=copy)
  319. if dtype is not None:
  320. assert X_checked.dtype == dtype
  321. else:
  322. assert X_checked.dtype == X.dtype
  323. if X.format in accept_sparse:
  324. # no change if allowed
  325. assert X.format == X_checked.format
  326. else:
  327. # got converted
  328. assert X_checked.format == accept_sparse[0]
  329. if copy:
  330. assert X is not X_checked
  331. else:
  332. # doesn't copy if it was already good
  333. if X.dtype == X_checked.dtype and X.format == X_checked.format:
  334. assert X is X_checked
  335. # other input formats
  336. # convert lists to arrays
  337. X_dense = check_array([[1, 2], [3, 4]])
  338. assert isinstance(X_dense, np.ndarray)
  339. # raise on too deep lists
  340. with pytest.raises(ValueError):
  341. check_array(X_ndim.tolist())
  342. check_array(X_ndim.tolist(), allow_nd=True) # doesn't raise
  343. # convert weird stuff to arrays
  344. X_no_array = _NotAnArray(X_dense)
  345. result = check_array(X_no_array)
  346. assert isinstance(result, np.ndarray)
  347. @pytest.mark.parametrize(
  348. "X",
  349. [
  350. [["1", "2"], ["3", "4"]],
  351. np.array([["1", "2"], ["3", "4"]], dtype="U"),
  352. np.array([["1", "2"], ["3", "4"]], dtype="S"),
  353. [[b"1", b"2"], [b"3", b"4"]],
  354. np.array([[b"1", b"2"], [b"3", b"4"]], dtype="V1"),
  355. ],
  356. )
  357. def test_check_array_numeric_error(X):
  358. """Test that check_array errors when it receives an array of bytes/string
  359. while a numeric dtype is required."""
  360. expected_msg = r"dtype='numeric' is not compatible with arrays of bytes/strings"
  361. with pytest.raises(ValueError, match=expected_msg):
  362. check_array(X, dtype="numeric")
  363. @pytest.mark.parametrize(
  364. "pd_dtype", ["Int8", "Int16", "UInt8", "UInt16", "Float32", "Float64"]
  365. )
  366. @pytest.mark.parametrize(
  367. "dtype, expected_dtype",
  368. [
  369. ([np.float32, np.float64], np.float32),
  370. (np.float64, np.float64),
  371. ("numeric", np.float64),
  372. ],
  373. )
  374. def test_check_array_pandas_na_support(pd_dtype, dtype, expected_dtype):
  375. # Test pandas numerical extension arrays with pd.NA
  376. pd = pytest.importorskip("pandas")
  377. if pd_dtype in {"Float32", "Float64"}:
  378. # Extension dtypes with Floats was added in 1.2
  379. pd = pytest.importorskip("pandas", minversion="1.2")
  380. X_np = np.array(
  381. [[1, 2, 3, np.nan, np.nan], [np.nan, np.nan, 8, 4, 6], [1, 2, 3, 4, 5]]
  382. ).T
  383. # Creates dataframe with numerical extension arrays with pd.NA
  384. X = pd.DataFrame(X_np, dtype=pd_dtype, columns=["a", "b", "c"])
  385. # column c has no nans
  386. X["c"] = X["c"].astype("float")
  387. X_checked = check_array(X, force_all_finite="allow-nan", dtype=dtype)
  388. assert_allclose(X_checked, X_np)
  389. assert X_checked.dtype == expected_dtype
  390. X_checked = check_array(X, force_all_finite=False, dtype=dtype)
  391. assert_allclose(X_checked, X_np)
  392. assert X_checked.dtype == expected_dtype
  393. msg = "Input contains NaN"
  394. with pytest.raises(ValueError, match=msg):
  395. check_array(X, force_all_finite=True)
  396. def test_check_array_pandas_dtype_casting():
  397. # test that data-frames with homogeneous dtype are not upcast
  398. pd = pytest.importorskip("pandas")
  399. X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
  400. X_df = pd.DataFrame(X)
  401. assert check_array(X_df).dtype == np.float32
  402. assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
  403. X_df = X_df.astype({0: np.float16})
  404. assert_array_equal(X_df.dtypes, (np.float16, np.float32, np.float32))
  405. assert check_array(X_df).dtype == np.float32
  406. assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
  407. X_df = X_df.astype({0: np.int16})
  408. # float16, int16, float32 casts to float32
  409. assert check_array(X_df).dtype == np.float32
  410. assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
  411. X_df = X_df.astype({2: np.float16})
  412. # float16, int16, float16 casts to float32
  413. assert check_array(X_df).dtype == np.float32
  414. assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float32
  415. X_df = X_df.astype(np.int16)
  416. assert check_array(X_df).dtype == np.int16
  417. # we're not using upcasting rules for determining
  418. # the target type yet, so we cast to the default of float64
  419. assert check_array(X_df, dtype=FLOAT_DTYPES).dtype == np.float64
  420. # check that we handle pandas dtypes in a semi-reasonable way
  421. # this is actually tricky because we can't really know that this
  422. # should be integer ahead of converting it.
  423. cat_df = pd.DataFrame({"cat_col": pd.Categorical([1, 2, 3])})
  424. assert check_array(cat_df).dtype == np.int64
  425. assert check_array(cat_df, dtype=FLOAT_DTYPES).dtype == np.float64
  426. def test_check_array_on_mock_dataframe():
  427. arr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
  428. mock_df = MockDataFrame(arr)
  429. checked_arr = check_array(mock_df)
  430. assert checked_arr.dtype == arr.dtype
  431. checked_arr = check_array(mock_df, dtype=np.float32)
  432. assert checked_arr.dtype == np.dtype(np.float32)
  433. def test_check_array_dtype_stability():
  434. # test that lists with ints don't get converted to floats
  435. X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  436. assert check_array(X).dtype.kind == "i"
  437. assert check_array(X, ensure_2d=False).dtype.kind == "i"
  438. def test_check_array_dtype_warning():
  439. X_int_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  440. X_float32 = np.asarray(X_int_list, dtype=np.float32)
  441. X_int64 = np.asarray(X_int_list, dtype=np.int64)
  442. X_csr_float32 = sp.csr_matrix(X_float32)
  443. X_csc_float32 = sp.csc_matrix(X_float32)
  444. X_csc_int32 = sp.csc_matrix(X_int64, dtype=np.int32)
  445. integer_data = [X_int64, X_csc_int32]
  446. float32_data = [X_float32, X_csr_float32, X_csc_float32]
  447. for X in integer_data:
  448. X_checked = assert_no_warnings(
  449. check_array, X, dtype=np.float64, accept_sparse=True
  450. )
  451. assert X_checked.dtype == np.float64
  452. for X in float32_data:
  453. X_checked = assert_no_warnings(
  454. check_array, X, dtype=[np.float64, np.float32], accept_sparse=True
  455. )
  456. assert X_checked.dtype == np.float32
  457. assert X_checked is X
  458. X_checked = assert_no_warnings(
  459. check_array,
  460. X,
  461. dtype=[np.float64, np.float32],
  462. accept_sparse=["csr", "dok"],
  463. copy=True,
  464. )
  465. assert X_checked.dtype == np.float32
  466. assert X_checked is not X
  467. X_checked = assert_no_warnings(
  468. check_array,
  469. X_csc_float32,
  470. dtype=[np.float64, np.float32],
  471. accept_sparse=["csr", "dok"],
  472. copy=False,
  473. )
  474. assert X_checked.dtype == np.float32
  475. assert X_checked is not X_csc_float32
  476. assert X_checked.format == "csr"
  477. def test_check_array_accept_sparse_type_exception():
  478. X = [[1, 2], [3, 4]]
  479. X_csr = sp.csr_matrix(X)
  480. invalid_type = SVR()
  481. msg = (
  482. "A sparse matrix was passed, but dense data is required. "
  483. r"Use X.toarray\(\) to convert to a dense numpy array."
  484. )
  485. with pytest.raises(TypeError, match=msg):
  486. check_array(X_csr, accept_sparse=False)
  487. msg = (
  488. "Parameter 'accept_sparse' should be a string, "
  489. "boolean or list of strings. You provided 'accept_sparse=.*'."
  490. )
  491. with pytest.raises(ValueError, match=msg):
  492. check_array(X_csr, accept_sparse=invalid_type)
  493. msg = (
  494. "When providing 'accept_sparse' as a tuple or list, "
  495. "it must contain at least one string value."
  496. )
  497. with pytest.raises(ValueError, match=msg):
  498. check_array(X_csr, accept_sparse=[])
  499. with pytest.raises(ValueError, match=msg):
  500. check_array(X_csr, accept_sparse=())
  501. with pytest.raises(TypeError, match="SVR"):
  502. check_array(X_csr, accept_sparse=[invalid_type])
  503. def test_check_array_accept_sparse_no_exception():
  504. X = [[1, 2], [3, 4]]
  505. X_csr = sp.csr_matrix(X)
  506. check_array(X_csr, accept_sparse=True)
  507. check_array(X_csr, accept_sparse="csr")
  508. check_array(X_csr, accept_sparse=["csr"])
  509. check_array(X_csr, accept_sparse=("csr",))
  510. @pytest.fixture(params=["csr", "csc", "coo", "bsr"])
  511. def X_64bit(request):
  512. X = sp.rand(20, 10, format=request.param)
  513. for attr in ["indices", "indptr", "row", "col"]:
  514. if hasattr(X, attr):
  515. setattr(X, attr, getattr(X, attr).astype("int64"))
  516. yield X
  517. def test_check_array_accept_large_sparse_no_exception(X_64bit):
  518. # When large sparse are allowed
  519. check_array(X_64bit, accept_large_sparse=True, accept_sparse=True)
  520. def test_check_array_accept_large_sparse_raise_exception(X_64bit):
  521. # When large sparse are not allowed
  522. msg = (
  523. "Only sparse matrices with 32-bit integer indices "
  524. "are accepted. Got int64 indices."
  525. )
  526. with pytest.raises(ValueError, match=msg):
  527. check_array(X_64bit, accept_sparse=True, accept_large_sparse=False)
  528. def test_check_array_min_samples_and_features_messages():
  529. # empty list is considered 2D by default:
  530. msg = r"0 feature\(s\) \(shape=\(1, 0\)\) while a minimum of 1 is" " required."
  531. with pytest.raises(ValueError, match=msg):
  532. check_array([[]])
  533. # If considered a 1D collection when ensure_2d=False, then the minimum
  534. # number of samples will break:
  535. msg = r"0 sample\(s\) \(shape=\(0,\)\) while a minimum of 1 is required."
  536. with pytest.raises(ValueError, match=msg):
  537. check_array([], ensure_2d=False)
  538. # Invalid edge case when checking the default minimum sample of a scalar
  539. msg = r"Singleton array array\(42\) cannot be considered a valid" " collection."
  540. with pytest.raises(TypeError, match=msg):
  541. check_array(42, ensure_2d=False)
  542. # Simulate a model that would need at least 2 samples to be well defined
  543. X = np.ones((1, 10))
  544. y = np.ones(1)
  545. msg = r"1 sample\(s\) \(shape=\(1, 10\)\) while a minimum of 2 is" " required."
  546. with pytest.raises(ValueError, match=msg):
  547. check_X_y(X, y, ensure_min_samples=2)
  548. # The same message is raised if the data has 2 dimensions even if this is
  549. # not mandatory
  550. with pytest.raises(ValueError, match=msg):
  551. check_X_y(X, y, ensure_min_samples=2, ensure_2d=False)
  552. # Simulate a model that would require at least 3 features (e.g. SelectKBest
  553. # with k=3)
  554. X = np.ones((10, 2))
  555. y = np.ones(2)
  556. msg = r"2 feature\(s\) \(shape=\(10, 2\)\) while a minimum of 3 is" " required."
  557. with pytest.raises(ValueError, match=msg):
  558. check_X_y(X, y, ensure_min_features=3)
  559. # Only the feature check is enabled whenever the number of dimensions is 2
  560. # even if allow_nd is enabled:
  561. with pytest.raises(ValueError, match=msg):
  562. check_X_y(X, y, ensure_min_features=3, allow_nd=True)
  563. # Simulate a case where a pipeline stage as trimmed all the features of a
  564. # 2D dataset.
  565. X = np.empty(0).reshape(10, 0)
  566. y = np.ones(10)
  567. msg = r"0 feature\(s\) \(shape=\(10, 0\)\) while a minimum of 1 is" " required."
  568. with pytest.raises(ValueError, match=msg):
  569. check_X_y(X, y)
  570. # nd-data is not checked for any minimum number of features by default:
  571. X = np.ones((10, 0, 28, 28))
  572. y = np.ones(10)
  573. X_checked, y_checked = check_X_y(X, y, allow_nd=True)
  574. assert_array_equal(X, X_checked)
  575. assert_array_equal(y, y_checked)
  576. def test_check_array_complex_data_error():
  577. X = np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]])
  578. with pytest.raises(ValueError, match="Complex data not supported"):
  579. check_array(X)
  580. # list of lists
  581. X = [[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]
  582. with pytest.raises(ValueError, match="Complex data not supported"):
  583. check_array(X)
  584. # tuple of tuples
  585. X = ((1 + 2j, 3 + 4j, 5 + 7j), (2 + 3j, 4 + 5j, 6 + 7j))
  586. with pytest.raises(ValueError, match="Complex data not supported"):
  587. check_array(X)
  588. # list of np arrays
  589. X = [np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j])]
  590. with pytest.raises(ValueError, match="Complex data not supported"):
  591. check_array(X)
  592. # tuple of np arrays
  593. X = (np.array([1 + 2j, 3 + 4j, 5 + 7j]), np.array([2 + 3j, 4 + 5j, 6 + 7j]))
  594. with pytest.raises(ValueError, match="Complex data not supported"):
  595. check_array(X)
  596. # dataframe
  597. X = MockDataFrame(np.array([[1 + 2j, 3 + 4j, 5 + 7j], [2 + 3j, 4 + 5j, 6 + 7j]]))
  598. with pytest.raises(ValueError, match="Complex data not supported"):
  599. check_array(X)
  600. # sparse matrix
  601. X = sp.coo_matrix([[0, 1 + 2j], [0, 0]])
  602. with pytest.raises(ValueError, match="Complex data not supported"):
  603. check_array(X)
  604. # target variable does not always go through check_array but should
  605. # never accept complex data either.
  606. y = np.array([1 + 2j, 3 + 4j, 5 + 7j, 2 + 3j, 4 + 5j, 6 + 7j])
  607. with pytest.raises(ValueError, match="Complex data not supported"):
  608. _check_y(y)
  609. def test_has_fit_parameter():
  610. assert not has_fit_parameter(KNeighborsClassifier, "sample_weight")
  611. assert has_fit_parameter(RandomForestRegressor, "sample_weight")
  612. assert has_fit_parameter(SVR, "sample_weight")
  613. assert has_fit_parameter(SVR(), "sample_weight")
  614. class TestClassWithDeprecatedFitMethod:
  615. @deprecated("Deprecated for the purpose of testing has_fit_parameter")
  616. def fit(self, X, y, sample_weight=None):
  617. pass
  618. assert has_fit_parameter(
  619. TestClassWithDeprecatedFitMethod, "sample_weight"
  620. ), "has_fit_parameter fails for class with deprecated fit method."
  621. def test_check_symmetric():
  622. arr_sym = np.array([[0, 1], [1, 2]])
  623. arr_bad = np.ones(2)
  624. arr_asym = np.array([[0, 2], [0, 2]])
  625. test_arrays = {
  626. "dense": arr_asym,
  627. "dok": sp.dok_matrix(arr_asym),
  628. "csr": sp.csr_matrix(arr_asym),
  629. "csc": sp.csc_matrix(arr_asym),
  630. "coo": sp.coo_matrix(arr_asym),
  631. "lil": sp.lil_matrix(arr_asym),
  632. "bsr": sp.bsr_matrix(arr_asym),
  633. }
  634. # check error for bad inputs
  635. with pytest.raises(ValueError):
  636. check_symmetric(arr_bad)
  637. # check that asymmetric arrays are properly symmetrized
  638. for arr_format, arr in test_arrays.items():
  639. # Check for warnings and errors
  640. with pytest.warns(UserWarning):
  641. check_symmetric(arr)
  642. with pytest.raises(ValueError):
  643. check_symmetric(arr, raise_exception=True)
  644. output = check_symmetric(arr, raise_warning=False)
  645. if sp.issparse(output):
  646. assert output.format == arr_format
  647. assert_array_equal(output.toarray(), arr_sym)
  648. else:
  649. assert_array_equal(output, arr_sym)
  650. def test_check_is_fitted_with_is_fitted():
  651. class Estimator(BaseEstimator):
  652. def fit(self, **kwargs):
  653. self._is_fitted = True
  654. return self
  655. def __sklearn_is_fitted__(self):
  656. return hasattr(self, "_is_fitted") and self._is_fitted
  657. with pytest.raises(NotFittedError):
  658. check_is_fitted(Estimator())
  659. check_is_fitted(Estimator().fit())
  660. def test_check_is_fitted():
  661. # Check is TypeError raised when non estimator instance passed
  662. with pytest.raises(TypeError):
  663. check_is_fitted(ARDRegression)
  664. with pytest.raises(TypeError):
  665. check_is_fitted("SVR")
  666. ard = ARDRegression()
  667. svr = SVR()
  668. try:
  669. with pytest.raises(NotFittedError):
  670. check_is_fitted(ard)
  671. with pytest.raises(NotFittedError):
  672. check_is_fitted(svr)
  673. except ValueError:
  674. assert False, "check_is_fitted failed with ValueError"
  675. # NotFittedError is a subclass of both ValueError and AttributeError
  676. try:
  677. check_is_fitted(ard, msg="Random message %(name)s, %(name)s")
  678. except ValueError as e:
  679. assert str(e) == "Random message ARDRegression, ARDRegression"
  680. try:
  681. check_is_fitted(svr, msg="Another message %(name)s, %(name)s")
  682. except AttributeError as e:
  683. assert str(e) == "Another message SVR, SVR"
  684. ard.fit(*make_blobs())
  685. svr.fit(*make_blobs())
  686. assert check_is_fitted(ard) is None
  687. assert check_is_fitted(svr) is None
  688. def test_check_is_fitted_attributes():
  689. class MyEstimator:
  690. def fit(self, X, y):
  691. return self
  692. msg = "not fitted"
  693. est = MyEstimator()
  694. with pytest.raises(NotFittedError, match=msg):
  695. check_is_fitted(est, attributes=["a_", "b_"])
  696. with pytest.raises(NotFittedError, match=msg):
  697. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
  698. with pytest.raises(NotFittedError, match=msg):
  699. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
  700. est.a_ = "a"
  701. with pytest.raises(NotFittedError, match=msg):
  702. check_is_fitted(est, attributes=["a_", "b_"])
  703. with pytest.raises(NotFittedError, match=msg):
  704. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
  705. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
  706. est.b_ = "b"
  707. check_is_fitted(est, attributes=["a_", "b_"])
  708. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=all)
  709. check_is_fitted(est, attributes=["a_", "b_"], all_or_any=any)
  710. @pytest.mark.parametrize(
  711. "wrap", [itemgetter(0), list, tuple], ids=["single", "list", "tuple"]
  712. )
  713. def test_check_is_fitted_with_attributes(wrap):
  714. ard = ARDRegression()
  715. with pytest.raises(NotFittedError, match="is not fitted yet"):
  716. check_is_fitted(ard, wrap(["coef_"]))
  717. ard.fit(*make_blobs())
  718. # Does not raise
  719. check_is_fitted(ard, wrap(["coef_"]))
  720. # Raises when using attribute that is not defined
  721. with pytest.raises(NotFittedError, match="is not fitted yet"):
  722. check_is_fitted(ard, wrap(["coef_bad_"]))
  723. def test_check_consistent_length():
  724. check_consistent_length([1], [2], [3], [4], [5])
  725. check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ["a", "b"])
  726. check_consistent_length([1], (2,), np.array([3]), sp.csr_matrix((1, 2)))
  727. with pytest.raises(ValueError, match="inconsistent numbers of samples"):
  728. check_consistent_length([1, 2], [1])
  729. with pytest.raises(TypeError, match=r"got <\w+ 'int'>"):
  730. check_consistent_length([1, 2], 1)
  731. with pytest.raises(TypeError, match=r"got <\w+ 'object'>"):
  732. check_consistent_length([1, 2], object())
  733. with pytest.raises(TypeError):
  734. check_consistent_length([1, 2], np.array(1))
  735. # Despite ensembles having __len__ they must raise TypeError
  736. with pytest.raises(TypeError, match="Expected sequence or array-like"):
  737. check_consistent_length([1, 2], RandomForestRegressor())
  738. # XXX: We should have a test with a string, but what is correct behaviour?
  739. def test_check_dataframe_fit_attribute():
  740. # check pandas dataframe with 'fit' column does not raise error
  741. # https://github.com/scikit-learn/scikit-learn/issues/8415
  742. try:
  743. import pandas as pd
  744. X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  745. X_df = pd.DataFrame(X, columns=["a", "b", "fit"])
  746. check_consistent_length(X_df)
  747. except ImportError:
  748. raise SkipTest("Pandas not found")
  749. def test_suppress_validation():
  750. X = np.array([0, np.inf])
  751. with pytest.raises(ValueError):
  752. assert_all_finite(X)
  753. sklearn.set_config(assume_finite=True)
  754. assert_all_finite(X)
  755. sklearn.set_config(assume_finite=False)
  756. with pytest.raises(ValueError):
  757. assert_all_finite(X)
  758. def test_check_array_series():
  759. # regression test that check_array works on pandas Series
  760. pd = importorskip("pandas")
  761. res = check_array(pd.Series([1, 2, 3]), ensure_2d=False)
  762. assert_array_equal(res, np.array([1, 2, 3]))
  763. # with categorical dtype (not a numpy dtype) (GH12699)
  764. s = pd.Series(["a", "b", "c"]).astype("category")
  765. res = check_array(s, dtype=None, ensure_2d=False)
  766. assert_array_equal(res, np.array(["a", "b", "c"], dtype=object))
  767. @pytest.mark.parametrize(
  768. "dtype", ((np.float64, np.float32), np.float64, None, "numeric")
  769. )
  770. @pytest.mark.parametrize("bool_dtype", ("bool", "boolean"))
  771. def test_check_dataframe_mixed_float_dtypes(dtype, bool_dtype):
  772. # pandas dataframe will coerce a boolean into a object, this is a mismatch
  773. # with np.result_type which will return a float
  774. # check_array needs to explicitly check for bool dtype in a dataframe for
  775. # this situation
  776. # https://github.com/scikit-learn/scikit-learn/issues/15787
  777. if bool_dtype == "boolean":
  778. # boolean extension arrays was introduced in 1.0
  779. pd = importorskip("pandas", minversion="1.0")
  780. else:
  781. pd = importorskip("pandas")
  782. df = pd.DataFrame(
  783. {
  784. "int": [1, 2, 3],
  785. "float": [0, 0.1, 2.1],
  786. "bool": pd.Series([True, False, True], dtype=bool_dtype),
  787. },
  788. columns=["int", "float", "bool"],
  789. )
  790. array = check_array(df, dtype=dtype)
  791. assert array.dtype == np.float64
  792. expected_array = np.array(
  793. [[1.0, 0.0, 1.0], [2.0, 0.1, 0.0], [3.0, 2.1, 1.0]], dtype=float
  794. )
  795. assert_allclose_dense_sparse(array, expected_array)
  796. def test_check_dataframe_with_only_bool():
  797. """Check that dataframe with bool return a boolean arrays."""
  798. pd = importorskip("pandas")
  799. df = pd.DataFrame({"bool": [True, False, True]})
  800. array = check_array(df, dtype=None)
  801. assert array.dtype == np.bool_
  802. assert_array_equal(array, [[True], [False], [True]])
  803. # common dtype is int for bool + int
  804. df = pd.DataFrame(
  805. {"bool": [True, False, True], "int": [1, 2, 3]},
  806. columns=["bool", "int"],
  807. )
  808. array = check_array(df, dtype="numeric")
  809. assert array.dtype == np.int64
  810. assert_array_equal(array, [[1, 1], [0, 2], [1, 3]])
  811. def test_check_dataframe_with_only_boolean():
  812. """Check that dataframe with boolean return a float array with dtype=None"""
  813. pd = importorskip("pandas", minversion="1.0")
  814. df = pd.DataFrame({"bool": pd.Series([True, False, True], dtype="boolean")})
  815. array = check_array(df, dtype=None)
  816. assert array.dtype == np.float64
  817. assert_array_equal(array, [[True], [False], [True]])
  818. class DummyMemory:
  819. def cache(self, func):
  820. return func
  821. class WrongDummyMemory:
  822. pass
  823. def test_check_memory():
  824. memory = check_memory("cache_directory")
  825. assert memory.location == "cache_directory"
  826. memory = check_memory(None)
  827. assert memory.location is None
  828. dummy = DummyMemory()
  829. memory = check_memory(dummy)
  830. assert memory is dummy
  831. msg = (
  832. "'memory' should be None, a string or have the same interface as"
  833. " joblib.Memory. Got memory='1' instead."
  834. )
  835. with pytest.raises(ValueError, match=msg):
  836. check_memory(1)
  837. dummy = WrongDummyMemory()
  838. msg = (
  839. "'memory' should be None, a string or have the same interface as"
  840. " joblib.Memory. Got memory='{}' instead.".format(dummy)
  841. )
  842. with pytest.raises(ValueError, match=msg):
  843. check_memory(dummy)
  844. @pytest.mark.parametrize("copy", [True, False])
  845. def test_check_array_memmap(copy):
  846. X = np.ones((4, 4))
  847. with TempMemmap(X, mmap_mode="r") as X_memmap:
  848. X_checked = check_array(X_memmap, copy=copy)
  849. assert np.may_share_memory(X_memmap, X_checked) == (not copy)
  850. assert X_checked.flags["WRITEABLE"] == copy
  851. @pytest.mark.parametrize(
  852. "retype",
  853. [
  854. np.asarray,
  855. sp.csr_matrix,
  856. sp.csc_matrix,
  857. sp.coo_matrix,
  858. sp.lil_matrix,
  859. sp.bsr_matrix,
  860. sp.dok_matrix,
  861. sp.dia_matrix,
  862. ],
  863. )
  864. def test_check_non_negative(retype):
  865. A = np.array([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
  866. X = retype(A)
  867. check_non_negative(X, "")
  868. X = retype([[0, 0], [0, 0]])
  869. check_non_negative(X, "")
  870. A[0, 0] = -1
  871. X = retype(A)
  872. with pytest.raises(ValueError, match="Negative "):
  873. check_non_negative(X, "")
  874. def test_check_X_y_informative_error():
  875. X = np.ones((2, 2))
  876. y = None
  877. msg = "estimator requires y to be passed, but the target y is None"
  878. with pytest.raises(ValueError, match=msg):
  879. check_X_y(X, y)
  880. msg = "RandomForestRegressor requires y to be passed, but the target y is None"
  881. with pytest.raises(ValueError, match=msg):
  882. check_X_y(X, y, estimator=RandomForestRegressor())
  883. def test_retrieve_samples_from_non_standard_shape():
  884. class TestNonNumericShape:
  885. def __init__(self):
  886. self.shape = ("not numeric",)
  887. def __len__(self):
  888. return len([1, 2, 3])
  889. X = TestNonNumericShape()
  890. assert _num_samples(X) == len(X)
  891. # check that it gives a good error if there's no __len__
  892. class TestNoLenWeirdShape:
  893. def __init__(self):
  894. self.shape = ("not numeric",)
  895. with pytest.raises(TypeError, match="Expected sequence or array-like"):
  896. _num_samples(TestNoLenWeirdShape())
  897. @pytest.mark.parametrize("x", [2, 3, 2.5, 5])
  898. def test_check_scalar_valid(x):
  899. """Test that check_scalar returns no error/warning if valid inputs are
  900. provided"""
  901. with warnings.catch_warnings():
  902. warnings.simplefilter("error")
  903. scalar = check_scalar(
  904. x,
  905. "test_name",
  906. target_type=numbers.Real,
  907. min_val=2,
  908. max_val=5,
  909. include_boundaries="both",
  910. )
  911. assert scalar == x
  912. @pytest.mark.parametrize(
  913. "x, target_name, target_type, min_val, max_val, include_boundaries, err_msg",
  914. [
  915. (
  916. 1,
  917. "test_name1",
  918. float,
  919. 2,
  920. 4,
  921. "neither",
  922. TypeError("test_name1 must be an instance of float, not int."),
  923. ),
  924. (
  925. None,
  926. "test_name1",
  927. numbers.Real,
  928. 2,
  929. 4,
  930. "neither",
  931. TypeError("test_name1 must be an instance of float, not NoneType."),
  932. ),
  933. (
  934. None,
  935. "test_name1",
  936. numbers.Integral,
  937. 2,
  938. 4,
  939. "neither",
  940. TypeError("test_name1 must be an instance of int, not NoneType."),
  941. ),
  942. (
  943. 1,
  944. "test_name1",
  945. (float, bool),
  946. 2,
  947. 4,
  948. "neither",
  949. TypeError("test_name1 must be an instance of {float, bool}, not int."),
  950. ),
  951. (
  952. 1,
  953. "test_name2",
  954. int,
  955. 2,
  956. 4,
  957. "neither",
  958. ValueError("test_name2 == 1, must be > 2."),
  959. ),
  960. (
  961. 5,
  962. "test_name3",
  963. int,
  964. 2,
  965. 4,
  966. "neither",
  967. ValueError("test_name3 == 5, must be < 4."),
  968. ),
  969. (
  970. 2,
  971. "test_name4",
  972. int,
  973. 2,
  974. 4,
  975. "right",
  976. ValueError("test_name4 == 2, must be > 2."),
  977. ),
  978. (
  979. 4,
  980. "test_name5",
  981. int,
  982. 2,
  983. 4,
  984. "left",
  985. ValueError("test_name5 == 4, must be < 4."),
  986. ),
  987. (
  988. 4,
  989. "test_name6",
  990. int,
  991. 2,
  992. 4,
  993. "bad parameter value",
  994. ValueError(
  995. "Unknown value for `include_boundaries`: 'bad parameter value'. "
  996. "Possible values are: ('left', 'right', 'both', 'neither')."
  997. ),
  998. ),
  999. (
  1000. 4,
  1001. "test_name7",
  1002. int,
  1003. None,
  1004. 4,
  1005. "left",
  1006. ValueError(
  1007. "`include_boundaries`='left' without specifying explicitly `min_val` "
  1008. "is inconsistent."
  1009. ),
  1010. ),
  1011. (
  1012. 4,
  1013. "test_name8",
  1014. int,
  1015. 2,
  1016. None,
  1017. "right",
  1018. ValueError(
  1019. "`include_boundaries`='right' without specifying explicitly `max_val` "
  1020. "is inconsistent."
  1021. ),
  1022. ),
  1023. ],
  1024. )
  1025. def test_check_scalar_invalid(
  1026. x, target_name, target_type, min_val, max_val, include_boundaries, err_msg
  1027. ):
  1028. """Test that check_scalar returns the right error if a wrong input is
  1029. given"""
  1030. with pytest.raises(Exception) as raised_error:
  1031. check_scalar(
  1032. x,
  1033. target_name,
  1034. target_type=target_type,
  1035. min_val=min_val,
  1036. max_val=max_val,
  1037. include_boundaries=include_boundaries,
  1038. )
  1039. assert str(raised_error.value) == str(err_msg)
  1040. assert type(raised_error.value) == type(err_msg)
  1041. _psd_cases_valid = {
  1042. "nominal": ((1, 2), np.array([1, 2]), None, ""),
  1043. "nominal_np_array": (np.array([1, 2]), np.array([1, 2]), None, ""),
  1044. "insignificant_imag": (
  1045. (5, 5e-5j),
  1046. np.array([5, 0]),
  1047. PositiveSpectrumWarning,
  1048. "There are imaginary parts in eigenvalues \\(1e\\-05 of the maximum real part",
  1049. ),
  1050. "insignificant neg": ((5, -5e-5), np.array([5, 0]), PositiveSpectrumWarning, ""),
  1051. "insignificant neg float32": (
  1052. np.array([1, -1e-6], dtype=np.float32),
  1053. np.array([1, 0], dtype=np.float32),
  1054. PositiveSpectrumWarning,
  1055. "There are negative eigenvalues \\(1e\\-06 of the maximum positive",
  1056. ),
  1057. "insignificant neg float64": (
  1058. np.array([1, -1e-10], dtype=np.float64),
  1059. np.array([1, 0], dtype=np.float64),
  1060. PositiveSpectrumWarning,
  1061. "There are negative eigenvalues \\(1e\\-10 of the maximum positive",
  1062. ),
  1063. "insignificant pos": (
  1064. (5, 4e-12),
  1065. np.array([5, 0]),
  1066. PositiveSpectrumWarning,
  1067. "the largest eigenvalue is more than 1e\\+12 times the smallest",
  1068. ),
  1069. }
  1070. @pytest.mark.parametrize(
  1071. "lambdas, expected_lambdas, w_type, w_msg",
  1072. list(_psd_cases_valid.values()),
  1073. ids=list(_psd_cases_valid.keys()),
  1074. )
  1075. @pytest.mark.parametrize("enable_warnings", [True, False])
  1076. def test_check_psd_eigenvalues_valid(
  1077. lambdas, expected_lambdas, w_type, w_msg, enable_warnings
  1078. ):
  1079. # Test that ``_check_psd_eigenvalues`` returns the right output for valid
  1080. # input, possibly raising the right warning
  1081. if not enable_warnings:
  1082. w_type = None
  1083. if w_type is None:
  1084. with warnings.catch_warnings():
  1085. warnings.simplefilter("error", PositiveSpectrumWarning)
  1086. lambdas_fixed = _check_psd_eigenvalues(
  1087. lambdas, enable_warnings=enable_warnings
  1088. )
  1089. else:
  1090. with pytest.warns(w_type, match=w_msg):
  1091. lambdas_fixed = _check_psd_eigenvalues(
  1092. lambdas, enable_warnings=enable_warnings
  1093. )
  1094. assert_allclose(expected_lambdas, lambdas_fixed)
  1095. _psd_cases_invalid = {
  1096. "significant_imag": (
  1097. (5, 5j),
  1098. ValueError,
  1099. "There are significant imaginary parts in eigenv",
  1100. ),
  1101. "all negative": (
  1102. (-5, -1),
  1103. ValueError,
  1104. "All eigenvalues are negative \\(maximum is -1",
  1105. ),
  1106. "significant neg": (
  1107. (5, -1),
  1108. ValueError,
  1109. "There are significant negative eigenvalues",
  1110. ),
  1111. "significant neg float32": (
  1112. np.array([3e-4, -2e-6], dtype=np.float32),
  1113. ValueError,
  1114. "There are significant negative eigenvalues",
  1115. ),
  1116. "significant neg float64": (
  1117. np.array([1e-5, -2e-10], dtype=np.float64),
  1118. ValueError,
  1119. "There are significant negative eigenvalues",
  1120. ),
  1121. }
  1122. @pytest.mark.parametrize(
  1123. "lambdas, err_type, err_msg",
  1124. list(_psd_cases_invalid.values()),
  1125. ids=list(_psd_cases_invalid.keys()),
  1126. )
  1127. def test_check_psd_eigenvalues_invalid(lambdas, err_type, err_msg):
  1128. # Test that ``_check_psd_eigenvalues`` raises the right error for invalid
  1129. # input
  1130. with pytest.raises(err_type, match=err_msg):
  1131. _check_psd_eigenvalues(lambdas)
  1132. def test_check_sample_weight():
  1133. # check array order
  1134. sample_weight = np.ones(10)[::2]
  1135. assert not sample_weight.flags["C_CONTIGUOUS"]
  1136. sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
  1137. assert sample_weight.flags["C_CONTIGUOUS"]
  1138. # check None input
  1139. sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
  1140. assert_allclose(sample_weight, np.ones(5))
  1141. # check numbers input
  1142. sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
  1143. assert_allclose(sample_weight, 2 * np.ones(5))
  1144. # check wrong number of dimensions
  1145. with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
  1146. _check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))
  1147. # check incorrect n_samples
  1148. msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
  1149. with pytest.raises(ValueError, match=msg):
  1150. _check_sample_weight(np.ones(4), X=np.ones((2, 2)))
  1151. # float32 dtype is preserved
  1152. X = np.ones((5, 2))
  1153. sample_weight = np.ones(5, dtype=np.float32)
  1154. sample_weight = _check_sample_weight(sample_weight, X)
  1155. assert sample_weight.dtype == np.float32
  1156. # int dtype will be converted to float64 instead
  1157. X = np.ones((5, 2), dtype=int)
  1158. sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
  1159. assert sample_weight.dtype == np.float64
  1160. # check negative weight when only_non_negative=True
  1161. X = np.ones((5, 2))
  1162. sample_weight = np.ones(_num_samples(X))
  1163. sample_weight[-1] = -10
  1164. err_msg = "Negative values in data passed to `sample_weight`"
  1165. with pytest.raises(ValueError, match=err_msg):
  1166. _check_sample_weight(sample_weight, X, only_non_negative=True)
  1167. @pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
  1168. def test_allclose_dense_sparse_equals(toarray):
  1169. base = np.arange(9).reshape(3, 3)
  1170. x, y = toarray(base), toarray(base)
  1171. assert _allclose_dense_sparse(x, y)
  1172. @pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])
  1173. def test_allclose_dense_sparse_not_equals(toarray):
  1174. base = np.arange(9).reshape(3, 3)
  1175. x, y = toarray(base), toarray(base + 1)
  1176. assert not _allclose_dense_sparse(x, y)
  1177. @pytest.mark.parametrize("toarray", [sp.csr_matrix, sp.csc_matrix])
  1178. def test_allclose_dense_sparse_raise(toarray):
  1179. x = np.arange(9).reshape(3, 3)
  1180. y = toarray(x + 1)
  1181. msg = "Can only compare two sparse matrices, not a sparse matrix and an array"
  1182. with pytest.raises(ValueError, match=msg):
  1183. _allclose_dense_sparse(x, y)
  1184. def test_deprecate_positional_args_warns_for_function():
  1185. @_deprecate_positional_args
  1186. def f1(a, b, *, c=1, d=1):
  1187. pass
  1188. with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
  1189. f1(1, 2, 3)
  1190. with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
  1191. f1(1, 2, 3, 4)
  1192. @_deprecate_positional_args
  1193. def f2(a=1, *, b=1, c=1, d=1):
  1194. pass
  1195. with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
  1196. f2(1, 2)
  1197. # The * is place before a keyword only argument without a default value
  1198. @_deprecate_positional_args
  1199. def f3(a, *, b, c=1, d=1):
  1200. pass
  1201. with pytest.warns(FutureWarning, match=r"Pass b=2 as keyword args"):
  1202. f3(1, 2)
  1203. def test_deprecate_positional_args_warns_for_function_version():
  1204. @_deprecate_positional_args(version="1.1")
  1205. def f1(a, *, b):
  1206. pass
  1207. with pytest.warns(
  1208. FutureWarning, match=r"From version 1.1 passing these as positional"
  1209. ):
  1210. f1(1, 2)
  1211. def test_deprecate_positional_args_warns_for_class():
  1212. class A1:
  1213. @_deprecate_positional_args
  1214. def __init__(self, a, b, *, c=1, d=1):
  1215. pass
  1216. with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
  1217. A1(1, 2, 3)
  1218. with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
  1219. A1(1, 2, 3, 4)
  1220. class A2:
  1221. @_deprecate_positional_args
  1222. def __init__(self, a=1, b=1, *, c=1, d=1):
  1223. pass
  1224. with pytest.warns(FutureWarning, match=r"Pass c=3 as keyword args"):
  1225. A2(1, 2, 3)
  1226. with pytest.warns(FutureWarning, match=r"Pass c=3, d=4 as keyword args"):
  1227. A2(1, 2, 3, 4)
  1228. @pytest.mark.parametrize("indices", [None, [1, 3]])
  1229. def test_check_fit_params(indices):
  1230. X = np.random.randn(4, 2)
  1231. fit_params = {
  1232. "list": [1, 2, 3, 4],
  1233. "array": np.array([1, 2, 3, 4]),
  1234. "sparse-col": sp.csc_matrix([1, 2, 3, 4]).T,
  1235. "sparse-row": sp.csc_matrix([1, 2, 3, 4]),
  1236. "scalar-int": 1,
  1237. "scalar-str": "xxx",
  1238. "None": None,
  1239. }
  1240. result = _check_fit_params(X, fit_params, indices)
  1241. indices_ = indices if indices is not None else list(range(X.shape[0]))
  1242. for key in ["sparse-row", "scalar-int", "scalar-str", "None"]:
  1243. assert result[key] is fit_params[key]
  1244. assert result["list"] == _safe_indexing(fit_params["list"], indices_)
  1245. assert_array_equal(result["array"], _safe_indexing(fit_params["array"], indices_))
  1246. assert_allclose_dense_sparse(
  1247. result["sparse-col"], _safe_indexing(fit_params["sparse-col"], indices_)
  1248. )
  1249. @pytest.mark.parametrize("sp_format", [True, "csr", "csc", "coo", "bsr"])
  1250. def test_c