scikit-learn-intelex 2024.1.0__py311-none-manylinux1_x86_64.whl → 2024.4.0__py311-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (62) hide show
  1. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  2. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  3. sklearnex/__init__.py +9 -7
  4. sklearnex/_device_offload.py +31 -4
  5. sklearnex/basic_statistics/__init__.py +2 -1
  6. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  7. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  8. sklearnex/cluster/dbscan.py +6 -4
  9. sklearnex/conftest.py +63 -0
  10. sklearnex/{preview/decomposition → covariance}/__init__.py +19 -19
  11. sklearnex/covariance/incremental_covariance.py +130 -0
  12. sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  13. sklearnex/decomposition/pca.py +319 -1
  14. sklearnex/decomposition/tests/test_pca.py +34 -5
  15. sklearnex/dispatcher.py +93 -61
  16. sklearnex/ensemble/_forest.py +81 -97
  17. sklearnex/ensemble/tests/test_forest.py +15 -19
  18. sklearnex/linear_model/__init__.py +1 -2
  19. sklearnex/linear_model/linear.py +275 -347
  20. sklearnex/{preview/linear_model → linear_model}/logistic_regression.py +83 -50
  21. sklearnex/linear_model/tests/test_linear.py +40 -5
  22. sklearnex/linear_model/tests/test_logreg.py +70 -7
  23. sklearnex/neighbors/__init__.py +1 -1
  24. sklearnex/neighbors/_lof.py +221 -0
  25. sklearnex/neighbors/common.py +4 -1
  26. sklearnex/neighbors/knn_classification.py +47 -137
  27. sklearnex/neighbors/knn_regression.py +20 -132
  28. sklearnex/neighbors/knn_unsupervised.py +16 -93
  29. sklearnex/neighbors/tests/test_neighbors.py +12 -16
  30. sklearnex/preview/__init__.py +1 -1
  31. sklearnex/preview/cluster/k_means.py +8 -81
  32. sklearnex/preview/covariance/covariance.py +51 -16
  33. sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  34. sklearnex/spmd/__init__.py +1 -0
  35. sklearnex/{preview/linear_model → spmd/covariance}/__init__.py +5 -5
  36. sklearnex/spmd/covariance/covariance.py +21 -0
  37. sklearnex/spmd/ensemble/forest.py +4 -12
  38. sklearnex/spmd/linear_model/__init__.py +2 -1
  39. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  40. sklearnex/svm/_common.py +4 -7
  41. sklearnex/svm/nusvc.py +74 -55
  42. sklearnex/svm/nusvr.py +9 -56
  43. sklearnex/svm/svc.py +74 -56
  44. sklearnex/svm/svr.py +6 -53
  45. sklearnex/tests/_utils.py +164 -0
  46. sklearnex/tests/test_memory_usage.py +9 -7
  47. sklearnex/tests/test_monkeypatch.py +179 -138
  48. sklearnex/tests/test_n_jobs_support.py +77 -9
  49. sklearnex/tests/test_parallel.py +6 -8
  50. sklearnex/tests/test_patching.py +338 -89
  51. sklearnex/utils/__init__.py +2 -1
  52. sklearnex/utils/_namespace.py +97 -0
  53. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  54. sklearnex/neighbors/lof.py +0 -436
  55. sklearnex/preview/decomposition/pca.py +0 -376
  56. sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
  57. sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  58. sklearnex/tests/_models_info.py +0 -170
  59. sklearnex/tests/utils/_launch_algorithms.py +0 -118
  60. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  61. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  62. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -17,372 +17,300 @@
17
17
  import logging
18
18
  from abc import ABC
19
19
 
