scikit-learn-intelex 2025.4.0__py313-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (282) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-313-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-313-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +696 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +204 -0
  62. onedal/_onedal_py_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-313-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +175 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +242 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  70. onedal/basic_statistics/tests/utils.py +50 -0
  71. onedal/cluster/__init__.py +27 -0
  72. onedal/cluster/dbscan.py +105 -0
  73. onedal/cluster/kmeans.py +557 -0
  74. onedal/cluster/kmeans_init.py +112 -0
  75. onedal/cluster/tests/test_dbscan.py +125 -0
  76. onedal/cluster/tests/test_kmeans.py +88 -0
  77. onedal/cluster/tests/test_kmeans_init.py +93 -0
  78. onedal/common/_base.py +38 -0
  79. onedal/common/_estimator_checks.py +47 -0
  80. onedal/common/_mixin.py +62 -0
  81. onedal/common/_policy.py +55 -0
  82. onedal/common/_spmd_policy.py +30 -0
  83. onedal/common/hyperparameters.py +125 -0
  84. onedal/common/tests/test_policy.py +76 -0
  85. onedal/common/tests/test_sycl.py +128 -0
  86. onedal/covariance/__init__.py +20 -0
  87. onedal/covariance/covariance.py +122 -0
  88. onedal/covariance/incremental_covariance.py +161 -0
  89. onedal/covariance/tests/test_covariance.py +50 -0
  90. onedal/covariance/tests/test_incremental_covariance.py +190 -0
  91. onedal/datatypes/__init__.py +19 -0
  92. onedal/datatypes/_data_conversion.py +121 -0
  93. onedal/datatypes/tests/common.py +126 -0
  94. onedal/datatypes/tests/test_data.py +475 -0
  95. onedal/decomposition/__init__.py +20 -0
  96. onedal/decomposition/incremental_pca.py +214 -0
  97. onedal/decomposition/pca.py +186 -0
  98. onedal/decomposition/tests/test_incremental_pca.py +285 -0
  99. onedal/ensemble/__init__.py +29 -0
  100. onedal/ensemble/forest.py +736 -0
  101. onedal/ensemble/tests/test_random_forest.py +97 -0
  102. onedal/linear_model/__init__.py +27 -0
  103. onedal/linear_model/incremental_linear_model.py +292 -0
  104. onedal/linear_model/linear_model.py +325 -0
  105. onedal/linear_model/logistic_regression.py +247 -0
  106. onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  107. onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  108. onedal/linear_model/tests/test_linear_regression.py +259 -0
  109. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  110. onedal/linear_model/tests/test_ridge.py +95 -0
  111. onedal/neighbors/__init__.py +19 -0
  112. onedal/neighbors/neighbors.py +763 -0
  113. onedal/neighbors/tests/test_knn_classification.py +49 -0
  114. onedal/primitives/__init__.py +27 -0
  115. onedal/primitives/get_tree.py +25 -0
  116. onedal/primitives/kernel_functions.py +152 -0
  117. onedal/primitives/tests/test_kernel_functions.py +159 -0
  118. onedal/spmd/__init__.py +25 -0
  119. onedal/spmd/_base.py +30 -0
  120. onedal/spmd/basic_statistics/__init__.py +20 -0
  121. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  122. onedal/spmd/basic_statistics/incremental_basic_statistics.py +71 -0
  123. onedal/spmd/cluster/__init__.py +28 -0
  124. onedal/spmd/cluster/dbscan.py +23 -0
  125. onedal/spmd/cluster/kmeans.py +56 -0
  126. onedal/spmd/covariance/__init__.py +20 -0
  127. onedal/spmd/covariance/covariance.py +26 -0
  128. onedal/spmd/covariance/incremental_covariance.py +83 -0
  129. onedal/spmd/decomposition/__init__.py +20 -0
  130. onedal/spmd/decomposition/incremental_pca.py +124 -0
  131. onedal/spmd/decomposition/pca.py +26 -0
  132. onedal/spmd/ensemble/__init__.py +19 -0
  133. onedal/spmd/ensemble/forest.py +28 -0
  134. onedal/spmd/linear_model/__init__.py +21 -0
  135. onedal/spmd/linear_model/incremental_linear_model.py +101 -0
  136. onedal/spmd/linear_model/linear_model.py +30 -0
  137. onedal/spmd/linear_model/logistic_regression.py +38 -0
  138. onedal/spmd/neighbors/__init__.py +19 -0
  139. onedal/spmd/neighbors/neighbors.py +75 -0
  140. onedal/svm/__init__.py +19 -0
  141. onedal/svm/svm.py +556 -0
  142. onedal/svm/tests/test_csr_svm.py +351 -0
  143. onedal/svm/tests/test_nusvc.py +204 -0
  144. onedal/svm/tests/test_nusvr.py +210 -0
  145. onedal/svm/tests/test_svc.py +176 -0
  146. onedal/svm/tests/test_svr.py +243 -0
  147. onedal/tests/test_common.py +57 -0
  148. onedal/tests/utils/_dataframes_support.py +162 -0
  149. onedal/tests/utils/_device_selection.py +102 -0
  150. onedal/utils/__init__.py +49 -0
  151. onedal/utils/_array_api.py +81 -0
  152. onedal/utils/_dpep_helpers.py +56 -0
  153. onedal/utils/tests/test_validation.py +142 -0
  154. onedal/utils/validation.py +464 -0
  155. scikit_learn_intelex-2025.4.0.dist-info/LICENSE.txt +202 -0
  156. scikit_learn_intelex-2025.4.0.dist-info/METADATA +190 -0
  157. scikit_learn_intelex-2025.4.0.dist-info/RECORD +282 -0
  158. scikit_learn_intelex-2025.4.0.dist-info/WHEEL +5 -0
  159. scikit_learn_intelex-2025.4.0.dist-info/top_level.txt +3 -0
  160. sklearnex/__init__.py +66 -0
  161. sklearnex/__main__.py +58 -0
  162. sklearnex/_config.py +116 -0
  163. sklearnex/_device_offload.py +126 -0
  164. sklearnex/_utils.py +177 -0
  165. sklearnex/basic_statistics/__init__.py +20 -0
  166. sklearnex/basic_statistics/basic_statistics.py +261 -0
  167. sklearnex/basic_statistics/incremental_basic_statistics.py +352 -0
  168. sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  169. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  170. sklearnex/cluster/__init__.py +20 -0
  171. sklearnex/cluster/dbscan.py +197 -0
  172. sklearnex/cluster/k_means.py +397 -0
  173. sklearnex/cluster/tests/test_dbscan.py +38 -0
  174. sklearnex/cluster/tests/test_kmeans.py +157 -0
  175. sklearnex/conftest.py +82 -0
  176. sklearnex/covariance/__init__.py +19 -0
  177. sklearnex/covariance/incremental_covariance.py +405 -0
  178. sklearnex/covariance/tests/test_incremental_covariance.py +287 -0
  179. sklearnex/decomposition/__init__.py +19 -0
  180. sklearnex/decomposition/pca.py +427 -0
  181. sklearnex/decomposition/tests/test_pca.py +58 -0
  182. sklearnex/dispatcher.py +534 -0
  183. sklearnex/doc/third-party-programs.txt +424 -0
  184. sklearnex/ensemble/__init__.py +29 -0
  185. sklearnex/ensemble/_forest.py +2029 -0
  186. sklearnex/ensemble/tests/test_forest.py +140 -0
  187. sklearnex/glob/__main__.py +72 -0
  188. sklearnex/glob/dispatcher.py +101 -0
  189. sklearnex/linear_model/__init__.py +32 -0
  190. sklearnex/linear_model/coordinate_descent.py +30 -0
  191. sklearnex/linear_model/incremental_linear.py +495 -0
  192. sklearnex/linear_model/incremental_ridge.py +432 -0
  193. sklearnex/linear_model/linear.py +346 -0
  194. sklearnex/linear_model/logistic_regression.py +415 -0
  195. sklearnex/linear_model/ridge.py +390 -0
  196. sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  197. sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  198. sklearnex/linear_model/tests/test_linear.py +142 -0
  199. sklearnex/linear_model/tests/test_logreg.py +134 -0
  200. sklearnex/linear_model/tests/test_ridge.py +256 -0
  201. sklearnex/manifold/__init__.py +19 -0
  202. sklearnex/manifold/t_sne.py +26 -0
  203. sklearnex/manifold/tests/test_tsne.py +250 -0
  204. sklearnex/metrics/__init__.py +23 -0
  205. sklearnex/metrics/pairwise.py +22 -0
  206. sklearnex/metrics/ranking.py +20 -0
  207. sklearnex/metrics/tests/test_metrics.py +39 -0
  208. sklearnex/model_selection/__init__.py +21 -0
  209. sklearnex/model_selection/split.py +22 -0
  210. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  211. sklearnex/neighbors/__init__.py +27 -0
  212. sklearnex/neighbors/_lof.py +236 -0
  213. sklearnex/neighbors/common.py +310 -0
  214. sklearnex/neighbors/knn_classification.py +231 -0
  215. sklearnex/neighbors/knn_regression.py +207 -0
  216. sklearnex/neighbors/knn_unsupervised.py +178 -0
  217. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  218. sklearnex/preview/__init__.py +17 -0
  219. sklearnex/preview/covariance/__init__.py +19 -0
  220. sklearnex/preview/covariance/covariance.py +142 -0
  221. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  222. sklearnex/preview/decomposition/__init__.py +19 -0
  223. sklearnex/preview/decomposition/incremental_pca.py +244 -0
  224. sklearnex/preview/decomposition/tests/test_incremental_pca.py +336 -0
  225. sklearnex/spmd/__init__.py +25 -0
  226. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  227. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  228. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  229. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  230. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +306 -0
  231. sklearnex/spmd/cluster/__init__.py +30 -0
  232. sklearnex/spmd/cluster/dbscan.py +50 -0
  233. sklearnex/spmd/cluster/kmeans.py +21 -0
  234. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  235. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +173 -0
  236. sklearnex/spmd/covariance/__init__.py +20 -0
  237. sklearnex/spmd/covariance/covariance.py +21 -0
  238. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  239. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  240. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  241. sklearnex/spmd/decomposition/__init__.py +20 -0
  242. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  243. sklearnex/spmd/decomposition/pca.py +21 -0
  244. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  245. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  246. sklearnex/spmd/ensemble/__init__.py +19 -0
  247. sklearnex/spmd/ensemble/forest.py +71 -0
  248. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  249. sklearnex/spmd/linear_model/__init__.py +21 -0
  250. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  251. sklearnex/spmd/linear_model/linear_model.py +21 -0
  252. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  253. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +331 -0
  254. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  255. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  256. sklearnex/spmd/neighbors/__init__.py +19 -0
  257. sklearnex/spmd/neighbors/neighbors.py +25 -0
  258. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  259. sklearnex/svm/__init__.py +29 -0
  260. sklearnex/svm/_common.py +339 -0
  261. sklearnex/svm/nusvc.py +371 -0
  262. sklearnex/svm/nusvr.py +170 -0
  263. sklearnex/svm/svc.py +399 -0
  264. sklearnex/svm/svr.py +167 -0
  265. sklearnex/svm/tests/test_svm.py +93 -0
  266. sklearnex/tests/test_common.py +491 -0
  267. sklearnex/tests/test_config.py +123 -0
  268. sklearnex/tests/test_hyperparameters.py +43 -0
  269. sklearnex/tests/test_memory_usage.py +347 -0
  270. sklearnex/tests/test_monkeypatch.py +269 -0
  271. sklearnex/tests/test_n_jobs_support.py +108 -0
  272. sklearnex/tests/test_parallel.py +48 -0
  273. sklearnex/tests/test_patching.py +377 -0
  274. sklearnex/tests/test_run_to_run_stability.py +326 -0
  275. sklearnex/tests/utils/__init__.py +48 -0
  276. sklearnex/tests/utils/base.py +436 -0
  277. sklearnex/tests/utils/spmd.py +198 -0
  278. sklearnex/utils/__init__.py +19 -0
  279. sklearnex/utils/_array_api.py +82 -0
  280. sklearnex/utils/parallel.py +59 -0
  281. sklearnex/utils/tests/test_validation.py +238 -0
  282. sklearnex/utils/validation.py +208 -0
