scikit-learn-intelex 2025.0.0__py39-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (278) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-39-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-39-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +242 -0
  10. daal4py/sklearn/_utils.py +241 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +155 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +53 -0
  61. onedal/_device_offload.py +229 -0
  62. onedal/_onedal_py_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-39-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +560 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +116 -0
  83. onedal/common/tests/test_policy.py +75 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +95 -0
  91. onedal/datatypes/tests/test_data.py +235 -0
  92. onedal/decomposition/__init__.py +20 -0
  93. onedal/decomposition/incremental_pca.py +204 -0
  94. onedal/decomposition/pca.py +186 -0
  95. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  96. onedal/ensemble/__init__.py +29 -0
  97. onedal/ensemble/forest.py +720 -0
  98. onedal/ensemble/tests/test_random_forest.py +97 -0
  99. onedal/linear_model/__init__.py +27 -0
  100. onedal/linear_model/incremental_linear_model.py +258 -0
  101. onedal/linear_model/linear_model.py +329 -0
  102. onedal/linear_model/logistic_regression.py +249 -0
  103. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  104. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  105. onedal/linear_model/tests/test_linear_regression.py +149 -0
  106. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  107. onedal/linear_model/tests/test_ridge.py +95 -0
  108. onedal/neighbors/__init__.py +19 -0
  109. onedal/neighbors/neighbors.py +778 -0
  110. onedal/neighbors/tests/test_knn_classification.py +49 -0
  111. onedal/primitives/__init__.py +27 -0
  112. onedal/primitives/get_tree.py +25 -0
  113. onedal/primitives/kernel_functions.py +153 -0
  114. onedal/primitives/tests/test_kernel_functions.py +159 -0
  115. onedal/spmd/__init__.py +25 -0
  116. onedal/spmd/_base.py +30 -0
  117. onedal/spmd/basic_statistics/__init__.py +20 -0
  118. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  119. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  120. onedal/spmd/cluster/__init__.py +28 -0
  121. onedal/spmd/cluster/dbscan.py +23 -0
  122. onedal/spmd/cluster/kmeans.py +56 -0
  123. onedal/spmd/covariance/__init__.py +20 -0
  124. onedal/spmd/covariance/covariance.py +26 -0
  125. onedal/spmd/covariance/incremental_covariance.py +82 -0
  126. onedal/spmd/decomposition/__init__.py +20 -0
  127. onedal/spmd/decomposition/incremental_pca.py +117 -0
  128. onedal/spmd/decomposition/pca.py +26 -0
  129. onedal/spmd/ensemble/__init__.py +19 -0
  130. onedal/spmd/ensemble/forest.py +28 -0
  131. onedal/spmd/linear_model/__init__.py +21 -0
  132. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  133. onedal/spmd/linear_model/linear_model.py +30 -0
  134. onedal/spmd/linear_model/logistic_regression.py +38 -0
  135. onedal/spmd/neighbors/__init__.py +19 -0
  136. onedal/spmd/neighbors/neighbors.py +75 -0
  137. onedal/svm/__init__.py +19 -0
  138. onedal/svm/svm.py +556 -0
  139. onedal/svm/tests/test_csr_svm.py +351 -0
  140. onedal/svm/tests/test_nusvc.py +204 -0
  141. onedal/svm/tests/test_nusvr.py +210 -0
  142. onedal/svm/tests/test_svc.py +168 -0
  143. onedal/svm/tests/test_svr.py +243 -0
  144. onedal/tests/test_common.py +41 -0
  145. onedal/tests/utils/_dataframes_support.py +168 -0
  146. onedal/tests/utils/_device_selection.py +107 -0
  147. onedal/utils/__init__.py +49 -0
  148. onedal/utils/_array_api.py +91 -0
  149. onedal/utils/validation.py +432 -0
  150. scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
  151. scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
  152. scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
  153. scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
  154. scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
  155. sklearnex/__init__.py +65 -0
  156. sklearnex/__main__.py +58 -0
  157. sklearnex/_config.py +98 -0
  158. sklearnex/_device_offload.py +121 -0
  159. sklearnex/_utils.py +109 -0
  160. sklearnex/basic_statistics/__init__.py +20 -0
  161. sklearnex/basic_statistics/basic_statistics.py +140 -0
  162. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  163. sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
  164. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
  165. sklearnex/cluster/__init__.py +20 -0
  166. sklearnex/cluster/dbscan.py +192 -0
  167. sklearnex/cluster/k_means.py +383 -0
  168. sklearnex/cluster/tests/test_dbscan.py +38 -0
  169. sklearnex/cluster/tests/test_kmeans.py +153 -0
  170. sklearnex/conftest.py +73 -0
  171. sklearnex/covariance/__init__.py +19 -0
  172. sklearnex/covariance/incremental_covariance.py +368 -0
  173. sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
  174. sklearnex/decomposition/__init__.py +19 -0
  175. sklearnex/decomposition/pca.py +414 -0
  176. sklearnex/decomposition/tests/test_pca.py +58 -0
  177. sklearnex/dispatcher.py +543 -0
  178. sklearnex/doc/third-party-programs.txt +424 -0
  179. sklearnex/ensemble/__init__.py +29 -0
  180. sklearnex/ensemble/_forest.py +2016 -0
  181. sklearnex/ensemble/tests/test_forest.py +120 -0
  182. sklearnex/glob/__main__.py +72 -0
  183. sklearnex/glob/dispatcher.py +101 -0
  184. sklearnex/linear_model/__init__.py +32 -0
  185. sklearnex/linear_model/coordinate_descent.py +30 -0
  186. sklearnex/linear_model/incremental_linear.py +463 -0
  187. sklearnex/linear_model/incremental_ridge.py +418 -0
  188. sklearnex/linear_model/linear.py +302 -0
  189. sklearnex/linear_model/logistic_path.py +17 -0
  190. sklearnex/linear_model/logistic_regression.py +403 -0
  191. sklearnex/linear_model/ridge.py +24 -0
  192. sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
  193. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  194. sklearnex/linear_model/tests/test_linear.py +142 -0
  195. sklearnex/linear_model/tests/test_logreg.py +134 -0
  196. sklearnex/manifold/__init__.py +19 -0
  197. sklearnex/manifold/t_sne.py +21 -0
  198. sklearnex/manifold/tests/test_tsne.py +26 -0
  199. sklearnex/metrics/__init__.py +23 -0
  200. sklearnex/metrics/pairwise.py +22 -0
  201. sklearnex/metrics/ranking.py +20 -0
  202. sklearnex/metrics/tests/test_metrics.py +39 -0
  203. sklearnex/model_selection/__init__.py +21 -0
  204. sklearnex/model_selection/split.py +22 -0
  205. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  206. sklearnex/neighbors/__init__.py +27 -0
  207. sklearnex/neighbors/_lof.py +231 -0
  208. sklearnex/neighbors/common.py +310 -0
  209. sklearnex/neighbors/knn_classification.py +226 -0
  210. sklearnex/neighbors/knn_regression.py +203 -0
  211. sklearnex/neighbors/knn_unsupervised.py +170 -0
  212. sklearnex/neighbors/tests/test_neighbors.py +80 -0
  213. sklearnex/preview/__init__.py +17 -0
  214. sklearnex/preview/covariance/__init__.py +19 -0
  215. sklearnex/preview/covariance/covariance.py +133 -0
  216. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  217. sklearnex/preview/decomposition/__init__.py +19 -0
  218. sklearnex/preview/decomposition/incremental_pca.py +228 -0
  219. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  220. sklearnex/preview/linear_model/__init__.py +19 -0
  221. sklearnex/preview/linear_model/ridge.py +419 -0
  222. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  223. sklearnex/spmd/__init__.py +25 -0
  224. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  225. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  226. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  227. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  228. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  229. sklearnex/spmd/cluster/__init__.py +30 -0
  230. sklearnex/spmd/cluster/dbscan.py +50 -0
  231. sklearnex/spmd/cluster/kmeans.py +21 -0
  232. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  233. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  234. sklearnex/spmd/covariance/__init__.py +20 -0
  235. sklearnex/spmd/covariance/covariance.py +21 -0
  236. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  237. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  238. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  239. sklearnex/spmd/decomposition/__init__.py +20 -0
  240. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  241. sklearnex/spmd/decomposition/pca.py +21 -0
  242. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  243. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  244. sklearnex/spmd/ensemble/__init__.py +19 -0
  245. sklearnex/spmd/ensemble/forest.py +71 -0
  246. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  247. sklearnex/spmd/linear_model/__init__.py +21 -0
  248. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  249. sklearnex/spmd/linear_model/linear_model.py +21 -0
  250. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  251. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  252. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  253. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
  254. sklearnex/spmd/neighbors/__init__.py +19 -0
  255. sklearnex/spmd/neighbors/neighbors.py +25 -0
  256. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  257. sklearnex/svm/__init__.py +29 -0
  258. sklearnex/svm/_common.py +328 -0
  259. sklearnex/svm/nusvc.py +332 -0
  260. sklearnex/svm/nusvr.py +148 -0
  261. sklearnex/svm/svc.py +360 -0
  262. sklearnex/svm/svr.py +149 -0
  263. sklearnex/svm/tests/test_svm.py +93 -0
  264. sklearnex/tests/_utils.py +328 -0
  265. sklearnex/tests/_utils_spmd.py +198 -0
  266. sklearnex/tests/test_common.py +54 -0
  267. sklearnex/tests/test_config.py +43 -0
  268. sklearnex/tests/test_memory_usage.py +291 -0
  269. sklearnex/tests/test_monkeypatch.py +276 -0
  270. sklearnex/tests/test_n_jobs_support.py +103 -0
  271. sklearnex/tests/test_parallel.py +48 -0
  272. sklearnex/tests/test_patching.py +385 -0
  273. sklearnex/tests/test_run_to_run_stability.py +296 -0
  274. sklearnex/utils/__init__.py +19 -0
  275. sklearnex/utils/_array_api.py +82 -0
  276. sklearnex/utils/parallel.py +59 -0
  277. sklearnex/utils/tests/test_finite.py +89 -0
  278. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,302 @@