20
- from daal4py.sklearn._utils import daal_check_version
21
-
22
-
23
- def get_coef(self):
24
- return self._coef_
25
-
26
-
27
- def set_coef(self, value):
28
- self._coef_ = value
29
- if hasattr(self, "_onedal_estimator"):
30
- self._onedal_estimator.coef_ = value
31
- if not self._is_in_fit:
32
- del self._onedal_estimator._onedal_model
33
-
34
-
35
- def get_intercept(self):
36
- return self._intercept_
37
-
38
-
39
- def set_intercept(self, value):
40
- self._intercept_ = value
41
- if hasattr(self, "_onedal_estimator"):
42
- self._onedal_estimator.intercept_ = value
43
- if not self._is_in_fit:
44
- del self._onedal_estimator._onedal_model
45
-
46
-
47
- class BaseLinearRegression(ABC):
48
- def _save_attributes(self):
49
- self.n_features_in_ = self._onedal_estimator.n_features_in_
50
- self.fit_status_ = 0
51
- self._coef_ = self._onedal_estimator.coef_
52
- self._intercept_ = self._onedal_estimator.intercept_
53
- self._sparse = False
54
-
55
- self.coef_ = property(get_coef, set_coef)
56
- self.intercept_ = property(get_intercept, set_intercept)
57
-
58
- self._is_in_fit = True
59
- self.coef_ = self._coef_
60
- self.intercept_ = self._intercept_
61
- self._is_in_fit = False
62
-
63
-
64
- if daal_check_version((2023, "P", 100)):
65
- import numpy as np
66
- from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
67
-
68
- from daal4py.sklearn._utils import (
69
- control_n_jobs,
70
- get_dtype,
71
- make2d,
72
- run_with_n_jobs,
73
- sklearn_check_version,
74
- )
75
-
76
- from .._device_offload import dispatch, wrap_output_data
77
- from .._utils import (
78
- PatchingConditionsChain,
79
- get_patch_message,
80
- register_hyperparameters,
81
- )
82
- from ..utils.validation import _assert_all_finite
20
+ import numpy as np
21
+ from sklearn.exceptions import NotFittedError
22
+ from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
23
+
24
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
25
+ from daal4py.sklearn._utils import sklearn_check_version
26
+
27
+ from .._device_offload import dispatch, wrap_output_data
28
+ from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters
29
+ from ..utils.validation import _assert_all_finite
30
+
31
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
32
+ from sklearn.linear_model._base import _deprecate_normalize
33
+
34
+ from scipy.sparse import issparse
35
+ from sklearn.utils.validation import check_X_y
36
+
37
+ from onedal.common.hyperparameters import get_hyperparameters
38
+ from onedal.linear_model import LinearRegression as onedal_LinearRegression
39
+ from onedal.utils import _num_features, _num_samples
40
+
41
+
42
+ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
43
+ @control_n_jobs(decorated_methods=["fit", "predict"])
44
+ class LinearRegression(sklearn_LinearRegression):
45
+ __doc__ = sklearn_LinearRegression.__doc__
46
+
47
+ if sklearn_check_version("1.2"):
48
+ _parameter_constraints: dict = {**sklearn_LinearRegression._parameter_constraints}
49
+
50
+ def __init__(
51
+ self,
52
+ fit_intercept=True,
53
+ copy_X=True,
54
+ n_jobs=None,
55
+ positive=False,
56
+ ):
57
+ super().__init__(
58
+ fit_intercept=fit_intercept,
59
+ copy_X=copy_X,
60
+ n_jobs=n_jobs,
61
+ positive=positive,
62
+ )
83
63
 
84
- if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
85
- from sklearn.linear_model._base import _deprecate_normalize
64
+ else:
65
+
66
+ def __init__(
67
+ self,
68
+ fit_intercept=True,
69
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
70
+ copy_X=True,
71
+ n_jobs=None,
72
+ positive=False,
73
+ ):
74
+ super().__init__(
75
+ fit_intercept=fit_intercept,
76
+ normalize=normalize,
77
+ copy_X=copy_X,
78
+ n_jobs=n_jobs,
79
+ positive=positive,
80
+ )
86
81
 
