scikit-learn-intelex 2025.0.0__py312-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (278) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-312-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-312-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +242 -0
  10. daal4py/sklearn/_utils.py +241 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +155 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +53 -0
  61. onedal/_device_offload.py +229 -0
  62. onedal/_onedal_py_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-312-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +560 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +116 -0
  83. onedal/common/tests/test_policy.py +75 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +95 -0
  91. onedal/datatypes/tests/test_data.py +235 -0
  92. onedal/decomposition/__init__.py +20 -0
  93. onedal/decomposition/incremental_pca.py +204 -0
  94. onedal/decomposition/pca.py +186 -0
  95. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  96. onedal/ensemble/__init__.py +29 -0
  97. onedal/ensemble/forest.py +720 -0
  98. onedal/ensemble/tests/test_random_forest.py +97 -0
  99. onedal/linear_model/__init__.py +27 -0
  100. onedal/linear_model/incremental_linear_model.py +258 -0
  101. onedal/linear_model/linear_model.py +329 -0
  102. onedal/linear_model/logistic_regression.py +249 -0
  103. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  104. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  105. onedal/linear_model/tests/test_linear_regression.py +149 -0
  106. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  107. onedal/linear_model/tests/test_ridge.py +95 -0
  108. onedal/neighbors/__init__.py +19 -0
  109. onedal/neighbors/neighbors.py +778 -0
  110. onedal/neighbors/tests/test_knn_classification.py +49 -0
  111. onedal/primitives/__init__.py +27 -0
  112. onedal/primitives/get_tree.py +25 -0
  113. onedal/primitives/kernel_functions.py +153 -0
  114. onedal/primitives/tests/test_kernel_functions.py +159 -0
  115. onedal/spmd/__init__.py +25 -0
  116. onedal/spmd/_base.py +30 -0
  117. onedal/spmd/basic_statistics/__init__.py +20 -0
  118. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  119. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  120. onedal/spmd/cluster/__init__.py +28 -0
  121. onedal/spmd/cluster/dbscan.py +23 -0
  122. onedal/spmd/cluster/kmeans.py +56 -0
  123. onedal/spmd/covariance/__init__.py +20 -0
  124. onedal/spmd/covariance/covariance.py +26 -0
  125. onedal/spmd/covariance/incremental_covariance.py +82 -0
  126. onedal/spmd/decomposition/__init__.py +20 -0
  127. onedal/spmd/decomposition/incremental_pca.py +117 -0
  128. onedal/spmd/decomposition/pca.py +26 -0
  129. onedal/spmd/ensemble/__init__.py +19 -0
  130. onedal/spmd/ensemble/forest.py +28 -0
  131. onedal/spmd/linear_model/__init__.py +21 -0
  132. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  133. onedal/spmd/linear_model/linear_model.py +30 -0
  134. onedal/spmd/linear_model/logistic_regression.py +38 -0
  135. onedal/spmd/neighbors/__init__.py +19 -0
  136. onedal/spmd/neighbors/neighbors.py +75 -0
  137. onedal/svm/__init__.py +19 -0
  138. onedal/svm/svm.py +556 -0
  139. onedal/svm/tests/test_csr_svm.py +351 -0
  140. onedal/svm/tests/test_nusvc.py +204 -0
  141. onedal/svm/tests/test_nusvr.py +210 -0
  142. onedal/svm/tests/test_svc.py +168 -0
  143. onedal/svm/tests/test_svr.py +243 -0
  144. onedal/tests/test_common.py +41 -0
  145. onedal/tests/utils/_dataframes_support.py +168 -0
  146. onedal/tests/utils/_device_selection.py +107 -0
  147. onedal/utils/__init__.py +49 -0
  148. onedal/utils/_array_api.py +91 -0
  149. onedal/utils/validation.py +432 -0
  150. scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
  151. scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
  152. scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
  153. scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
  154. scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
  155. sklearnex/__init__.py +65 -0
  156. sklearnex/__main__.py +58 -0
  157. sklearnex/_config.py +98 -0
  158. sklearnex/_device_offload.py +121 -0
  159. sklearnex/_utils.py +109 -0
  160. sklearnex/basic_statistics/__init__.py +20 -0
  161. sklearnex/basic_statistics/basic_statistics.py +140 -0
  162. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  163. sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
  164. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
  165. sklearnex/cluster/__init__.py +20 -0
  166. sklearnex/cluster/dbscan.py +192 -0
  167. sklearnex/cluster/k_means.py +383 -0
  168. sklearnex/cluster/tests/test_dbscan.py +38 -0
  169. sklearnex/cluster/tests/test_kmeans.py +153 -0
  170. sklearnex/conftest.py +73 -0
  171. sklearnex/covariance/__init__.py +19 -0
  172. sklearnex/covariance/incremental_covariance.py +368 -0
  173. sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
  174. sklearnex/decomposition/__init__.py +19 -0
  175. sklearnex/decomposition/pca.py +414 -0
  176. sklearnex/decomposition/tests/test_pca.py +58 -0
  177. sklearnex/dispatcher.py +543 -0
  178. sklearnex/doc/third-party-programs.txt +424 -0
  179. sklearnex/ensemble/__init__.py +29 -0
  180. sklearnex/ensemble/_forest.py +2016 -0
  181. sklearnex/ensemble/tests/test_forest.py +120 -0
  182. sklearnex/glob/__main__.py +72 -0
  183. sklearnex/glob/dispatcher.py +101 -0
  184. sklearnex/linear_model/__init__.py +32 -0
  185. sklearnex/linear_model/coordinate_descent.py +30 -0
  186. sklearnex/linear_model/incremental_linear.py +463 -0
  187. sklearnex/linear_model/incremental_ridge.py +418 -0
  188. sklearnex/linear_model/linear.py +302 -0
  189. sklearnex/linear_model/logistic_path.py +17 -0
  190. sklearnex/linear_model/logistic_regression.py +403 -0
  191. sklearnex/linear_model/ridge.py +24 -0
  192. sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
  193. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  194. sklearnex/linear_model/tests/test_linear.py +142 -0
  195. sklearnex/linear_model/tests/test_logreg.py +134 -0
  196. sklearnex/manifold/__init__.py +19 -0
  197. sklearnex/manifold/t_sne.py +21 -0
  198. sklearnex/manifold/tests/test_tsne.py +26 -0
  199. sklearnex/metrics/__init__.py +23 -0
  200. sklearnex/metrics/pairwise.py +22 -0
  201. sklearnex/metrics/ranking.py +20 -0
  202. sklearnex/metrics/tests/test_metrics.py +39 -0
  203. sklearnex/model_selection/__init__.py +21 -0
  204. sklearnex/model_selection/split.py +22 -0
  205. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  206. sklearnex/neighbors/__init__.py +27 -0
  207. sklearnex/neighbors/_lof.py +231 -0
  208. sklearnex/neighbors/common.py +310 -0
  209. sklearnex/neighbors/knn_classification.py +226 -0
  210. sklearnex/neighbors/knn_regression.py +203 -0
  211. sklearnex/neighbors/knn_unsupervised.py +170 -0
  212. sklearnex/neighbors/tests/test_neighbors.py +80 -0
  213. sklearnex/preview/__init__.py +17 -0
  214. sklearnex/preview/covariance/__init__.py +19 -0
  215. sklearnex/preview/covariance/covariance.py +133 -0
  216. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  217. sklearnex/preview/decomposition/__init__.py +19 -0
  218. sklearnex/preview/decomposition/incremental_pca.py +228 -0
  219. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  220. sklearnex/preview/linear_model/__init__.py +19 -0
  221. sklearnex/preview/linear_model/ridge.py +419 -0
  222. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  223. sklearnex/spmd/__init__.py +25 -0
  224. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  225. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  226. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  227. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  228. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  229. sklearnex/spmd/cluster/__init__.py +30 -0
  230. sklearnex/spmd/cluster/dbscan.py +50 -0
  231. sklearnex/spmd/cluster/kmeans.py +21 -0
  232. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  233. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  234. sklearnex/spmd/covariance/__init__.py +20 -0
  235. sklearnex/spmd/covariance/covariance.py +21 -0
  236. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  237. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  238. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  239. sklearnex/spmd/decomposition/__init__.py +20 -0
  240. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  241. sklearnex/spmd/decomposition/pca.py +21 -0
  242. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  243. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  244. sklearnex/spmd/ensemble/__init__.py +19 -0
  245. sklearnex/spmd/ensemble/forest.py +71 -0
  246. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  247. sklearnex/spmd/linear_model/__init__.py +21 -0
  248. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  249. sklearnex/spmd/linear_model/linear_model.py +21 -0
  250. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  251. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  252. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  253. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
  254. sklearnex/spmd/neighbors/__init__.py +19 -0
  255. sklearnex/spmd/neighbors/neighbors.py +25 -0
  256. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  257. sklearnex/svm/__init__.py +29 -0
  258. sklearnex/svm/_common.py +328 -0
  259. sklearnex/svm/nusvc.py +332 -0
  260. sklearnex/svm/nusvr.py +148 -0
  261. sklearnex/svm/svc.py +360 -0
  262. sklearnex/svm/svr.py +149 -0
  263. sklearnex/svm/tests/test_svm.py +93 -0
  264. sklearnex/tests/_utils.py +328 -0
  265. sklearnex/tests/_utils_spmd.py +198 -0
  266. sklearnex/tests/test_common.py +54 -0
  267. sklearnex/tests/test_config.py +43 -0
  268. sklearnex/tests/test_memory_usage.py +291 -0
  269. sklearnex/tests/test_monkeypatch.py +276 -0
  270. sklearnex/tests/test_n_jobs_support.py +103 -0
  271. sklearnex/tests/test_parallel.py +48 -0
  272. sklearnex/tests/test_patching.py +385 -0
  273. sklearnex/tests/test_run_to_run_stability.py +296 -0
  274. sklearnex/utils/__init__.py +19 -0
  275. sklearnex/utils/_array_api.py +82 -0
  276. sklearnex/utils/parallel.py +59 -0
  277. sklearnex/utils/tests/test_finite.py +89 -0
  278. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,153 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn._utils import daal_check_version
