PageRenderTime 27ms CodeModel.GetById 28ms RepoModel.GetById 1ms app.codeStats 0ms

/sklearn/preprocessing/tests/test_common.py

https://github.com/scikit-learn/scikit-learn
Python | 183 lines | 138 code | 27 blank | 18 comment | 12 complexity | 246502bf66fca4287f75e1a7e23e392b MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import warnings
  2. import pytest
  3. import numpy as np
  4. from scipy import sparse
  5. from sklearn.datasets import load_iris
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.base import clone
  8. from sklearn.preprocessing import maxabs_scale
  9. from sklearn.preprocessing import minmax_scale
  10. from sklearn.preprocessing import scale
  11. from sklearn.preprocessing import power_transform
  12. from sklearn.preprocessing import quantile_transform
  13. from sklearn.preprocessing import robust_scale
  14. from sklearn.preprocessing import MaxAbsScaler
  15. from sklearn.preprocessing import MinMaxScaler
  16. from sklearn.preprocessing import StandardScaler
  17. from sklearn.preprocessing import PowerTransformer
  18. from sklearn.preprocessing import QuantileTransformer
  19. from sklearn.preprocessing import RobustScaler
  20. from sklearn.utils._testing import assert_array_equal
  21. from sklearn.utils._testing import assert_allclose
  22. iris = load_iris()
  23. def _get_valid_samples_by_column(X, col):
  24. """Get non NaN samples in column of X"""
  25. return X[:, [col]][~np.isnan(X[:, col])]
  26. @pytest.mark.parametrize(
  27. "est, func, support_sparse, strictly_positive, omit_kwargs",
  28. [
  29. (MaxAbsScaler(), maxabs_scale, True, False, []),
  30. (MinMaxScaler(), minmax_scale, False, False, ["clip"]),
  31. (StandardScaler(), scale, False, False, []),
  32. (StandardScaler(with_mean=False), scale, True, False, []),
  33. (PowerTransformer("yeo-johnson"), power_transform, False, False, []),
  34. (PowerTransformer("box-cox"), power_transform, False, True, []),
  35. (QuantileTransformer(n_quantiles=10), quantile_transform, True, False, []),
  36. (RobustScaler(), robust_scale, False, False, []),
  37. (RobustScaler(with_centering=False), robust_scale, True, False, []),
  38. ],
  39. )
  40. def test_missing_value_handling(
  41. est, func, support_sparse, strictly_positive, omit_kwargs
  42. ):
  43. # check that the preprocessing method let pass nan
  44. rng = np.random.RandomState(42)
  45. X = iris.data.copy()
  46. n_missing = 50
  47. X[
  48. rng.randint(X.shape[0], size=n_missing), rng.randint(X.shape[1], size=n_missing)
  49. ] = np.nan
  50. if strictly_positive:
  51. X += np.nanmin(X) + 0.1
  52. X_train, X_test = train_test_split(X, random_state=1)
  53. # sanity check
  54. assert not np.all(np.isnan(X_train), axis=0).any()
  55. assert np.any(np.isnan(X_train), axis=0).all()
  56. assert np.any(np.isnan(X_test), axis=0).all()
  57. X_test[:, 0] = np.nan # make sure this boundary case is tested
  58. with warnings.catch_warnings():
  59. warnings.simplefilter("error", RuntimeWarning)
  60. Xt = est.fit(X_train).transform(X_test)
  61. # ensure no warnings are raised
  62. # missing values should still be missing, and only them
  63. assert_array_equal(np.isnan(Xt), np.isnan(X_test))
  64. # check that the function leads to the same results as the class
  65. with warnings.catch_warnings():
  66. warnings.simplefilter("error", RuntimeWarning)
  67. Xt_class = est.transform(X_train)
  68. kwargs = est.get_params()
  69. # remove the parameters which should be omitted because they
  70. # are not defined in the counterpart function of the preprocessing class
  71. for kwarg in omit_kwargs:
  72. _ = kwargs.pop(kwarg)
  73. Xt_func = func(X_train, **kwargs)
  74. assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class))
  75. assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)])
  76. # check that the inverse transform keep NaN
  77. Xt_inv = est.inverse_transform(Xt)
  78. assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test))
  79. # FIXME: we can introduce equal_nan=True in recent version of numpy.
  80. # For the moment which just check that non-NaN values are almost equal.
  81. assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)])
  82. for i in range(X.shape[1]):
  83. # train only on non-NaN
  84. est.fit(_get_valid_samples_by_column(X_train, i))
  85. # check transforming with NaN works even when training without NaN
  86. with warnings.catch_warnings():
  87. warnings.simplefilter("error", RuntimeWarning)
  88. Xt_col = est.transform(X_test[:, [i]])
  89. assert_allclose(Xt_col, Xt[:, [i]])
  90. # check non-NaN is handled as before - the 1st column is all nan
  91. if not np.isnan(X_test[:, i]).all():
  92. Xt_col_nonan = est.transform(_get_valid_samples_by_column(X_test, i))
  93. assert_array_equal(Xt_col_nonan, Xt_col[~np.isnan(Xt_col.squeeze())])
  94. if support_sparse:
  95. est_dense = clone(est)
  96. est_sparse = clone(est)
  97. with warnings.catch_warnings():
  98. warnings.simplefilter("error", RuntimeWarning)
  99. Xt_dense = est_dense.fit(X_train).transform(X_test)
  100. Xt_inv_dense = est_dense.inverse_transform(Xt_dense)
  101. for sparse_constructor in (
  102. sparse.csr_matrix,
  103. sparse.csc_matrix,
  104. sparse.bsr_matrix,
  105. sparse.coo_matrix,
  106. sparse.dia_matrix,
  107. sparse.dok_matrix,
  108. sparse.lil_matrix,
  109. ):
  110. # check that the dense and sparse inputs lead to the same results
  111. # precompute the matrix to avoid catching side warnings
  112. X_train_sp = sparse_constructor(X_train)
  113. X_test_sp = sparse_constructor(X_test)
  114. with warnings.catch_warnings():
  115. warnings.simplefilter("ignore", PendingDeprecationWarning)
  116. warnings.simplefilter("error", RuntimeWarning)
  117. Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp)
  118. assert_allclose(Xt_sp.A, Xt_dense)
  119. with warnings.catch_warnings():
  120. warnings.simplefilter("ignore", PendingDeprecationWarning)
  121. warnings.simplefilter("error", RuntimeWarning)
  122. Xt_inv_sp = est_sparse.inverse_transform(Xt_sp)
  123. assert_allclose(Xt_inv_sp.A, Xt_inv_dense)
  124. @pytest.mark.parametrize(
  125. "est, func",
  126. [
  127. (MaxAbsScaler(), maxabs_scale),
  128. (MinMaxScaler(), minmax_scale),
  129. (StandardScaler(), scale),
  130. (StandardScaler(with_mean=False), scale),
  131. (PowerTransformer("yeo-johnson"), power_transform),
  132. (
  133. PowerTransformer("box-cox"),
  134. power_transform,
  135. ),
  136. (QuantileTransformer(n_quantiles=3), quantile_transform),
  137. (RobustScaler(), robust_scale),
  138. (RobustScaler(with_centering=False), robust_scale),
  139. ],
  140. )
  141. def test_missing_value_pandas_na_support(est, func):
  142. # Test pandas IntegerArray with pd.NA
  143. pd = pytest.importorskip("pandas")
  144. X = np.array(
  145. [
  146. [1, 2, 3, np.nan, np.nan, 4, 5, 1],
  147. [np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8],
  148. [1, 2, 3, 4, 5, 6, 7, 8],
  149. ]
  150. ).T
  151. # Creates dataframe with IntegerArrays with pd.NA
  152. X_df = pd.DataFrame(X, dtype="Int16", columns=["a", "b", "c"])
  153. X_df["c"] = X_df["c"].astype("int")
  154. X_trans = est.fit_transform(X)
  155. X_df_trans = est.fit_transform(X_df)
  156. assert_allclose(X_trans, X_df_trans)