87
- from scipy.sparse import issparse
88
- from sklearn.exceptions import NotFittedError
89
- from sklearn.utils.validation import _deprecate_positional_args, check_X_y
82
+ def fit(self, X, y, sample_weight=None):
83
+ if sklearn_check_version("1.0"):
84
+ self._check_feature_names(X, reset=True)
85
+ if sklearn_check_version("1.2"):
86
+ self._validate_params()
87
+
88
+ # It is necessary to properly update coefs for predict if we
89
+ # fallback to sklearn in dispatch
90
+ if hasattr(self, "_onedal_estimator"):
91
+ del self._onedal_estimator
92
+
93
+ dispatch(
94
+ self,
95
+ "fit",
96
+ {
97
+ "onedal": self.__class__._onedal_fit,
98
+ "sklearn": sklearn_LinearRegression.fit,
99
+ },
100
+ X,
101
+ y,
102
+ sample_weight,
103
+ )
104
+ return self
105
+
106
+ @wrap_output_data
107
+ def predict(self, X):
108
+
109
+ if not hasattr(self, "coef_"):
110
+ msg = (
111
+ "This %(name)s instance is not fitted yet. Call 'fit' with "
112
+ "appropriate arguments before using this estimator."
113
+ )
114
+ raise NotFittedError(msg % {"name": self.__class__.__name__})
115
+
116
+ return dispatch(
117
+ self,
118
+ "predict",
119
+ {
120
+ "onedal": self.__class__._onedal_predict,
121
+ "sklearn": sklearn_LinearRegression.predict,
122
+ },
123
+ X,
124
+ )
125
+
126
+ def _test_type_and_finiteness(self, X_in):
127
+ X = X_in if isinstance(X_in, np.ndarray) else np.asarray(X_in)
128
+
129
+ dtype = X.dtype
130
+ if "complex" in str(type(dtype)):
131
+ return False
132
+
133
+ try:
134
+ _assert_all_finite(X)
135
+ except BaseException:
136
+ return False
137
+ return True
138
+
139
+ def _onedal_fit_supported(self, method_name, *data):
140
+ assert method_name == "fit"
141
+ assert len(data) == 3
142
+ X, y, sample_weight = data
143
+
144
+ class_name = self.__class__.__name__
145
+ patching_status = PatchingConditionsChain(
146
+ f"sklearn.linear_model.{class_name}.fit"
147
+ )
148
+
149
+ normalize_is_set = (
150
+ hasattr(self, "normalize")
151
+ and self.normalize
152
+ and self.normalize != "deprecated"
153
+ )
154
+ positive_is_set = hasattr(self, "positive") and self.positive
155
+
156
+ n_samples = _num_samples(X)
157
+ n_features = _num_features(X, fallback_1d=True)
158
+
159
+ # Check if equations are well defined
160
+ is_good_for_onedal = n_samples >= (n_features + int(self.fit_intercept))
161
+
162
+ dal_ready = patching_status.and_conditions(
163
+ [
164
+ (sample_weight is None, "Sample weight is not supported."),
165
+ (
166
+ not issparse(X) and not issparse(y),
167
+ "Sparse input is not supported.",
168
+ ),
169
+ (not normalize_is_set, "Normalization is not supported."),
170
+ (
171
+ not positive_is_set,
172
+ "Forced positive coefficients are not supported.",
173
+ ),
174
+ (
175
+ is_good_for_onedal,
176
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
177
+ "Number of features + 1 >= number of samples.",
178
+ ),
179
+ ]
180
+ )
181
+ if not dal_ready:
182
+ return patching_status
90
183
 
91
- from onedal.common.hyperparameters import get_hyperparameters
92
- from onedal.linear_model import LinearRegression as onedal_LinearRegression
93
- from onedal.utils import _num_features, _num_samples
184
+ if not patching_status.and_condition(
185
+ self._test_type_and_finiteness(X), "Input X is not supported."
186
+ ):
187
+ return patching_status
94
188
 
95
- @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
96
- @control_n_jobs
97
- class LinearRegression(sklearn_LinearRegression, BaseLinearRegression):
98
- __doc__ = sklearn_LinearRegression.__doc__
99
- intercept_, coef_ = None, None
189
+ patching_status.and_condition(
190
+ self._test_type_and_finiteness(y), "Input y is not supported."
191
+ )
192
+
193
+ return patching_status
194
+
195
+ def _onedal_predict_supported(self, method_name, *data):
196
+ assert method_name == "predict"
197
+ assert len(data) == 1
198
+
199
+ class_name = self.__class__.__name__
200
+ patching_status = PatchingConditionsChain(
201
+ f"sklearn.linear_model.{class_name}.predict"
202
+ )
203
+
204
+ n_samples = _num_samples(*data)
205
+ model_is_sparse = issparse(self.coef_) or (
206
+ self.fit_intercept and issparse(self.intercept_)
207
+ )
208
+ dal_ready = patching_status.and_conditions(
209
+ [
210
+ (n_samples > 0, "Number of samples is less than 1."),
211
+ (not issparse(*data), "Sparse input is not supported."),
212
+ (not model_is_sparse, "Sparse coefficients are not supported."),
213
+ ]
214
+ )
215
+ if not dal_ready:
216
+ return patching_status
100
217
 