18
+
19
+ if daal_check_version((2024, "P", 600)):
20
+ import numpy as np
21
+ import pytest
22
+ from numpy.testing import assert_allclose
23
+ from sklearn.exceptions import NotFittedError
24
+
25
+ from onedal.tests.utils._dataframes_support import (
26
+ _as_numpy,
27
+ _convert_to_dataframe,
28
+ get_dataframes_and_queues,
29
+ )
30
+ from sklearnex.linear_model import IncrementalRidge
31
+
32
+ def _compute_ridge_coefficients(X, y, alpha, fit_intercept):
33
+ coefficients_manual, intercept_manual = None, None
34
+ if fit_intercept:
35
+ X_mean = np.mean(X, axis=0)
36
+ y_mean = np.mean(y)
37
+ X_centered = X - X_mean
38
+ y_centered = y - y_mean
39
+
40
+ X_with_intercept = np.hstack([np.ones((X.shape[0], 1)), X_centered])
41
+ lambda_identity = alpha * np.eye(X_with_intercept.shape[1])
42
+ inverse_term = np.linalg.inv(
43
+ np.dot(X_with_intercept.T, X_with_intercept) + lambda_identity
44
+ )
45
+ xt_y = np.dot(X_with_intercept.T, y_centered)
46
+ coefficients_manual = np.dot(inverse_term, xt_y)
47
+
48
+ intercept_manual = y_mean - np.dot(X_mean, coefficients_manual[1:])
49
+ coefficients_manual = coefficients_manual[1:]
50
+ else:
51
+ lambda_identity = alpha * np.eye(X.shape[1])
52
+ inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity)
53
+ xt_y = np.dot(X.T, y)
54
+ coefficients_manual = np.dot(inverse_term, xt_y)
55
+
56
+ return coefficients_manual, intercept_manual
57
+
58
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
59
+ @pytest.mark.parametrize("batch_size", [10, 100, 1000])
60
+ @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0])
61
+ @pytest.mark.parametrize("fit_intercept", [True, False])
62
+ def test_inc_ridge_fit_coefficients(
63
+ dataframe, queue, alpha, batch_size, fit_intercept
64
+ ):
65
+ sample_size, feature_size = 1000, 50
66
+ X = np.random.rand(sample_size, feature_size)
67
+ y = np.random.rand(sample_size)
68
+ X_c = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
69
+ y_c = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
70
+
71
+ inc_ridge = IncrementalRidge(
72
+ fit_intercept=fit_intercept, alpha=alpha, batch_size=batch_size
73
+ )
74
+ inc_ridge.fit(X_c, y_c)
75
+
76
+ coefficients_manual, intercept_manual = _compute_ridge_coefficients(
77
+ X, y, alpha, fit_intercept
78
+ )
79
+ if fit_intercept:
80
+ assert_allclose(inc_ridge.intercept_, intercept_manual, rtol=1e-6, atol=1e-6)
81
+
82
+ assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6)
83
+
84
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
85
+ @pytest.mark.parametrize("batch_size", [2, 5])
86
+ @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0])
87
+ def test_inc_ridge_partial_fit_coefficients(dataframe, queue, alpha, batch_size):
88
+ sample_size, feature_size = 1000, 50
89
+ X = np.random.rand(sample_size, feature_size)
90
+ y = np.random.rand(sample_size)
91
+ X_split = np.array_split(X, batch_size)
92
+ y_split = np.array_split(y, batch_size)
93
+
94
+ inc_ridge = IncrementalRidge(fit_intercept=False, alpha=alpha)
95
+
96
+ for batch_index in range(len(X_split)):
97
+ X_c = _convert_to_dataframe(
98
+ X_split[batch_index], sycl_queue=queue, target_df=dataframe
99
+ )
100
+ y_c = _convert_to_dataframe(
101
+ y_split[batch_index], sycl_queue=queue, target_df=dataframe
102
+ )
103
+ inc_ridge.partial_fit(X_c, y_c)
104
+
105
+ lambda_identity = alpha * np.eye(X.shape[1])
106
+ inverse_term = np.linalg.inv(np.dot(X.T, X) + lambda_identity)
107
+ xt_y = np.dot(X.T, y)
108
+ coefficients_manual = np.dot(inverse_term, xt_y)
109
+
110
+ assert_allclose(inc_ridge.coef_, coefficients_manual, rtol=1e-6, atol=1e-6)
111
+
112
+ def test_inc_ridge_score_before_fit():
113
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
114
+ y = np.dot(X, np.array([1, 2])) + 3
115
+ inc_ridge = IncrementalRidge(alpha=0.5)
116
+ with pytest.raises(NotFittedError):
117
+ inc_ridge.score(X, y)
118
+
119
+ def test_inc_ridge_predict_before_fit():
120
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
121
+ inc_ridge = IncrementalRidge(alpha=0.5)
122
+ with pytest.raises(NotFittedError):
123
+ inc_ridge.predict(X)
124
+
125
+ def test_inc_ridge_score_after_fit():
126
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
127
+ y = np.dot(X, np.array([1, 2])) + 3
128
+ inc_ridge = IncrementalRidge(alpha=0.5)
129
+ inc_ridge.fit(X, y)
130
+ assert inc_ridge.score(X, y) >= 0.97
131
+
132
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
133
+ @pytest.mark.parametrize("fit_intercept", [True, False])
134
+ def test_inc_ridge_predict_after_fit(dataframe, queue, fit_intercept):
135
+ sample_size, feature_size = 1000, 50
136
+ X = np.random.rand(sample_size, feature_size)
137
+ y = np.random.rand(sample_size)
138
+ X_c = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
139
+ y_c = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
140
+
141
+ inc_ridge = IncrementalRidge(fit_intercept=fit_intercept, alpha=0.5)
142
+ inc_ridge.fit(X_c, y_c)
143
+
144
+ y_pred = inc_ridge.predict(X_c)
145
+
146
+ coefficients_manual, intercept_manual = _compute_ridge_coefficients(
147
+ X, y, 0.5, fit_intercept
148
+ )
149
+ y_pred_manual = np.dot(X, coefficients_manual)
150
+ if fit_intercept:
151
+ y_pred_manual += intercept_manual
152
+
153
+ assert_allclose(_as_numpy(y_pred), y_pred_manual, rtol=1e-6, atol=1e-6)
@@ -0,0 +1,142 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+ from sklearn.datasets import make_regression
21
+
22
+ from daal4py.sklearn._utils import daal_check_version
23
+ from daal4py.sklearn.linear_model.tests.test_ridge import (
24
+ _test_multivariate_ridge_alpha_shape,
25
+ _test_multivariate_ridge_coefficients,
26
+ )
27
+ from onedal.tests.utils._dataframes_support import (
28
+ _as_numpy,
29
+ _convert_to_dataframe,
30
+ get_dataframes_and_queues,
31
+ )
32
+
33
+
34
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
35
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
36
+ @pytest.mark.parametrize("macro_block", [None, 1024])
37
+ def test_sklearnex_import_linear(dataframe, queue, dtype, macro_block):
38
+ from sklearnex.linear_model import LinearRegression
39
+
40
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
41
+ y = np.dot(X, np.array([1, 2])) + 3
42
+ X = X.astype(dtype=dtype)
43
+ y = y.astype(dtype=dtype)
44
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
45
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
46
+
47
+ linreg = LinearRegression()
48
+ if daal_check_version((2024, "P", 0)) and macro_block is not None:
49
+ hparams = linreg.get_hyperparameters("fit")
50
+ hparams.cpu_macro_block = macro_block
51
+ hparams.gpu_macro_block = macro_block
52
+
53
+ linreg.fit(X, y)
54
+
55
+ assert hasattr(linreg, "_onedal_estimator")
56
+ assert "sklearnex" in linreg.__module__
57
+ assert linreg.n_features_in_ == 2
58
+
59
+ tol = 1e-5 if _as_numpy(linreg.coef_).dtype == np.float32 else 1e-7
60
+ assert_allclose(_as_numpy(linreg.intercept_), 3.0, rtol=tol)
61
+ assert_allclose(_as_numpy(linreg.coef_), [1.0, 2.0], rtol=tol)
62
+
63
+
64
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
65
+ def test_sklearnex_import_ridge(dataframe, queue):
66
+ from sklearnex.linear_model import Ridge
67
+
68
+ X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
69
+ y = np.dot(X, np.array([1, 2])) + 3
70
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
71
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
72
+ ridgereg = Ridge().fit(X, y)
73
+ assert "daal4py" in ridgereg.__module__
74
+ assert_allclose(ridgereg.intercept_, 4.5)
75
+ assert_allclose(ridgereg.coef_, [0.8, 1.4])
76
+
77
+
78
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
79
+ def test_sklearnex_import_lasso(dataframe, queue):
80
+ from sklearnex.linear_model import Lasso
81
+
82
+ X = [[0, 0], [1, 1], [2, 2]]
83
+ y = [0, 1, 2]
84
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
85
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
86
+ lasso = Lasso(alpha=0.1).fit(X, y)
87
+ assert "daal4py" in lasso.__module__
88
+ assert_allclose(lasso.intercept_, 0.15)
89
+ assert_allclose(lasso.coef_, [0.85, 0.0])
90
+
91
+
92
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
93
+ def test_sklearnex_import_elastic(dataframe, queue):
94
+ from sklearnex.linear_model import ElasticNet
95
+
96
+ X, y = make_regression(n_features=2, random_state=0)
97
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
98
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
99
+ elasticnet = ElasticNet(random_state=0).fit(X, y)
100
+ assert "daal4py" in elasticnet.__module__
101
+ assert_allclose(elasticnet.intercept_, 1.451, atol=1e-3)
102
+ assert_allclose(elasticnet.coef_, [18.838, 64.559], atol=1e-3)
103
+
104
+
105
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
106
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
107
+ def test_sklearnex_reconstruct_model(dataframe, queue, dtype):
108
+ from sklearnex.linear_model import LinearRegression
109
+
110
+ seed = 42
111
+ num_samples = 3500
112
+ num_features, num_targets = 14, 9
113
+
114
+ gen = np.random.default_rng(seed)
115
+ intercept = gen.random(size=num_targets, dtype=dtype)
116
+ coef = gen.random(size=(num_targets, num_features), dtype=dtype).T
117
+
118
+ X = gen.random(size=(num_samples, num_features), dtype=dtype)
119
+ gtr = X @ coef + intercept[np.newaxis, :]
120
+
121
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
122
+
123
+ linreg = LinearRegression(fit_intercept=True)
124
+ linreg.coef_ = coef.T
125
+ linreg.intercept_ = intercept
126
+
127
+ y_pred = linreg.predict(X)
128
+
129
+ tol = 1e-5 if _as_numpy(y_pred).dtype == np.float32 else 1e-7
130
+ assert_allclose(gtr, _as_numpy(y_pred), rtol=tol)
131
+
132
+
133
+ def test_sklearnex_multivariate_ridge_coefs():
134
+ from sklearnex.linear_model import Ridge
135
+
136
+ _test_multivariate_ridge_coefficients(Ridge, random_state=0)
137
+
138
+
139
+ def test_sklearnex_multivariate_ridge_alpha_shape():
140
+ from sklearnex.linear_model import Ridge
141
+
142
+ _test_multivariate_ridge_alpha_shape(Ridge, random_state=0)
@@ -0,0 +1,134 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose, assert_array_equal
20
+ from scipy.sparse import csr_matrix
21
+ from sklearn.datasets import load_breast_cancer, load_iris, make_classification
22
+ from sklearn.metrics import accuracy_score
23
+ from sklearn.model_selection import train_test_split
24
+
25
+ from daal4py.sklearn._utils import daal_check_version
26
+ from onedal.tests.utils._dataframes_support import (
27
+ _as_numpy,
28
+ _convert_to_dataframe,
29
+ get_dataframes_and_queues,
30
+ get_queues,
31
+ )
32
+ from sklearnex import config_context
33
+
34
+
35
+ def prepare_input(X, y, dataframe, queue):
36
+ X_train, X_test, y_train, y_test = train_test_split(
37
+ X, y, train_size=0.8, random_state=42
38
+ )
39
+ X_train = _convert_to_dataframe(X_train, sycl_queue=queue, target_df=dataframe)
40
+ y_train = _convert_to_dataframe(y_train, sycl_queue=queue, target_df=dataframe)
41
+ X_test = _convert_to_dataframe(X_test, sycl_queue=queue, target_df=dataframe)
42
+ return X_train, X_test, y_train, y_test
43
+
44
+
45
+ @pytest.mark.parametrize(
46
+ "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
47
+ )
48
+ def test_sklearnex_multiclass_classification(dataframe, queue):
49
+ from sklearnex.linear_model import LogisticRegression
50
+
51
+ X, y = load_iris(return_X_y=True)
52
+ X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue)
53
+
54
+ logreg = LogisticRegression(fit_intercept=True, solver="lbfgs", max_iter=200).fit(
55
+ X_train, y_train
56
+ )
57
+
58
+ if daal_check_version((2024, "P", 1)):
59
+ assert "sklearnex" in logreg.__module__
60
+ else:
61
+ assert "daal4py" in logreg.__module__
62
+
63
+ y_pred = _as_numpy(logreg.predict(X_test))
64
+ assert accuracy_score(y_test, y_pred) > 0.99
65
+
66
+
67
+ @pytest.mark.parametrize(
68
+ "dataframe,queue",
69
+ get_dataframes_and_queues(),
70
+ )
71
+ def test_sklearnex_binary_classification(dataframe, queue):
72
+ from sklearnex.linear_model import LogisticRegression
73
+
74
+ X, y = load_breast_cancer(return_X_y=True)
75
+ X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue)
76
+
77
+ logreg = LogisticRegression(fit_intercept=True, solver="newton-cg", max_iter=100).fit(
78
+ X_train, y_train
79
+ )
80
+
81
+ if daal_check_version((2024, "P", 1)):
82
+ assert "sklearnex" in logreg.__module__
83
+ else:
84
+ assert "daal4py" in logreg.__module__
85
+ if (
86
+ dataframe != "numpy"
87
+ and queue is not None
88
+ and queue.sycl_device.is_gpu
89
+ and daal_check_version((2024, "P", 1))
90
+ ):
91
+ # fit was done on gpu
92
+ assert hasattr(logreg, "_onedal_estimator")
93
+
94
+ y_pred = _as_numpy(logreg.predict(X_test))
95
+ assert accuracy_score(y_test, y_pred) > 0.95
96
+
97
+
98
+ if daal_check_version((2024, "P", 700)):
99
+
100
+ @pytest.mark.parametrize("queue", get_queues("gpu"))
101
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
102
+ @pytest.mark.parametrize(
103
+ "dims", [(3007, 17, 0.05), (50000, 100, 0.01), (512, 10, 0.5)]
104
+ )
105
+ def test_csr(queue, dtype, dims):
106
+ from sklearnex.linear_model import LogisticRegression
107
+
108
+ n, p, density = dims
109
+
110
+ # Create sparse dataset for classification
111
+ X, y = make_classification(n, p, random_state=42)
112
+ X = X.astype(dtype)
113
+ y = y.astype(dtype)
114
+ np.random.seed(2007 + n + p)
115
+ mask = np.random.binomial(1, density, (n, p))
116
+ X = X * mask
117
+ X_sp = csr_matrix(X)
118
+
119
+ model = LogisticRegression(fit_intercept=True, solver="newton-cg")
120
+ model_sp = LogisticRegression(fit_intercept=True, solver="newton-cg")
121
+
122
+ with config_context(target_offload="gpu:0"):
123
+ model.fit(X, y)
124
+ pred = model.predict(X)
125
+ prob = model.predict_proba(X)
126
+ model_sp.fit(X_sp, y)
127
+ pred_sp = model_sp.predict(X_sp)
128
+ prob_sp = model_sp.predict_proba(X_sp)
129
+
130
+ rtol = 2e-4
131
+ assert_allclose(pred, pred_sp, rtol=rtol)
132
+ assert_allclose(prob, prob_sp, rtol=rtol)
133
+ assert_allclose(model.coef_, model_sp.coef_, rtol=rtol)
134
+ assert_allclose(model.intercept_, model_sp.intercept_, rtol=rtol)
@@ -0,0 +1,19 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from .t_sne import TSNE
18
+
19
+ __all__ = ["TSNE"]
@@ -0,0 +1,21 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn.manifold import TSNE
18
+ from onedal._device_offload import support_input_format
19
+
20
+ TSNE.fit = support_input_format(queue_param=False)(TSNE.fit)
21
+ TSNE.fit_transform = support_input_format(queue_param=False)(TSNE.fit_transform)
@@ -0,0 +1,26 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ from numpy.testing import assert_allclose
19
+
20
+
21
+ def test_sklearnex_import():
22
+ from sklearnex.manifold import TSNE
23
+
24
+ X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
25
+ tsne = TSNE(n_components=2, perplexity=2.0).fit(X)
26
+ assert "daal4py" in tsne.__module__
@@ -0,0 +1,23 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from .pairwise import pairwise_distances
18
+ from .ranking import roc_auc_score
19
+
20
+ __all__ = [
21
+ "roc_auc_score",
22
+ "pairwise_distances",
23
+ ]
@@ -0,0 +1,22 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn.metrics import pairwise_distances
18
+ from onedal._device_offload import support_input_format
19
+
20
+ pairwise_distances = support_input_format(freefunc=True, queue_param=False)(
21
+ pairwise_distances
22
+ )
@@ -0,0 +1,20 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn.metrics import roc_auc_score
18
+ from onedal._device_offload import support_input_format
19
+
20
+ roc_auc_score = support_input_format(freefunc=True, queue_param=False)(roc_auc_score)
@@ -0,0 +1,39 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ from numpy.testing import assert_allclose
19
+ from sklearn.datasets import load_breast_cancer
20
+
21
+
22
+ def test_sklearnex_import_roc_auc():
23
+ from sklearnex.linear_model import LogisticRegression
24
+ from sklearnex.metrics import roc_auc_score
25
+
26
+ X, y = load_breast_cancer(return_X_y=True)
27
+ clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
28
+ res = roc_auc_score(y, clf.decision_function(X))
29
+ assert_allclose(res, 0.99, atol=1e-2)
30
+
31
+
32
+ def test_sklearnex_import_pairwise_distances():
33
+ from sklearnex.metrics import pairwise_distances
34
+
35
+ rng = np.random.RandomState(0)
36
+ x = np.abs(rng.rand(4), dtype=np.float64)
37
+ x = np.vstack([x, x])
38
+ res = pairwise_distances(x, metric="cosine")
39
+ assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2)
@@ -0,0 +1,21 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from .split import train_test_split
18
+
19
+ __all__ = [
20
+ "train_test_split",
21
+ ]
@@ -0,0 +1,22 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from daal4py.sklearn.model_selection import train_test_split
18
+ from onedal._device_offload import support_input_format
19
+
20
+ train_test_split = support_input_format(freefunc=True, queue_param=False)(
21
+ train_test_split
22
+ )