chemotools 0.0.27__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. chemotools/augmentation/__init__.py +16 -0
  2. chemotools/augmentation/baseline_shift.py +119 -0
  3. chemotools/augmentation/exponential_noise.py +117 -0
  4. chemotools/augmentation/index_shift.py +120 -0
  5. chemotools/augmentation/normal_noise.py +118 -0
  6. chemotools/augmentation/spectrum_scale.py +120 -0
  7. chemotools/augmentation/uniform_noise.py +124 -0
  8. chemotools/baseline/__init__.py +20 -8
  9. chemotools/baseline/{air_pls.py → _air_pls.py} +20 -32
  10. chemotools/baseline/{ar_pls.py → _ar_pls.py} +18 -31
  11. chemotools/baseline/{constant_baseline_correction.py → _constant_baseline_correction.py} +22 -30
  12. chemotools/baseline/{cubic_spline_correction.py → _cubic_spline_correction.py} +26 -19
  13. chemotools/baseline/{linear_correction.py → _linear_correction.py} +19 -28
  14. chemotools/baseline/{non_negative.py → _non_negative.py} +15 -23
  15. chemotools/baseline/{polynomial_correction.py → _polynomial_correction.py} +29 -31
  16. chemotools/baseline/{subtract_reference.py → _subtract_reference.py} +23 -27
  17. chemotools/datasets/__init__.py +3 -0
  18. chemotools/datasets/_base.py +85 -15
  19. chemotools/datasets/data/coffee_labels.csv +61 -0
  20. chemotools/datasets/data/coffee_spectra.csv +61 -0
  21. chemotools/derivative/__init__.py +4 -2
  22. chemotools/derivative/{norris_william.py → _norris_william.py} +17 -24
  23. chemotools/derivative/{savitzky_golay.py → _savitzky_golay.py} +26 -36
  24. chemotools/feature_selection/__init__.py +4 -0
  25. chemotools/{variable_selection/select_features.py → feature_selection/_index_selector.py} +32 -56
  26. chemotools/{variable_selection/range_cut.py → feature_selection/_range_cut.py} +25 -50
  27. chemotools/scale/__init__.py +5 -3
  28. chemotools/scale/{min_max_scaler.py → _min_max_scaler.py} +20 -27
  29. chemotools/scale/{norm_scaler.py → _norm_scaler.py} +18 -25
  30. chemotools/scale/{point_scaler.py → _point_scaler.py} +27 -32
  31. chemotools/scatter/__init__.py +13 -4
  32. chemotools/scatter/{extended_multiplicative_scatter_correction.py → _extended_multiplicative_scatter_correction.py} +19 -28
  33. chemotools/scatter/{multiplicative_scatter_correction.py → _multiplicative_scatter_correction.py} +19 -17
  34. chemotools/scatter/{robust_normal_variate.py → _robust_normal_variate.py} +15 -23
  35. chemotools/scatter/{standard_normal_variate.py → _standard_normal_variate.py} +21 -26
  36. chemotools/smooth/__init__.py +6 -4
  37. chemotools/smooth/{mean_filter.py → _mean_filter.py} +18 -25
  38. chemotools/smooth/{median_filter.py → _median_filter.py} +32 -24
  39. chemotools/smooth/{savitzky_golay_filter.py → _savitzky_golay_filter.py} +22 -24
  40. chemotools/smooth/{whittaker_smooth.py → _whittaker_smooth.py} +24 -29
  41. {chemotools-0.0.27.dist-info → chemotools-0.1.6.dist-info}/METADATA +19 -16
  42. chemotools-0.1.6.dist-info/RECORD +51 -0
  43. {chemotools-0.0.27.dist-info → chemotools-0.1.6.dist-info}/WHEEL +1 -2
  44. chemotools/utils/check_inputs.py +0 -14
  45. chemotools/variable_selection/__init__.py +0 -2
  46. chemotools-0.0.27.dist-info/RECORD +0 -49
  47. chemotools-0.0.27.dist-info/top_level.txt +0 -2
  48. tests/__init__.py +0 -0
  49. tests/fixtures.py +0 -89
  50. tests/test_datasets.py +0 -30
  51. tests/test_functionality.py +0 -616
  52. tests/test_sklearn_compliance.py +0 -220
  53. {chemotools-0.0.27.dist-info → chemotools-0.1.6.dist-info}/LICENSE +0 -0