218
+ patching_status.and_condition(
219
+ self._test_type_and_finiteness(*data), "Input X is not supported."
220
+ )
221
+
222
+ return patching_status
223
+
224
+ def _onedal_supported(self, method_name, *data):
225
+ if method_name == "fit":
226
+ return self._onedal_fit_supported(method_name, *data)
227
+ if method_name == "predict":
228
+ return self._onedal_predict_supported(method_name, *data)
229
+ raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")
230
+
231
+ _onedal_gpu_supported = _onedal_supported
232
+ _onedal_cpu_supported = _onedal_supported
233
+
234
+ def _initialize_onedal_estimator(self):
235
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
236
+ self._onedal_estimator = onedal_LinearRegression(**onedal_params)
237
+
238
+ def _onedal_fit(self, X, y, sample_weight, queue=None):
239
+ assert sample_weight is None
240
+
241
+ check_params = {
242
+ "X": X,
243
+ "y": y,
244
+ "dtype": [np.float64, np.float32],
245
+ "accept_sparse": ["csr", "csc", "coo"],
246
+ "y_numeric": True,
247
+ "multi_output": True,
248
+ "force_all_finite": False,
249
+ }
101
250
  if sklearn_check_version("1.2"):
102
- _parameter_constraints: dict = {
103
- **sklearn_LinearRegression._parameter_constraints
104
- }
105
-
106
- def __init__(
107
- self,
108
- fit_intercept=True,
109
- copy_X=True,
110
- n_jobs=None,
111
- positive=False,
112
- ):
113
- super().__init__(
114
- fit_intercept=fit_intercept,
115
- copy_X=copy_X,
116
- n_jobs=n_jobs,
117
- positive=positive,
118
- )
119
-
120
- elif sklearn_check_version("0.24"):
121
-
122
- def __init__(
123
- self,
124
- fit_intercept=True,
125
- normalize="deprecated" if sklearn_check_version("1.0") else False,
126
- copy_X=True,
127
- n_jobs=None,
128
- positive=False,
129
- ):
130
- super().__init__(
131
- fit_intercept=fit_intercept,
132
- normalize=normalize,
133
- copy_X=copy_X,
134
- n_jobs=n_jobs,
135
- positive=positive,
136
- )
137
-
251
+ X, y = self._validate_data(**check_params)
138
252
  else:
253
+ X, y = check_X_y(**check_params)
139
254
 