1
+ # ===============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import logging
18
+ from abc import ABC
19
+
20
+ import numpy as np
21
+ from sklearn.exceptions import NotFittedError
22
+ from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
23
+ from sklearn.metrics import r2_score
24
+
25
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
26
+ from daal4py.sklearn._utils import sklearn_check_version
27
+
28
+ from .._device_offload import dispatch, wrap_output_data
29
+ from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters
30
+
31
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
32
+ from sklearn.linear_model._base import _deprecate_normalize
33
+
34
+ from scipy.sparse import issparse
35
+ from sklearn.utils.validation import check_X_y
36
+
37
+ from onedal.common.hyperparameters import get_hyperparameters
38
+ from onedal.linear_model import LinearRegression as onedal_LinearRegression
39
+ from onedal.utils import _num_features, _num_samples
40
+
41
+
42
+ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
43
+ @control_n_jobs(decorated_methods=["fit", "predict"])
44
+ class LinearRegression(sklearn_LinearRegression):
45
+ __doc__ = sklearn_LinearRegression.__doc__
46
+
47
+ if sklearn_check_version("1.2"):
48
+ _parameter_constraints: dict = {**sklearn_LinearRegression._parameter_constraints}
49
+
50
+ def __init__(
51
+ self,
52
+ fit_intercept=True,
53
+ copy_X=True,
54
+ n_jobs=None,
55
+ positive=False,
56
+ ):
57
+ super().__init__(
58
+ fit_intercept=fit_intercept,
59
+ copy_X=copy_X,
60
+ n_jobs=n_jobs,
61
+ positive=positive,
62
+ )
63
+
64
+ else:
65
+
66
+ def __init__(
67
+ self,
68
+ fit_intercept=True,
69
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
70
+ copy_X=True,
71
+ n_jobs=None,
72
+ positive=False,
73
+ ):
74
+ super().__init__(
75
+ fit_intercept=fit_intercept,
76
+ normalize=normalize,
77
+ copy_X=copy_X,
78
+ n_jobs=n_jobs,
79
+ positive=positive,
80
+ )
81
+
82
+ def fit(self, X, y, sample_weight=None):
83
+ if sklearn_check_version("1.0"):
84
+ self._check_feature_names(X, reset=True)
85
+ if sklearn_check_version("1.2"):
86
+ self._validate_params()
87
+
88
+ # It is necessary to properly update coefs for predict if we
89
+ # fallback to sklearn in dispatch
90
+ if hasattr(self, "_onedal_estimator"):
91
+ del self._onedal_estimator
92
+
93
+ dispatch(
94
+ self,
95
+ "fit",
96
+ {
97
+ "onedal": self.__class__._onedal_fit,
98
+ "sklearn": sklearn_LinearRegression.fit,
99
+ },
100
+ X,
101
+ y,
102
+ sample_weight,
103
+ )
104
+ return self
105
+
106
+ @wrap_output_data
107
+ def predict(self, X):
108
+
109
+ if not hasattr(self, "coef_"):
110
+ msg = (
111
+ "This %(name)s instance is not fitted yet. Call 'fit' with "
112
+ "appropriate arguments before using this estimator."
113
+ )
114
+ raise NotFittedError(msg % {"name": self.__class__.__name__})
115
+
116
+ return dispatch(
117
+ self,
118
+ "predict",
119
+ {
120
+ "onedal": self.__class__._onedal_predict,
121
+ "sklearn": sklearn_LinearRegression.predict,
122
+ },
123
+ X,
124
+ )
125
+
126
+ @wrap_output_data
127
+ def score(self, X, y, sample_weight=None):
128
+ return dispatch(
129
+ self,
130
+ "score",
131
+ {
132
+ "onedal": self.__class__._onedal_score,
133
+ "sklearn": sklearn_LinearRegression.score,
134
+ },
135
+ X,
136
+ y,
137
+ sample_weight=sample_weight,
138
+ )
139
+
140
+ def _onedal_fit_supported(self, method_name, *data):
141
+ assert method_name == "fit"
142
+ assert len(data) == 3
143
+ X, y, sample_weight = data
144
+
145
+ class_name = self.__class__.__name__
146
+ patching_status = PatchingConditionsChain(
147
+ f"sklearn.linear_model.{class_name}.fit"
148
+ )
149
+
150
+ normalize_is_set = (
151
+ hasattr(self, "normalize")
152
+ and self.normalize
153
+ and self.normalize != "deprecated"
154
+ )
155
+ positive_is_set = hasattr(self, "positive") and self.positive
156
+
157
+ n_samples = _num_samples(X)
158
+ n_features = _num_features(X, fallback_1d=True)
159
+
160
+ # Check if equations are well defined
161
+ is_underdetermined = n_samples < (n_features + int(self.fit_intercept))
162
+
163
+ patching_status.and_conditions(
164
+ [
165
+ (sample_weight is None, "Sample weight is not supported."),
166
+ (
167
+ not issparse(X) and not issparse(y),
168
+ "Sparse input is not supported.",
169
+ ),
170
+ (not normalize_is_set, "Normalization is not supported."),
171
+ (
172
+ not positive_is_set,
173
+ "Forced positive coefficients are not supported.",
174
+ ),
175
+ (
176
+ not is_underdetermined,
177
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
178
+ "Number of features + 1 >= number of samples.",
179
+ ),
180
+ ]
181
+ )
182
+
183
+ return patching_status
184
+
185
+ def _onedal_predict_supported(self, method_name, *data):
186
+ class_name = self.__class__.__name__
187
+ patching_status = PatchingConditionsChain(
188
+ f"sklearn.linear_model.{class_name}.predict"
189
+ )
190
+
191
+ n_samples = _num_samples(data[0])
192
+ model_is_sparse = issparse(self.coef_) or (
193
+ self.fit_intercept and issparse(self.intercept_)
194
+ )
195
+ patching_status.and_conditions(
196
+ [
197
+ (n_samples > 0, "Number of samples is less than 1."),
198
+ (not issparse(data[0]), "Sparse input is not supported."),
199
+ (not model_is_sparse, "Sparse coefficients are not supported."),
200
+ ]
201
+ )
202
+
203
+ return patching_status
204
+
205
+ def _onedal_supported(self, method_name, *data):
206
+ if method_name == "fit":
207
+ return self._onedal_fit_supported(method_name, *data)
208
+ if method_name in ["predict", "score"]:
209
+ return self._onedal_predict_supported(method_name, *data)
210
+ raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")
211
+
212
+ _onedal_gpu_supported = _onedal_supported
213
+ _onedal_cpu_supported = _onedal_supported
214
+
215
+ def _initialize_onedal_estimator(self):
216
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
217
+ self._onedal_estimator = onedal_LinearRegression(**onedal_params)
218
+
219
+ def _onedal_fit(self, X, y, sample_weight, queue=None):
220
+ assert sample_weight is None
221
+
222
+ check_params = {
223
+ "X": X,
224
+ "y": y,
225
+ "dtype": [np.float64, np.float32],
226
+ "accept_sparse": ["csr", "csc", "coo"],
227
+ "y_numeric": True,
228
+ "multi_output": True,
229
+ }
230
+ if sklearn_check_version("1.2"):
231
+ X, y = self._validate_data(**check_params)
232
+ else:
233
+ X, y = check_X_y(**check_params)
234
+
235
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
236
+ self._normalize = _deprecate_normalize(
237
+ self.normalize,
238
+ default=False,
239
+ estimator_name=self.__class__.__name__,
240
+ )
241
+
242
+ self._initialize_onedal_estimator()
243
+ try:
244
+ self._onedal_estimator.fit(X, y, queue=queue)
245
+ self._save_attributes()
246
+
247
+ except RuntimeError:
248
+ logging.getLogger("sklearnex").info(
249
+ f"{self.__class__.__name__}.fit "
250
+ + get_patch_message("sklearn_after_onedal")
251
+ )
252
+
253
+ del self._onedal_estimator
254
+ super().fit(X, y)
255
+
256
+ def _onedal_predict(self, X, queue=None):
257
+ if sklearn_check_version("1.0"):
258
+ self._check_feature_names(X, reset=False)
259
+
260
+ X = self._validate_data(X, accept_sparse=False, reset=False)
261
+ if not hasattr(self, "_onedal_estimator"):
262
+ self._initialize_onedal_estimator()
263
+ self._onedal_estimator.coef_ = self.coef_
264
+ self._onedal_estimator.intercept_ = self.intercept_
265
+
266
+ res = self._onedal_estimator.predict(X, queue=queue)
267
+ return res
268
+
269
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
270
+ return r2_score(
271
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
272
+ )
273
+
274
+ def get_coef_(self):
275
+ return self.coef_
276
+
277
+ def set_coef_(self, value):
278
+ self.__dict__["coef_"] = value
279
+ if hasattr(self, "_onedal_estimator"):
280
+ self._onedal_estimator.coef_ = value
281
+ del self._onedal_estimator._onedal_model
282
+
283
+ def get_intercept_(self):
284
+ return self.intercept_
285
+
286
+ def set_intercept_(self, value):
287
+ self.__dict__["intercept_"] = value
288
+ if hasattr(self, "_onedal_estimator"):
289
+ self._onedal_estimator.intercept_ = value
290
+ del self._onedal_estimator._onedal_model
291
+
292
+ def _save_attributes(self):
293
+ self.coef_ = property(self.get_coef_, self.set_coef_)
294
+ self.intercept_ = property(self.get_intercept_, self.set_intercept_)
295
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
296
+ self._sparse = False
297
+ self.__dict__["coef_"] = self._onedal_estimator.coef_
298
+ self.__dict__["intercept_"] = self._onedal_estimator.intercept_
299
+
300
+ fit.__doc__ = sklearn_LinearRegression.fit.__doc__
301
+ predict.__doc__ = sklearn_LinearRegression.predict.__doc__
302
+ score.__doc__ = sklearn_LinearRegression.score.__doc__
@@ -0,0 +1,17 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn.linear_model import LogisticRegression, logistic_regression_path