PageRenderTime 27ms CodeModel.GetById 14ms RepoModel.GetById 0ms app.codeStats 0ms

/sklearn/preprocessing/tests/test_common.py

https://github.com/ogrisel/scikit-learn
Python | 165 lines | 123 code | 24 blank | 18 comment | 12 complexity | 77284fe61c17e1e48343ea476c153239 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import warnings
  2. import pytest
  3. import numpy as np
  4. from scipy import sparse
  5. from sklearn.datasets import load_iris
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.base import clone
  8. from sklearn.preprocessing import maxabs_scale
  9. from sklearn.preprocessing import minmax_scale
  10. from sklearn.preprocessing import scale
  11. from sklearn.preprocessing import power_transform
  12. from sklearn.preprocessing import quantile_transform
  13. from sklearn.preprocessing import robust_scale
  14. from sklearn.preprocessing import MaxAbsScaler
  15. from sklearn.preprocessing import MinMaxScaler
  16. from sklearn.preprocessing import StandardScaler
  17. from sklearn.preprocessing import PowerTransformer
  18. from sklearn.preprocessing import QuantileTransformer
  19. from sklearn.preprocessing import RobustScaler
  20. from sklearn.utils._testing import assert_array_equal
  21. from sklearn.utils._testing import assert_allclose
  22. iris = load_iris()
  23. def _get_valid_samples_by_column(X, col):
  24. """Get non NaN samples in column of X"""
  25. return X[:, [col]][~np.isnan(X[:, col])]
  26. @pytest.mark.parametrize(
  27. "est, func, support_sparse, strictly_positive, omit_kwargs",
  28. [(MaxAbsScaler(), maxabs_scale, True, False, []),
  29. (MinMaxScaler(), minmax_scale, False, False, ['clip']),
  30. (StandardScaler(), scale, False, False, []),
  31. (StandardScaler(with_mean=False), scale, True, False, []),
  32. (PowerTransformer('yeo-johnson'), power_transform, False, False, []),
  33. (PowerTransformer('box-cox'), power_transform, False, True, []),
  34. (QuantileTransformer(n_quantiles=10), quantile_transform, True, False,
  35. []),
  36. (RobustScaler(), robust_scale, False, False, []),
  37. (RobustScaler(with_centering=False), robust_scale, True, False, [])]
  38. )
  39. def test_missing_value_handling(est, func, support_sparse, strictly_positive,
  40. omit_kwargs):
  41. # check that the preprocessing method let pass nan
  42. rng = np.random.RandomState(42)
  43. X = iris.data.copy()
  44. n_missing = 50
  45. X[rng.randint(X.shape[0], size=n_missing),
  46. rng.randint(X.shape[1], size=n_missing)] = np.nan
  47. if strictly_positive:
  48. X += np.nanmin(X) + 0.1
  49. X_train, X_test = train_test_split(X, random_state=1)
  50. # sanity check
  51. assert not np.all(np.isnan(X_train), axis=0).any()
  52. assert np.any(np.isnan(X_train), axis=0).all()
  53. assert np.any(np.isnan(X_test), axis=0).all()
  54. X_test[:, 0] = np.nan # make sure this boundary case is tested
  55. with pytest.warns(None) as records:
  56. Xt = est.fit(X_train).transform(X_test)
  57. # ensure no warnings are raised
  58. assert len(records) == 0
  59. # missing values should still be missing, and only them
  60. assert_array_equal(np.isnan(Xt), np.isnan(X_test))
  61. # check that the function leads to the same results as the class
  62. with pytest.warns(None) as records:
  63. Xt_class = est.transform(X_train)
  64. assert len(records) == 0
  65. kwargs = est.get_params()
  66. # remove the parameters which should be omitted because they
  67. # are not defined in the sister function of the preprocessing class
  68. for kwarg in omit_kwargs:
  69. _ = kwargs.pop(kwarg)
  70. Xt_func = func(X_train, **kwargs)
  71. assert_array_equal(np.isnan(Xt_func), np.isnan(Xt_class))
  72. assert_allclose(Xt_func[~np.isnan(Xt_func)], Xt_class[~np.isnan(Xt_class)])
  73. # check that the inverse transform keep NaN
  74. Xt_inv = est.inverse_transform(Xt)
  75. assert_array_equal(np.isnan(Xt_inv), np.isnan(X_test))
  76. # FIXME: we can introduce equal_nan=True in recent version of numpy.
  77. # For the moment which just check that non-NaN values are almost equal.
  78. assert_allclose(Xt_inv[~np.isnan(Xt_inv)], X_test[~np.isnan(X_test)])
  79. for i in range(X.shape[1]):
  80. # train only on non-NaN
  81. est.fit(_get_valid_samples_by_column(X_train, i))
  82. # check transforming with NaN works even when training without NaN
  83. with pytest.warns(None) as records:
  84. Xt_col = est.transform(X_test[:, [i]])
  85. assert len(records) == 0
  86. assert_allclose(Xt_col, Xt[:, [i]])
  87. # check non-NaN is handled as before - the 1st column is all nan
  88. if not np.isnan(X_test[:, i]).all():
  89. Xt_col_nonan = est.transform(
  90. _get_valid_samples_by_column(X_test, i))
  91. assert_array_equal(Xt_col_nonan,
  92. Xt_col[~np.isnan(Xt_col.squeeze())])
  93. if support_sparse:
  94. est_dense = clone(est)
  95. est_sparse = clone(est)
  96. with pytest.warns(None) as records:
  97. Xt_dense = est_dense.fit(X_train).transform(X_test)
  98. Xt_inv_dense = est_dense.inverse_transform(Xt_dense)
  99. assert len(records) == 0
  100. for sparse_constructor in (sparse.csr_matrix, sparse.csc_matrix,
  101. sparse.bsr_matrix, sparse.coo_matrix,
  102. sparse.dia_matrix, sparse.dok_matrix,
  103. sparse.lil_matrix):
  104. # check that the dense and sparse inputs lead to the same results
  105. # precompute the matrix to avoid catching side warnings
  106. X_train_sp = sparse_constructor(X_train)
  107. X_test_sp = sparse_constructor(X_test)
  108. with pytest.warns(None) as records:
  109. warnings.simplefilter('ignore', PendingDeprecationWarning)
  110. Xt_sp = est_sparse.fit(X_train_sp).transform(X_test_sp)
  111. assert len(records) == 0
  112. assert_allclose(Xt_sp.A, Xt_dense)
  113. with pytest.warns(None) as records:
  114. warnings.simplefilter('ignore', PendingDeprecationWarning)
  115. Xt_inv_sp = est_sparse.inverse_transform(Xt_sp)
  116. assert len(records) == 0
  117. assert_allclose(Xt_inv_sp.A, Xt_inv_dense)
  118. @pytest.mark.parametrize(
  119. "est, func",
  120. [(MaxAbsScaler(), maxabs_scale),
  121. (MinMaxScaler(), minmax_scale),
  122. (StandardScaler(), scale),
  123. (StandardScaler(with_mean=False), scale),
  124. (PowerTransformer('yeo-johnson'), power_transform),
  125. (PowerTransformer('box-cox'), power_transform,),
  126. (QuantileTransformer(n_quantiles=3), quantile_transform),
  127. (RobustScaler(), robust_scale),
  128. (RobustScaler(with_centering=False), robust_scale)]
  129. )
  130. def test_missing_value_pandas_na_support(est, func):
  131. # Test pandas IntegerArray with pd.NA
  132. pd = pytest.importorskip('pandas', minversion="1.0")
  133. X = np.array([[1, 2, 3, np.nan, np.nan, 4, 5, 1],
  134. [np.nan, np.nan, 8, 4, 6, np.nan, np.nan, 8],
  135. [1, 2, 3, 4, 5, 6, 7, 8]]).T
  136. # Creates dataframe with IntegerArrays with pd.NA
  137. X_df = pd.DataFrame(X, dtype="Int16", columns=['a', 'b', 'c'])
  138. X_df['c'] = X_df['c'].astype('int')
  139. X_trans = est.fit_transform(X)
  140. X_df_trans = est.fit_transform(X_df)
  141. assert_allclose(X_trans, X_df_trans)