140
- def __init__(
141
- self,
142
- fit_intercept=True,
143
- normalize=False,
144
- copy_X=True,
145
- n_jobs=None,
146
- ):
147
- super().__init__(
148
- fit_intercept=fit_intercept,
149
- normalize=normalize,
150
- copy_X=copy_X,
151
- n_jobs=n_jobs,
152
- )
153
-
154
- def fit(self, X, y, sample_weight=None):
155
- """
156
- Fit linear model.
157
- Parameters
158
- ----------
159
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
160
- Training data.
161
- y : array-like of shape (n_samples,) or (n_samples, n_targets)
162
- Target values. Will be cast to X's dtype if necessary.
163
- sample_weight : array-like of shape (n_samples,), default=None
164
- Individual weights for each sample.
165
- .. versionadded:: 0.17
166
- parameter *sample_weight* support to LinearRegression.
167
- Returns
168
- -------
169
- self : object
170
- Fitted Estimator.
171
- """
172
- if sklearn_check_version("1.0"):
173
- self._check_feature_names(X, reset=True)
174
- if sklearn_check_version("1.2"):
175
- self._validate_params()
176
-
177
- dispatch(
178
- self,
179
- "fit",
180
- {
181
- "onedal": self.__class__._onedal_fit,
182
- "sklearn": sklearn_LinearRegression.fit,
183
- },
184
- X,
185
- y,
186
- sample_weight,
187
- )
188
- return self
189
-
190
- @wrap_output_data
191
- def predict(self, X):
192
- """
193
- Predict using the linear model.
194
- Parameters
195
- ----------
196
- X : array-like or sparse matrix, shape (n_samples, n_features)
197
- Samples.
198
- Returns
199
- -------
200
- C : array, shape (n_samples, n_targets)
201
- Returns predicted values.
202
- """
203
- if sklearn_check_version("1.0"):
204
- self._check_feature_names(X, reset=False)
205
- return dispatch(
206
- self,
207
- "predict",
208
- {
209
- "onedal": self.__class__._onedal_predict,
210
- "sklearn": sklearn_LinearRegression.predict,
211
- },
212
- X,
255
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
256
+ self._normalize = _deprecate_normalize(
257
+ self.normalize,
258
+ default=False,
259
+ estimator_name=self.__class__.__name__,
213
260
  )
214
261
 
215
- def _test_type_and_finiteness(self, X_in):
216
- X = X_in if isinstance(X_in, np.ndarray) else np.asarray(X_in)
217
-
218
- dtype = X.dtype
219
- if "complex" in str(type(dtype)):
220
- return False
221
-
222
- try:
223
- _assert_all_finite(X)
224
- except BaseException:
225
- return False
226
- return True
227
-
228
- def _onedal_fit_supported(self, method_name, *data):
229
- assert method_name == "fit"
230
- assert len(data) == 3
231
- X, y, sample_weight = data
262
+ self._initialize_onedal_estimator()
263
+ try:
264
+ self._onedal_estimator.fit(X, y, queue=queue)
265
+ self._save_attributes()
232
266
 
233
- class_name = self.__class__.__name__
234
- patching_status = PatchingConditionsChain(
235
- f"sklearn.linear_model.{class_name}.fit"
267
+ except RuntimeError:
268
+ logging.getLogger("sklearnex").info(
269
+ f"{self.__class__.__name__}.fit "
270
+ + get_patch_message("sklearn_after_onedal")
236
271
  )
237
272
 
238
- normalize_is_set = (
239
- hasattr(self, "normalize")
240
- and self.normalize
241
- and self.normalize != "deprecated"
242
- )
243
- positive_is_set = hasattr(self, "positive") and self.positive
244
-
245
- n_samples = _num_samples(X)
246
- n_features = _num_features(X, fallback_1d=True)
247
-
248
- # Check if equations are well defined
249
- is_good_for_onedal = n_samples > (n_features + int(self.fit_intercept))
250
-
251
- dal_ready = patching_status.and_conditions(
252
- [
253
- (sample_weight is None, "Sample weight is not supported."),
254
- (
255
- not issparse(X) and not issparse(y),
256
- "Sparse input is not supported.",
257
- ),
258
- (not normalize_is_set, "Normalization is not supported."),
259
- (
260
- not positive_is_set,
261
- "Forced positive coefficients are not supported.",
262
- ),
263
- (
264
- is_good_for_onedal,
265
- "The shape of X (fitting) does not satisfy oneDAL requirements:."
266
- "Number of features + 1 >= number of samples.",
267
- ),
268
- ]
269
- )
270
- if not dal_ready:
271
- return patching_status
272
-
273
- if not patching_status.and_condition(
274
- self._test_type_and_finiteness(X), "Input X is not supported."
275
- ):
276
- return patching_status
273
+ del self._onedal_estimator
274
+ super().fit(X, y)
277
275
 
278
- patching_status.and_condition(
279
- self._test_type_and_finiteness(y), "Input y is not supported."
280
- )
281
-
282
- return patching_status
276
+ def _onedal_predict(self, X, queue=None):
277
+ if sklearn_check_version("1.0"):
278
+ self._check_feature_names(X, reset=False)
283
279
 
284
- def _onedal_predict_supported(self, method_name, *data):
285
- assert method_name == "predict"
286
- assert len(data) == 1
280
+ X = self._validate_data(X, accept_sparse=False, reset=False)
281
+ if not hasattr(self, "_onedal_estimator"):
282
+ self._initialize_onedal_estimator()
283
+ self._onedal_estimator.coef_ = self.coef_
284
+ self._onedal_estimator.intercept_ = self.intercept_
287
285
 
288
- class_name = self.__class__.__name__
289
- patching_status = PatchingConditionsChain(
290
- f"sklearn.linear_model.{class_name}.predict"
291
- )
286
+ res = self._onedal_estimator.predict(X, queue=queue)
287
+ return res
292
288
 
293
- n_samples = _num_samples(*data)
294
- model_is_sparse = issparse(self.coef_) or (
295
- self.fit_intercept and issparse(self.intercept_)
296
- )
297
- dal_ready = patching_status.and_conditions(
298
- [
299
- (n_samples > 0, "Number of samples is less than 1."),
300
- (not issparse(*data), "Sparse input is not supported."),
301
- (not model_is_sparse, "Sparse coefficients are not supported."),
302
- (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
303
- ]
304
- )
305
- if not dal_ready:
306
- return patching_status
289
+ def get_coef_(self):
290
+ return self.coef_
307
291
 
308
- patching_status.and_condition(
309
- self._test_type_and_finiteness(*data), "Input X is not supported."
310
- )
292
+ def set_coef_(self, value):
293
+ self.__dict__["coef_"] = value
294
+ if hasattr(self, "_onedal_estimator"):
295
+ self._onedal_estimator.coef_ = value
296
+ del self._onedal_estimator._onedal_model
311
297
 
312
- return patching_status
298
+ def get_intercept_(self):
299
+ return self.intercept_
313
300
 
314
- def _onedal_supported(self, method_name, *data):
315
- if method_name == "fit":
316
- return self._onedal_fit_supported(method_name, *data)
317
- if method_name == "predict":
318
- return self._onedal_predict_supported(method_name, *data)
319
- raise RuntimeError(
320
- f"Unknown method {method_name} in {self.__class__.__name__}"
321
- )
301
+ def set_intercept_(self, value):
302
+ self.__dict__["intercept_"] = value
303
+ if hasattr(self, "_onedal_estimator"):
304
+ self._onedal_estimator.intercept_ = value
305
+ del self._onedal_estimator._onedal_model
322
306
 
323
- def _onedal_gpu_supported(self, method_name, *data):
324
- return self._onedal_supported(method_name, *data)
325
-
326
- def _onedal_cpu_supported(self, method_name, *data):
327
- return self._onedal_supported(method_name, *data)
328
-
329
- def _initialize_onedal_estimator(self):
330
- onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
331
- self._onedal_estimator = onedal_LinearRegression(**onedal_params)
332
-
333
- @run_with_n_jobs
334
- def _onedal_fit(self, X, y, sample_weight, queue=None):
335
- assert sample_weight is None
336
-
337
- check_params = {
338
- "X": X,
339
- "y": y,
340
- "dtype": [np.float64, np.float32],
341
- "accept_sparse": ["csr", "csc", "coo"],
342
- "y_numeric": True,
343
- "multi_output": True,
344
- "force_all_finite": False,
345
- }
346
- if sklearn_check_version("1.2"):
347
- X, y = self._validate_data(**check_params)
348
- else:
349
- X, y = check_X_y(**check_params)
350
-
351
- if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
352
- self._normalize = _deprecate_normalize(
353
- self.normalize,
354
- default=False,
355
- estimator_name=self.__class__.__name__,
356
- )
307
+ def _save_attributes(self):
308
+ self.coef_ = property(self.get_coef_, self.set_coef_)
309
+ self.intercept_ = property(self.get_intercept_, self.set_intercept_)
310
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
311
+ self._sparse = False
312
+ self.__dict__["coef_"] = self._onedal_estimator.coef_
313
+ self.__dict__["intercept_"] = self._onedal_estimator.intercept_
357
314
 
358
- self._initialize_onedal_estimator()
359
- try:
360
- self._onedal_estimator.fit(X, y, queue=queue)
361
- self._save_attributes()
362
-
363
- except RuntimeError:
364
- logging.getLogger("sklearnex").info(
365
- f"{self.__class__.__name__}.fit "
366
- + get_patch_message("sklearn_after_onedal")
367
- )
368
-
369
- del self._onedal_estimator
370
- super().fit(X, y)
371
-
372
- @run_with_n_jobs
373
- def _onedal_predict(self, X, queue=None):
374
- X = self._validate_data(X, accept_sparse=False, reset=False)
375
- if not hasattr(self, "_onedal_estimator"):
376
- self._initialize_onedal_estimator()
377
- self._onedal_estimator.coef_ = self.coef_
378
- self._onedal_estimator.intercept_ = self.intercept_
379
-
380
- return self._onedal_estimator.predict(X, queue=queue)
381
-
382
- else:
383
- from daal4py.sklearn.linear_model import LinearRegression
384
-
385
- logging.warning(
386
- "Sklearnex LinearRegression requires oneDAL version >= 2023.1 "
387
- "but it was not found"
388
- )
315
+ fit.__doc__ = sklearn_LinearRegression.fit.__doc__
316
+ predict.__doc__ = sklearn_LinearRegression.predict.__doc__