@@ -0,0 +1,346 @@
1
+ # ===============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import logging
18
+ from abc import ABC
19
+
20
+ import numpy as np
21
+ from sklearn.linear_model import LinearRegression as _sklearn_LinearRegression
22
+ from sklearn.metrics import r2_score
23
+ from sklearn.utils.validation import check_array, check_is_fitted
24
+
25
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
26
+ from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
27
+
28
+ from .._config import get_config
29
+ from .._device_offload import dispatch, wrap_output_data
30
+ from .._utils import (
31
+ PatchableEstimator,
32
+ PatchingConditionsChain,
33
+ get_patch_message,
34
+ register_hyperparameters,
35
+ )
36
+
37
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
38
+ from sklearn.linear_model._base import _deprecate_normalize
39
+
40
+ from scipy.sparse import issparse
41
+ from sklearn.utils.validation import check_is_fitted, check_X_y
42
+
43
+ from onedal.common.hyperparameters import get_hyperparameters
44
+ from onedal.linear_model import LinearRegression as onedal_LinearRegression
45
+ from onedal.utils import _num_features, _num_samples
46
+
47
+ if sklearn_check_version("1.6"):
48
+ from sklearn.utils.validation import validate_data
49
+ else:
50
+ validate_data = _sklearn_LinearRegression._validate_data
51
+
52
+
53
+ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
54
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
55
+ class LinearRegression(PatchableEstimator, _sklearn_LinearRegression):
56
+ __doc__ = _sklearn_LinearRegression.__doc__
57
+
58
+ if sklearn_check_version("1.2"):
59
+ _parameter_constraints: dict = {
60
+ **_sklearn_LinearRegression._parameter_constraints
61
+ }
62
+
63
+ def __init__(
64
+ self,
65
+ fit_intercept=True,
66
+ copy_X=True,
67
+ n_jobs=None,
68
+ positive=False,
69
+ ):
70
+ super().__init__(
71
+ fit_intercept=fit_intercept,
72
+ copy_X=copy_X,
73
+ n_jobs=n_jobs,
74
+ positive=positive,
75
+ )
76
+
77
+ else:
78
+
79
+ def __init__(
80
+ self,
81
+ fit_intercept=True,
82
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
83
+ copy_X=True,
84
+ n_jobs=None,
85
+ positive=False,
86
+ ):
87
+ super().__init__(
88
+ fit_intercept=fit_intercept,
89
+ normalize=normalize,
90
+ copy_X=copy_X,
91
+ n_jobs=n_jobs,
92
+ positive=positive,
93
+ )
94
+
95
+ def fit(self, X, y, sample_weight=None):
96
+ if sklearn_check_version("1.2"):
97
+ self._validate_params()
98
+
99
+ # It is necessary to properly update coefs for predict if we
100
+ # fallback to sklearn in dispatch
101
+ if hasattr(self, "_onedal_estimator"):
102
+ del self._onedal_estimator
103
+
104
+ dispatch(
105
+ self,
106
+ "fit",
107
+ {
108
+ "onedal": self.__class__._onedal_fit,
109
+ "sklearn": _sklearn_LinearRegression.fit,
110
+ },
111
+ X,
112
+ y,
113
+ sample_weight,
114
+ )
115
+ return self
116
+
117
+ @wrap_output_data
118
+ def predict(self, X):
119
+ check_is_fitted(self)
120
+ return dispatch(
121
+ self,
122
+ "predict",
123
+ {
124
+ "onedal": self.__class__._onedal_predict,
125
+ "sklearn": _sklearn_LinearRegression.predict,
126
+ },
127
+ X,
128
+ )
129
+
130
+ @wrap_output_data
131
+ def score(self, X, y, sample_weight=None):
132
+ check_is_fitted(self)
133
+ return dispatch(
134
+ self,
135
+ "score",
136
+ {
137
+ "onedal": self.__class__._onedal_score,
138
+ "sklearn": _sklearn_LinearRegression.score,
139
+ },
140
+ X,
141
+ y,
142
+ sample_weight=sample_weight,
143
+ )
144
+
145
+ def _onedal_cpu_supported(self, method_name, *data):
146
+ patching_status = PatchingConditionsChain(
147
+ f"sklearn.linear_model.{self.__class__.__name__}.{method_name}"
148
+ )
149
+ return self._onedal_supported(patching_status, method_name, *data)
150
+
151
+ def _onedal_gpu_supported(self, method_name, *data):
152
+ patching_status = PatchingConditionsChain(
153
+ f"sklearn.linear_model.{self.__class__.__name__}.{method_name}"
154
+ )
155
+
156
+ if method_name == "fit" and not daal_check_version((2025, "P", 200)):
157
+ n_samples = _num_samples(data[0])
158
+ n_features = _num_features(data[0], fallback_1d=True)
159
+ is_underdetermined = n_samples < (n_features + int(self.fit_intercept))
160
+ patching_status.and_conditions(
161
+ [
162
+ (
163
+ not is_underdetermined,
164
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
165
+ "Number of features + 1 >= number of samples.",
166
+ )
167
+ ]
168
+ )
169
+
170
+ return self._onedal_supported(patching_status, method_name, *data)
171
+
172
+ def _onedal_supported(self, patching_status, method_name, *data):
173
+ if method_name == "fit":
174
+ return self._onedal_fit_supported(patching_status, method_name, *data)
175
+ if method_name in ["predict", "score"]:
176
+ return self._onedal_predict_supported(patching_status, method_name, *data)
177
+ raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")
178
+
179
+ def _onedal_fit_supported(self, patching_status, method_name, *data):
180
+ assert method_name == "fit"
181
+ assert len(data) == 3
182
+ X, y, sample_weight = data
183
+
184
+ normalize_is_set = (
185
+ hasattr(self, "normalize")
186
+ and self.normalize
187
+ and self.normalize != "deprecated"
188
+ )
189
+ positive_is_set = hasattr(self, "positive") and self.positive
190
+
191
+ n_samples = _num_samples(X)
192
+ n_features = _num_features(X, fallback_1d=True)
193
+
194
+ # Note: support for some variants was either introduced in oneDAL 2025.1,
195
+ # or had bugs in some uncommon cases in older versions.
196
+ is_underdetermined = n_samples < (n_features + int(self.fit_intercept))
197
+ supports_all_variants = daal_check_version((2025, "P", 1))
198
+ is_multi_output = _num_features(y, fallback_1d=True) > 1
199
+
200
+ patching_status.and_conditions(
201
+ [
202
+ (sample_weight is None, "Sample weight is not supported."),
203
+ (
204
+ not issparse(X) and not issparse(y),
205
+ "Sparse input is not supported.",
206
+ ),
207
+ (not normalize_is_set, "Normalization is not supported."),
208
+ (
209
+ not positive_is_set,
210
+ "Forced positive coefficients are not supported.",
211
+ ),
212
+ (
213
+ not is_underdetermined or supports_all_variants,
214
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
215
+ "Number of features + 1 >= number of samples.",
216
+ ),
217
+ (
218
+ not is_multi_output or supports_all_variants,
219
+ "Multi-output regression is not supported.",
220
+ ),
221
+ ]
222
+ )
223
+
224
+ return patching_status
225
+
226
+ def _onedal_predict_supported(self, patching_status, method_name, *data):
227
+ n_samples = _num_samples(data[0])
228
+ model_is_sparse = issparse(self.coef_) or (
229
+ self.fit_intercept and issparse(self.intercept_)
230
+ )
231
+ patching_status.and_conditions(
232
+ [
233
+ (n_samples > 0, "Number of samples is less than 1."),
234
+ (not issparse(data[0]), "Sparse input is not supported."),
235
+ (not model_is_sparse, "Sparse coefficients are not supported."),
236
+ ]
237
+ )
238
+
239
+ return patching_status
240
+
241
+ def _initialize_onedal_estimator(self):
242
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
243
+ self._onedal_estimator = onedal_LinearRegression(**onedal_params)
244
+
245
+ def _onedal_fit(self, X, y, sample_weight, queue=None):
246
+ assert sample_weight is None
247
+
248
+ supports_multi_output = daal_check_version((2025, "P", 1))
249
+ check_params = {
250
+ "X": X,
251
+ "y": y,
252
+ "dtype": [np.float64, np.float32],
253
+ "accept_sparse": ["csr", "csc", "coo"],
254
+ "y_numeric": True,
255
+ "multi_output": supports_multi_output,
256
+ }
257
+ if sklearn_check_version("1.0"):
258
+ X, y = validate_data(self, **check_params)
259
+ else:
260
+ X, y = check_X_y(**check_params)
261
+
262
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
263
+ self._normalize = _deprecate_normalize(
264
+ self.normalize,
265
+ default=False,
266
+ estimator_name=self.__class__.__name__,
267
+ )
268
+
269
+ self._initialize_onedal_estimator()
270
+ # TODO:
271
+ # impl wrapper/primitive for this case.
272
+ if get_config()["allow_sklearn_after_onedal"]:
273
+ try:
274
+ self._onedal_estimator.fit(X, y, queue=queue)
275
+ self._save_attributes()
276
+
277
+ except RuntimeError:
278
+ logging.getLogger("sklearnex").info(
279
+ f"{self.__class__.__name__}.fit "
280
+ + get_patch_message("sklearn_after_onedal")
281
+ )
282
+
283
+ del self._onedal_estimator
284
+ super().fit(X, y)
285
+ else:
286
+ self._onedal_estimator.fit(X, y, queue=queue)
287
+ self._save_attributes()
288
+
289
+ def _onedal_predict(self, X, queue=None):
290
+ if sklearn_check_version("1.0"):
291
+ X = validate_data(self, X, accept_sparse=False, reset=False)
292
+ else:
293
+ X = check_array(X, accept_sparse=False)
294
+
295
+ if not hasattr(self, "_onedal_estimator"):
296
+ self._initialize_onedal_estimator()
297
+ self._onedal_estimator.coef_ = self.coef_
298
+ self._onedal_estimator.intercept_ = self.intercept_
299
+
300
+ res = self._onedal_estimator.predict(X, queue=queue)
301
+ return res
302
+
303
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
304
+ return r2_score(
305
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
306
+ )
307
+
308
+ @property
309
+ def coef_(self):
310
+ return self._coef_
311
+
312
+ @coef_.setter
313
+ def coef_(self, value):
314
+ self._coef_ = value
315
+ if hasattr(self, "_onedal_estimator"):
316
+ self._onedal_estimator.coef_ = value
317
+ del self._onedal_estimator._onedal_model
318
+
319
+ @coef_.deleter
320
+ def coef_(self):
321
+ del self._coef_
322
+
323
+ @property
324
+ def intercept_(self):
325
+ return self._intercept_
326
+
327
+ @intercept_.setter
328
+ def intercept_(self, value):
329
+ self._intercept_ = value
330
+ if hasattr(self, "_onedal_estimator"):
331
+ self._onedal_estimator.intercept_ = value
332
+ del self._onedal_estimator._onedal_model
333
+
334
+ @intercept_.deleter
335
+ def intercept_(self):
336
+ del self._intercept_
337
+
338
+ def _save_attributes(self):
339
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
340
+ self._sparse = False
341
+ self._coef_ = self._onedal_estimator.coef_
342
+ self._intercept_ = self._onedal_estimator.intercept_
343
+
344
+ fit.__doc__ = _sklearn_LinearRegression.fit.__doc__
345
+ predict.__doc__ = _sklearn_LinearRegression.predict.__doc__
346
+ score.__doc__ = _sklearn_LinearRegression.score.__doc__