scikit-learn-intelex 2025.1.0__py39-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-39-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-39-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-39-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,341 @@
1
+ # ===============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import logging
18
+ from abc import ABC
19
+
20
+ import numpy as np
21
+ from sklearn.linear_model import LinearRegression as _sklearn_LinearRegression
22
+ from sklearn.metrics import r2_score
23
+ from sklearn.utils.validation import check_array, check_is_fitted
24
+
25
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
26
+ from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
27
+
28
+ from .._config import get_config
29
+ from .._device_offload import dispatch, wrap_output_data
30
+ from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters
31
+
32
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
33
+ from sklearn.linear_model._base import _deprecate_normalize
34
+
35
+ from scipy.sparse import issparse
36
+ from sklearn.utils.validation import check_is_fitted, check_X_y
37
+
38
+ from onedal.common.hyperparameters import get_hyperparameters
39
+ from onedal.linear_model import LinearRegression as onedal_LinearRegression
40
+ from onedal.utils import _num_features, _num_samples
41
+
42
+ if sklearn_check_version("1.6"):
43
+ from sklearn.utils.validation import validate_data
44
+ else:
45
+ validate_data = _sklearn_LinearRegression._validate_data
46
+
47
+
48
+ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
49
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
50
+ class LinearRegression(_sklearn_LinearRegression):
51
+ __doc__ = _sklearn_LinearRegression.__doc__
52
+
53
+ if sklearn_check_version("1.2"):
54
+ _parameter_constraints: dict = {
55
+ **_sklearn_LinearRegression._parameter_constraints
56
+ }
57
+
58
+ def __init__(
59
+ self,
60
+ fit_intercept=True,
61
+ copy_X=True,
62
+ n_jobs=None,
63
+ positive=False,
64
+ ):
65
+ super().__init__(
66
+ fit_intercept=fit_intercept,
67
+ copy_X=copy_X,
68
+ n_jobs=n_jobs,
69
+ positive=positive,
70
+ )
71
+
72
+ else:
73
+
74
+ def __init__(
75
+ self,
76
+ fit_intercept=True,
77
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
78
+ copy_X=True,
79
+ n_jobs=None,
80
+ positive=False,
81
+ ):
82
+ super().__init__(
83
+ fit_intercept=fit_intercept,
84
+ normalize=normalize,
85
+ copy_X=copy_X,
86
+ n_jobs=n_jobs,
87
+ positive=positive,
88
+ )
89
+
90
+ def fit(self, X, y, sample_weight=None):
91
+ if sklearn_check_version("1.2"):
92
+ self._validate_params()
93
+
94
+ # It is necessary to properly update coefs for predict if we
95
+ # fallback to sklearn in dispatch
96
+ if hasattr(self, "_onedal_estimator"):
97
+ del self._onedal_estimator
98
+
99
+ dispatch(
100
+ self,
101
+ "fit",
102
+ {
103
+ "onedal": self.__class__._onedal_fit,
104
+ "sklearn": _sklearn_LinearRegression.fit,
105
+ },
106
+ X,
107
+ y,
108
+ sample_weight,
109
+ )
110
+ return self
111
+
112
+ @wrap_output_data
113
+ def predict(self, X):
114
+ check_is_fitted(self)
115
+ return dispatch(
116
+ self,
117
+ "predict",
118
+ {
119
+ "onedal": self.__class__._onedal_predict,
120
+ "sklearn": _sklearn_LinearRegression.predict,
121
+ },
122
+ X,
123
+ )
124
+
125
+ @wrap_output_data
126
+ def score(self, X, y, sample_weight=None):
127
+ check_is_fitted(self)
128
+ return dispatch(
129
+ self,
130
+ "score",
131
+ {
132
+ "onedal": self.__class__._onedal_score,
133
+ "sklearn": _sklearn_LinearRegression.score,
134
+ },
135
+ X,
136
+ y,
137
+ sample_weight=sample_weight,
138
+ )
139
+
140
+ def _onedal_cpu_supported(self, method_name, *data):
141
+ patching_status = PatchingConditionsChain(
142
+ f"sklearn.linear_model.{self.__class__.__name__}.{method_name}"
143
+ )
144
+ return self._onedal_supported(patching_status, method_name, *data)
145
+
146
+ def _onedal_gpu_supported(self, method_name, *data):
147
+ patching_status = PatchingConditionsChain(
148
+ f"sklearn.linear_model.{self.__class__.__name__}.{method_name}"
149
+ )
150
+
151
+ if method_name == "fit":
152
+ n_samples = _num_samples(data[0])
153
+ n_features = _num_features(data[0], fallback_1d=True)
154
+ is_underdetermined = n_samples < (n_features + int(self.fit_intercept))
155
+ patching_status.and_conditions(
156
+ [
157
+ (
158
+ not is_underdetermined,
159
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
160
+ "Number of features + 1 >= number of samples.",
161
+ )
162
+ ]
163
+ )
164
+
165
+ return self._onedal_supported(patching_status, method_name, *data)
166
+
167
+ def _onedal_supported(self, patching_status, method_name, *data):
168
+ if method_name == "fit":
169
+ return self._onedal_fit_supported(patching_status, method_name, *data)
170
+ if method_name in ["predict", "score"]:
171
+ return self._onedal_predict_supported(patching_status, method_name, *data)
172
+ raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")
173
+
174
+ def _onedal_fit_supported(self, patching_status, method_name, *data):
175
+ assert method_name == "fit"
176
+ assert len(data) == 3
177
+ X, y, sample_weight = data
178
+
179
+ normalize_is_set = (
180
+ hasattr(self, "normalize")
181
+ and self.normalize
182
+ and self.normalize != "deprecated"
183
+ )
184
+ positive_is_set = hasattr(self, "positive") and self.positive
185
+
186
+ n_samples = _num_samples(X)
187
+ n_features = _num_features(X, fallback_1d=True)
188
+
189
+ # Note: support for some variants was either introduced in oneDAL 2025.1,
190
+ # or had bugs in some uncommon cases in older versions.
191
+ is_underdetermined = n_samples < (n_features + int(self.fit_intercept))
192
+ supports_all_variants = daal_check_version((2025, "P", 1))
193
+ is_multi_output = _num_features(y, fallback_1d=True) > 1
194
+
195
+ patching_status.and_conditions(
196
+ [
197
+ (sample_weight is None, "Sample weight is not supported."),
198
+ (
199
+ not issparse(X) and not issparse(y),
200
+ "Sparse input is not supported.",
201
+ ),
202
+ (not normalize_is_set, "Normalization is not supported."),
203
+ (
204
+ not positive_is_set,
205
+ "Forced positive coefficients are not supported.",
206
+ ),
207
+ (
208
+ not is_underdetermined or supports_all_variants,
209
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
210
+ "Number of features + 1 >= number of samples.",
211
+ ),
212
+ (
213
+ not is_multi_output or supports_all_variants,
214
+ "Multi-output regression is not supported.",
215
+ ),
216
+ ]
217
+ )
218
+
219
+ return patching_status
220
+
221
+ def _onedal_predict_supported(self, patching_status, method_name, *data):
222
+ n_samples = _num_samples(data[0])
223
+ model_is_sparse = issparse(self.coef_) or (
224
+ self.fit_intercept and issparse(self.intercept_)
225
+ )
226
+ patching_status.and_conditions(
227
+ [
228
+ (n_samples > 0, "Number of samples is less than 1."),
229
+ (not issparse(data[0]), "Sparse input is not supported."),
230
+ (not model_is_sparse, "Sparse coefficients are not supported."),
231
+ ]
232
+ )
233
+
234
+ return patching_status
235
+
236
+ def _initialize_onedal_estimator(self):
237
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
238
+ self._onedal_estimator = onedal_LinearRegression(**onedal_params)
239
+
240
+ def _onedal_fit(self, X, y, sample_weight, queue=None):
241
+ assert sample_weight is None
242
+
243
+ supports_multi_output = daal_check_version((2025, "P", 1))
244
+ check_params = {
245
+ "X": X,
246
+ "y": y,
247
+ "dtype": [np.float64, np.float32],
248
+ "accept_sparse": ["csr", "csc", "coo"],
249
+ "y_numeric": True,
250
+ "multi_output": supports_multi_output,
251
+ }
252
+ if sklearn_check_version("1.0"):
253
+ X, y = validate_data(self, **check_params)
254
+ else:
255
+ X, y = check_X_y(**check_params)
256
+
257
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
258
+ self._normalize = _deprecate_normalize(
259
+ self.normalize,
260
+ default=False,
261
+ estimator_name=self.__class__.__name__,
262
+ )
263
+
264
+ self._initialize_onedal_estimator()
265
+ # TODO:
266
+ # impl wrapper/primitive for this case.
267
+ if get_config()["allow_sklearn_after_onedal"]:
268
+ try:
269
+ self._onedal_estimator.fit(X, y, queue=queue)
270
+ self._save_attributes()
271
+
272
+ except RuntimeError:
273
+ logging.getLogger("sklearnex").info(
274
+ f"{self.__class__.__name__}.fit "
275
+ + get_patch_message("sklearn_after_onedal")
276
+ )
277
+
278
+ del self._onedal_estimator
279
+ super().fit(X, y)
280
+ else:
281
+ self._onedal_estimator.fit(X, y, queue=queue)
282
+ self._save_attributes()
283
+
284
+ def _onedal_predict(self, X, queue=None):
285
+ if sklearn_check_version("1.0"):
286
+ X = validate_data(self, X, accept_sparse=False, reset=False)
287
+ else:
288
+ X = check_array(X, accept_sparse=False)
289
+
290
+ if not hasattr(self, "_onedal_estimator"):
291
+ self._initialize_onedal_estimator()
292
+ self._onedal_estimator.coef_ = self.coef_
293
+ self._onedal_estimator.intercept_ = self.intercept_
294
+
295
+ res = self._onedal_estimator.predict(X, queue=queue)
296
+ return res
297
+
298
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
299
+ return r2_score(
300
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
301
+ )
302
+
303
+ @property
304
+ def coef_(self):
305
+ return self._coef_
306
+
307
+ @coef_.setter
308
+ def coef_(self, value):
309
+ self._coef_ = value
310
+ if hasattr(self, "_onedal_estimator"):
311
+ self._onedal_estimator.coef_ = value
312
+ del self._onedal_estimator._onedal_model
313
+
314
+ @coef_.deleter
315
+ def coef_(self):
316
+ del self._coef_
317
+
318
+ @property
319
+ def intercept_(self):
320
+ return self._intercept_
321
+
322
+ @intercept_.setter
323
+ def intercept_(self, value):
324
+ self._intercept_ = value
325
+ if hasattr(self, "_onedal_estimator"):
326
+ self._onedal_estimator.intercept_ = value
327
+ del self._onedal_estimator._onedal_model
328
+
329
+ @intercept_.deleter
330
+ def intercept_(self):
331
+ del self._intercept_
332
+
333
+ def _save_attributes(self):
334
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
335
+ self._sparse = False
336
+ self._coef_ = self._onedal_estimator.coef_
337
+ self._intercept_ = self._onedal_estimator.intercept_
338
+
339
+ fit.__doc__ = _sklearn_LinearRegression.fit.__doc__
340
+ predict.__doc__ = _sklearn_LinearRegression.predict.__doc__
341
+ score.__doc__ = _sklearn_LinearRegression.score.__doc__