@@ -1,18 +1,18 @@
1
+ from typing import Optional
2
+
1
3
  import numpy as np
2
4
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
3
- from sklearn.utils.validation import check_is_fitted
4
-
5
- from chemotools.utils.check_inputs import check_input
5
+ from sklearn.utils.validation import check_is_fitted, validate_data
6
6
 
7
7
 
8
- class PointScaler(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
8
+ class PointScaler(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
9
9
  """
10
- A transformer that scales the input data by the intensity value at a given point.
10
+ A transformer that scales the input data by the intensity value at a given point.
11
11
  The point can be specified by an index or by a wavenumber.
12
12
 
13
13
  Parameters
14
14
  ----------
15
- point : int,
15
+ point : int,
16
16
  The point to scale the data by. It can be an index or a wavenumber.
17
17
 
18
18
  wavenumber : array-like, optional
@@ -25,12 +25,6 @@ class PointScaler(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
25
25
  point_index_ : int
26
26
  The index of the point to scale the data by. It is 0 if the wavenumbers are not provided.
27
27
 
28
- n_features_in_ : int
29
- The number of features in the input data.
30
-
31
- _is_fitted : bool
32
- Whether the transformer has been fitted to data.
33
-
34
28
  Methods
35
29
  -------
36
30
  fit(X, y=None)
@@ -39,11 +33,11 @@ class PointScaler(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
39
33
  transform(X, y=0, copy=True)
40
34
  Transform the input data by scaling by the value at a given Point.
41
35
  """
42
- def __init__(self, point: int = 0, wavenumbers: np.ndarray = None):
36
+
37
+ def __init__(self, point: int = 0, wavenumbers: Optional[np.ndarray] = None):
43
38
  self.point = point
44
39
  self.wavenumbers = wavenumbers
45
40
 
46
-
47
41
  def fit(self, X: np.ndarray, y=None) -> "PointScaler":
48
42
  """
49
43
  Fit the transformer to the input data.
@@ -62,21 +56,15 @@ class PointScaler(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
62
56
  The fitted transformer.
63
57
  """
64
58
  # Check that X is a 2D array and has only finite values
65
- X = check_input(X)
66
-
67
- # Set the number of features
68
- self.n_features_in_ = X.shape[1]
69
-
70
- # Set the fitted attribute to True
71
- self._is_fitted = True
72
-
59
+ X = validate_data(
60
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
61
+ )
73
62
  # Set the point index
74
63
  if self.wavenumbers is None:
75
64
  self.point_index_ = self.point
76
65
  else:
77
66
  self.point_index_ = self._find_index(self.point)
78
67
 
79
-
80
68
  return self
81
69
 
82
70
  def transform(self, X: np.ndarray, y=None) -> np.ndarray:
@@ -97,24 +85,31 @@ class PointScaler(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
97
85
  The transformed data.
98
86
  """
99
87
  # Check that the estimator is fitted
100
- check_is_fitted(self, "_is_fitted")
88
+ check_is_fitted(self, "point_index_")
101
89
 
102
90
  # Check that X is a 2D array and has only finite values
103
- X = check_input(X)
104
- X_ = X.copy()
91
+ X_ = validate_data(
92
+ self,
93
+ X,
94
+ y="no_validation",
95
+ ensure_2d=True,
96
+ copy=True,
97
+ reset=False,
98
+ dtype=np.float64,
99
+ )
105
100
 
106
101
  # Check that the number of features is the same as the fitted data
107
102
  if X_.shape[1] != self.n_features_in_:
108
- raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
103
+ raise ValueError(
104
+ f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
105
+ )
109
106
 
110
107
  # Scale the data by Point
111
108
  for i, x in enumerate(X_):
112
109
  X_[i] = x / x[self.point_index_]
113
-
110
+
114
111
  return X_.reshape(-1, 1) if X_.ndim == 1 else X_
115
-
112
+
116
113
  def _find_index(self, target: float) -> int:
117
- if self.wavenumbers is None:
118
- return target
119
114
  wavenumbers = np.array(self.wavenumbers)
120
- return np.argmin(np.abs(wavenumbers - target))
115
+ return int(np.argmin(np.abs(wavenumbers - target)))
@@ -1,4 +1,13 @@
1
- from .extended_multiplicative_scatter_correction import ExtendedMultiplicativeScatterCorrection
2
- from .multiplicative_scatter_correction import MultiplicativeScatterCorrection
3
- from .robust_normal_variate import RobustNormalVariate
4
- from .standard_normal_variate import StandardNormalVariate
1
+ from ._extended_multiplicative_scatter_correction import (
2
+ ExtendedMultiplicativeScatterCorrection,
3
+ )
4
+ from ._multiplicative_scatter_correction import MultiplicativeScatterCorrection
5
+ from ._robust_normal_variate import RobustNormalVariate
6
+ from ._standard_normal_variate import StandardNormalVariate
7
+
8
+ __all__ = [
9
+ "ExtendedMultiplicativeScatterCorrection",
10
+ "MultiplicativeScatterCorrection",
11
+ "RobustNormalVariate",
12
+ "StandardNormalVariate",
13
+ ]
@@ -1,13 +1,12 @@
1
+ from typing import Optional
2
+
1
3
  import numpy as np
2
4
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
3
- from sklearn.preprocessing import StandardScaler
4
- from sklearn.utils.validation import check_is_fitted
5
-
6
- from chemotools.utils.check_inputs import check_input
5
+ from sklearn.utils.validation import check_is_fitted, validate_data
7
6
 
8
7
 
9
8
  class ExtendedMultiplicativeScatterCorrection(
10
- OneToOneFeatureMixin, BaseEstimator, TransformerMixin
9
+ TransformerMixin, OneToOneFeatureMixin, BaseEstimator
11
10
  ):
12
11
  """Extended multiplicative scatter correction (EMSC) is a preprocessing technique for
13
12
  removing non linear scatter effects from spectra. It is based on fitting a polynomial
@@ -37,8 +36,6 @@ class ExtendedMultiplicativeScatterCorrection(
37
36
  ----------
38
37
  reference_ : np.ndarray
39
38
  The reference spectrum used for the correction.
40
- n_features_in_ : int
41
- The number of features in the training data.
42
39
 
43
40
  References
44
41
  ----------
@@ -51,11 +48,11 @@ class ExtendedMultiplicativeScatterCorrection(
51
48
 
52
49
  def __init__(
53
50
  self,
54
- reference: np.ndarray = None,
51
+ reference: Optional[np.ndarray] = None,
55
52
  use_mean: bool = True,
56
53
  use_median: bool = False,
57
54
  order: int = 2,
58
- weights: np.ndarray = None,
55
+ weights: Optional[np.ndarray] = None,
59
56
  ):
60
57
  self.reference = reference
61
58
  self.use_mean = use_mean
@@ -82,13 +79,9 @@ class ExtendedMultiplicativeScatterCorrection(
82
79
  The fitted transformer.
83
80
  """
84
81
  # Check that X is a 2D array and has only finite values
85
- X = check_input(X)
86
-
87
- # Set the number of features
88
- self.n_features_in_ = X.shape[1]
89
-
90
- # Set the fitted attribute to True
91
- self._is_fitted = True
82
+ X = validate_data(
83
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
84
+ )
92
85
 
93
86
  # Check that the length of the reference is the same as the number of features
94
87
  if self.reference is not None:
@@ -146,20 +139,18 @@ class ExtendedMultiplicativeScatterCorrection(
146
139
  The transformed data.
147
140
  """
148
141
  # Check that the estimator is fitted
149
- check_is_fitted(self, "_is_fitted")
142
+ check_is_fitted(self, "n_features_in_")
150
143
 
151
144
  # Check that X is a 2D array and has only finite values
152
- X = check_input(X)
153
- X_ = X.copy()
154
-
155
- # Check that the number of features is the same as the fitted data
156
- if X_.shape[1] != self.n_features_in_:
157
- raise ValueError(
158
- f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
159
- )
160
-
161
- # Calculate the extended multiplicative scatter correction
162
- X_ = X.copy()
145
+ X_ = validate_data(
146
+ self,
147
+ X,
148
+ y="no_validation",
149
+ ensure_2d=True,
150
+ copy=True,
151
+ reset=False,
152
+ dtype=np.float64,
153
+ )
163
154
 
164
155
  if self.weights is None:
165
156
  for i, x in enumerate(X_):
@@ -1,12 +1,12 @@
1
+ from typing import Optional
2
+
1
3
  import numpy as np
2
4
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
3
- from sklearn.utils.validation import check_is_fitted
4
-
5
- from chemotools.utils.check_inputs import check_input
5
+ from sklearn.utils.validation import check_is_fitted, validate_data
6
6
 
7
7
 
8
8
  class MultiplicativeScatterCorrection(
9
- OneToOneFeatureMixin, BaseEstimator, TransformerMixin
9
+ TransformerMixin, OneToOneFeatureMixin, BaseEstimator
10
10
  ):
11
11
  """Multiplicative scatter correction (MSC) is a preprocessing technique for
12
12
  removing scatter effects from spectra. It is based on fitting a linear
@@ -39,10 +39,10 @@ class MultiplicativeScatterCorrection(
39
39
 
40
40
  def __init__(
41
41
  self,
42
- reference: np.ndarray = None,
42
+ reference: Optional[np.ndarray] = None,
43
43
  use_mean: bool = True,
44
44
  use_median: bool = False,
45
- weights: np.ndarray = None,
45
+ weights: Optional[np.ndarray] = None,
46
46
  ):
47
47
  self.reference = reference
48
48
  self.use_mean = use_mean
@@ -68,14 +68,9 @@ class MultiplicativeScatterCorrection(
68
68
  The fitted transformer.
69
69
  """
70
70
  # Check that X is a 2D array and has only finite values
71
- X = check_input(X)
72
-
73
- # Set the number of features
74
- self.n_features_in_ = X.shape[1]
75
-
76
- # Set the fitted attribute to True
77
- self._is_fitted = True
78
-
71
+ X = validate_data(
72
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
73
+ )
79
74
  # Check that the length of the reference is the same as the number of features
80
75
  if self.reference is not None:
81
76
  if len(self.reference) != self.n_features_in_:
@@ -129,11 +124,18 @@ class MultiplicativeScatterCorrection(
129
124
  The transformed data.
130
125
  """
131
126
  # Check that the estimator is fitted
132
- check_is_fitted(self, "_is_fitted")
127
+ check_is_fitted(self, "n_features_in_")
133
128
 
134
129
  # Check that X is a 2D array and has only finite values
135
- X = check_input(X)
136
- X_ = X.copy()
130
+ X_ = validate_data(
131
+ self,
132
+ X,
133
+ y="no_validation",
134
+ ensure_2d=True,
135
+ copy=True,
136
+ reset=False,
137
+ dtype=np.float64,
138
+ )
137
139
 
138
140
  # Check that the number of features is the same as the fitted data
139
141
  if X_.shape[1] != self.n_features_in_:
@@ -1,11 +1,9 @@
1
1
  import numpy as np
2
2
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
3
- from sklearn.utils.validation import check_is_fitted
3
+ from sklearn.utils.validation import check_is_fitted, validate_data
4
4
 
5
- from chemotools.utils.check_inputs import check_input
6
5
 
7
-
8
- class RobustNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
6
+ class RobustNormalVariate(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
9
7
  """
10
8
  A transformer that calculates the robust normal variate of the input data.
11
9
 
@@ -15,14 +13,6 @@ class RobustNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixin)
15
13
  The percentile to use for the robust normal variate. The value should be
16
14
  between 0 and 100. The default is 25.
17
15
 
18
- Attributes
19
- ----------
20
- n_features_in_ : int
21
- The number of features in the input data.
22
-
23
- _is_fitted : bool
24
- Whether the transformer has been fitted to data.
25
-
26
16
  Methods
27
17
  -------
28
18
  fit(X, y=None)
@@ -58,14 +48,9 @@ class RobustNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixin)
58
48
  The fitted transformer.
59
49
  """
60
50
  # Check that X is a 2D array and has only finite values
61
- X = check_input(X)
62
-
63
- # Set the number of features
64
- self.n_features_in_ = X.shape[1]
65
-
66
- # Set the fitted attribute to True
67
- self._is_fitted = True
68
-
51
+ X = validate_data(
52
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
53
+ )
69
54
  return self
70
55
 
71
56
  def transform(self, X: np.ndarray, y=None) -> np.ndarray:
@@ -86,11 +71,18 @@ class RobustNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixin)
86
71
  The transformed data.
87
72
  """
88
73
  # Check that the estimator is fitted
89
- check_is_fitted(self, "_is_fitted")
74
+ check_is_fitted(self, "n_features_in_")
90
75
 
91
76
  # Check that X is a 2D array and has only finite values
92
- X = check_input(X)
93
- X_ = X.copy()
77
+ X_ = validate_data(
78
+ self,
79
+ X,
80
+ y="no_validation",
81
+ ensure_2d=True,
82
+ copy=True,
83
+ reset=False,
84
+ dtype=np.float64,
85
+ )
94
86
 
95
87
  # Check that the number of features is the same as the fitted data
96
88
  if X_.shape[1] != self.n_features_in_:
@@ -1,22 +1,12 @@
1
1
  import numpy as np
2
2
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
3
- from sklearn.utils.validation import check_is_fitted
3
+ from sklearn.utils.validation import check_is_fitted, validate_data
4
4
 
5
- from chemotools.utils.check_inputs import check_input
6
5
 
7
-
8
- class StandardNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
6
+ class StandardNormalVariate(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
9
7
  """
10
8
  A transformer that calculates the standard normal variate of the input data.
11
9
 
12
- Attributes
13
- ----------
14
- n_features_in_ : int
15
- The number of features in the input data.
16
-
17
- _is_fitted : bool
18
- Whether the transformer has been fitted to data.
19
-
20
10
  Methods
21
11
  -------
22
12
  fit(X, y=None)
@@ -25,10 +15,11 @@ class StandardNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
25
15
  transform(X, y=0, copy=True)
26
16
  Transform the input data by calculating the standard normal variate.
27
17
  """
18
+
28
19
  def fit(self, X: np.ndarray, y=None) -> "StandardNormalVariate":
29
20
  """
30
21
  Fit the transformer to the input data.
31
-
22
+
32
23
  Parameters
33
24
  ----------
34
25
  X : np.ndarray of shape (n_samples, n_features)
@@ -43,14 +34,9 @@ class StandardNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
43
34
  The fitted transformer.
44
35
  """
45
36
  # Check that X is a 2D array and has only finite values
46
- X = check_input(X)
47
-
48
- # Set the number of features
49
- self.n_features_in_ = X.shape[1]
50
-
51
- # Set the fitted attribute to True
52
- self._is_fitted = True
53
-
37
+ X = validate_data(
38
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
39
+ )
54
40
  return self
55
41
 
56
42
  def transform(self, X: np.ndarray, y=None) -> np.ndarray:
@@ -71,15 +57,24 @@ class StandardNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
71
57
  The transformed data.
72
58
  """
73
59
  # Check that the estimator is fitted
74
- check_is_fitted(self, "_is_fitted")
60
+ check_is_fitted(self, "n_features_in_")
75
61
 
76
62
  # Check that X is a 2D array and has only finite values
77
- X = check_input(X)
78
- X_ = X.copy()
63
+ X_ = validate_data(
64
+ self,
65
+ X,
66
+ y="no_validation",
67
+ ensure_2d=True,
68
+ copy=True,
69
+ reset=False,
70
+ dtype=np.float64,
71
+ )
79
72
 
80
73
  # Check that the number of features is the same as the fitted data
81
74
  if X_.shape[1] != self.n_features_in_:
82
- raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
75
+ raise ValueError(
76
+ f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
77
+ )
83
78
 
84
79
  # Calculate the standard normal variate
85
80
  for i, x in enumerate(X_):
@@ -88,4 +83,4 @@ class StandardNormalVariate(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
88
83
  return X_.reshape(-1, 1) if X_.ndim == 1 else X_
89
84
 
90
85
  def _calculate_standard_normal_variate(self, x) -> np.ndarray:
91
- return (x - x.mean()) / x.std()
86
+ return (x - x.mean()) / x.std()
@@ -1,4 +1,6 @@
1
- from .mean_filter import MeanFilter
2
- from .median_filter import MedianFilter
3
- from .savitzky_golay_filter import SavitzkyGolayFilter
4
- from .whittaker_smooth import WhittakerSmooth
1
+ from ._mean_filter import MeanFilter
2
+ from ._median_filter import MedianFilter
3
+ from ._savitzky_golay_filter import SavitzkyGolayFilter
4
+ from ._whittaker_smooth import WhittakerSmooth
5
+
6
+ __all__ = ["MeanFilter", "MedianFilter", "SavitzkyGolayFilter", "WhittakerSmooth"]
@@ -1,12 +1,10 @@
1
1
  import numpy as np
2
2
  from scipy.ndimage import uniform_filter1d
3
3
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
4
- from sklearn.utils.validation import check_is_fitted
4
+ from sklearn.utils.validation import check_is_fitted, validate_data
5
5
 
6
- from chemotools.utils.check_inputs import check_input
7
6
 
8
-
9
- class MeanFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
7
+ class MeanFilter(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
10
8
  """
11
9
  A transformer that calculates the mean filter of the input data.
12
10
 
@@ -14,19 +12,11 @@ class MeanFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
14
12
  ----------
15
13
  window_size : int, optional
16
14
  The size of the window to use for the mean filter. Must be odd. Default is 3.
17
-
15
+
18
16
  mode : str, optional
19
17
  The mode to use for the mean filter. Can be "nearest", "constant", "reflect",
20
18
  "wrap", "mirror" or "interp". Default is "nearest".
21
19
 
22
- Attributes
23
- ----------
24
- n_features_in_ : int
25
- The number of features in the input data.
26
-
27
- _is_fitted : bool
28
- Whether the transformer has been fitted to data.
29
-
30
20
  Methods
31
21
  -------
32
22
  fit(X, y=None)
@@ -35,7 +25,8 @@ class MeanFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
35
25
  transform(X, y=0, copy=True)
36
26
  Transform the input data by calculating the mean filter.
37
27
  """
38
- def __init__(self, window_size: int = 3, mode='nearest') -> None:
28
+
29
+ def __init__(self, window_size: int = 3, mode="nearest") -> None:
39
30
  self.window_size = window_size
40
31
  self.mode = mode
41
32
 
@@ -57,14 +48,9 @@ class MeanFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
57
48
  The fitted transformer.
58
49
  """
59
50
  # Check that X is a 2D array and has only finite values
60
- X = check_input(X)
61
-
62
- # Set the number of features
63
- self.n_features_in_ = X.shape[1]
64
-
65
- # Set the fitted attribute to True
66
- self._is_fitted = True
67
-
51
+ X = validate_data(
52
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
53
+ )
68
54
  return self
69
55
 
70
56
  def transform(self, X: np.ndarray, y=None) -> np.ndarray:
@@ -85,11 +71,18 @@ class MeanFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
85
71
  The transformed data.
86
72
  """
87
73
  # Check that the estimator is fitted
88
- check_is_fitted(self, "_is_fitted")
74
+ check_is_fitted(self, "n_features_in_")
89
75
 
90
76
  # Check that X is a 2D array and has only finite values
91
- X = check_input(X)
92
- X_ = X.copy()
77
+ X_ = validate_data(
78
+ self,
79
+ X,
80
+ y="no_validation",
81
+ ensure_2d=True,
82
+ copy=True,
83
+ reset=False,
84
+ dtype=np.float64,
85
+ )
93
86
 
94
87
  if X_.shape[1] != self.n_features_in_:
95
88
  raise ValueError(
@@ -1,12 +1,12 @@
1
+ from typing import Literal
2
+
1
3
  import numpy as np
2
4
  from scipy.ndimage import median_filter
3
5
  from sklearn.base import BaseEstimator, TransformerMixin, OneToOneFeatureMixin
4
- from sklearn.utils.validation import check_is_fitted
5
-
6
- from chemotools.utils.check_inputs import check_input
6
+ from sklearn.utils.validation import check_is_fitted, validate_data
7
7
 
8
8
 
9
- class MedianFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
9
+ class MedianFilter(TransformerMixin, OneToOneFeatureMixin, BaseEstimator):
10
10
  """
11
11
  A transformer that calculates the median filter of the input data.
12
12
 
@@ -19,14 +19,6 @@ class MedianFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
19
19
  The mode to use for the median filter. Can be "nearest", "constant", "reflect",
20
20
  "wrap", "mirror" or "interp". Default is "nearest".
21
21
 
22
- Attributes
23
- ----------
24
- n_features_in_ : int
25
- The number of features in the input data.
26
-
27
- _is_fitted : bool
28
- Whether the transformer has been fitted to data.
29
-
30
22
  Methods
31
23
  -------
32
24
  fit(X, y=None)
@@ -35,7 +27,21 @@ class MedianFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
35
27
  transform(X, y=0, copy=True)
36
28
  Transform the input data by calculating the median filter.
37
29
  """
38
- def __init__(self, window_size: int = 3, mode: str = 'nearest') -> None:
30
+
31
+ def __init__(
32
+ self,
33
+ window_size: int = 3,
34
+ mode: Literal[
35
+ "reflect",
36
+ "constant",
37
+ "nearest",
38
+ "mirror",
39
+ "wrap",
40
+ "grid-constant",
41
+ "grid-mirror",
42
+ "grid-wrap",
43
+ ] = "nearest",
44
+ ) -> None:
39
45
  self.window_size = window_size
40
46
  self.mode = mode
41
47
 
@@ -57,14 +63,9 @@ class MedianFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
57
63
  The fitted transformer.
58
64
  """
59
65
  # Check that X is a 2D array and has only finite values
60
- X = check_input(X)
61
-
62
- # Set the number of features
63
- self.n_features_in_ = X.shape[1]
64
-
65
- # Set the fitted attribute to True
66
- self._is_fitted = True
67
-
66
+ X = validate_data(
67
+ self, X, y="no_validation", ensure_2d=True, reset=True, dtype=np.float64
68
+ )
68
69
  return self
69
70
 
70
71
  def transform(self, X: np.ndarray, y=None) -> np.ndarray:
@@ -85,11 +86,18 @@ class MedianFilter(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
85
86
  The transformed data.
86
87
  """
87
88
  # Check that the estimator is fitted
88
- check_is_fitted(self, "_is_fitted")
89
+ check_is_fitted(self, "n_features_in_")
89
90
 
90
91
  # Check that X is a 2D array and has only finite values
91
- X = check_input(X)
92
- X_ = X.copy()
92
+ X_ = validate_data(
93
+ self,
94
+ X,
95
+ y="no_validation",
96
+ ensure_2d=True,
97
+ copy=True,
98
+ reset=False,
99
+ dtype=np.float64,
100
+ )
93
101
 
94
102
  if X_.shape[1] != self.n_features_in_:
95
103
  raise ValueError(