scikit-learn-intelex 2025.1.0__py310-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-310-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-310-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-310-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,385 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ import importlib
19
+ import inspect
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ from inspect import signature
25
+
26
+ import numpy as np
27
+ import numpy.random as nprnd
28
+ import pytest
29
+ from sklearn.base import BaseEstimator
30
+
31
+ from daal4py.sklearn._utils import sklearn_check_version
32
+ from onedal.tests.utils._dataframes_support import (
33
+ _convert_to_dataframe,
34
+ get_dataframes_and_queues,
35
+ )
36
+ from sklearnex import is_patched_instance
37
+ from sklearnex.dispatcher import _is_preview_enabled
38
+ from sklearnex.metrics import pairwise_distances, roc_auc_score
39
+ from sklearnex.tests.utils import (
40
+ DTYPES,
41
+ PATCHED_FUNCTIONS,
42
+ PATCHED_MODELS,
43
+ SPECIAL_INSTANCES,
44
+ UNPATCHED_FUNCTIONS,
45
+ UNPATCHED_MODELS,
46
+ call_method,
47
+ gen_dataset,
48
+ gen_models_info,
49
+ )
50
+
51
+
52
+ @pytest.mark.parametrize("dtype", DTYPES)
53
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
54
+ @pytest.mark.parametrize("metric", ["cosine", "correlation"])
55
+ def test_pairwise_distances_patching(caplog, dataframe, queue, dtype, metric):
56
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
57
+ if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
58
+ pytest.skip("Hardware does not support fp16 SYCL testing")
59
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
60
+ pytest.skip("Hardware does not support fp64 SYCL testing")
61
+ elif queue and queue.sycl_device.is_gpu:
62
+ pytest.skip("pairwise_distances does not support GPU queues")
63
+
64
+ rng = nprnd.default_rng()
65
+ if dataframe == "pandas":
66
+ X = _convert_to_dataframe(
67
+ rng.random(size=1000).astype(dtype).reshape(1, -1),
68
+ target_df=dataframe,
69
+ )
70
+ else:
71
+ X = _convert_to_dataframe(
72
+ rng.random(size=1000), sycl_queue=queue, target_df=dataframe, dtype=dtype
73
+ )[None, :]
74
+
75
+ _ = pairwise_distances(X, metric=metric)
76
+ assert all(
77
+ [
78
+ "running accelerated version" in i.message
79
+ or "fallback to original Scikit-learn" in i.message
80
+ for i in caplog.records
81
+ ]
82
+ ), f"sklearnex patching issue in pairwise_distances with log: \n{caplog.text}"
83
+
84
+
85
+ @pytest.mark.parametrize(
86
+ "dtype", [i for i in DTYPES if "32" in i.__name__ or "64" in i.__name__]
87
+ )
88
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
89
+ def test_roc_auc_score_patching(caplog, dataframe, queue, dtype):
90
+ if dtype in [np.uint32, np.uint64] and sys.platform == "win32":
91
+ pytest.skip("Windows issue with unsigned ints")
92
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
93
+ pytest.skip("Hardware does not support fp64 SYCL testing")
94
+
95
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
96
+ rng = nprnd.default_rng()
97
+ X = rng.integers(2, size=1000)
98
+ y = rng.integers(2, size=1000)
99
+
100
+ X = _convert_to_dataframe(
101
+ X,
102
+ sycl_queue=queue,
103
+ target_df=dataframe,
104
+ dtype=dtype,
105
+ )
106
+ y = _convert_to_dataframe(
107
+ y,
108
+ sycl_queue=queue,
109
+ target_df=dataframe,
110
+ dtype=dtype,
111
+ )
112
+
113
+ _ = roc_auc_score(X, y)
114
+ assert all(
115
+ [
116
+ "running accelerated version" in i.message
117
+ or "fallback to original Scikit-learn" in i.message
118
+ for i in caplog.records
119
+ ]
120
+ ), f"sklearnex patching issue in roc_auc_score with log: \n{caplog.text}"
121
+
122
+
123
+ @pytest.mark.parametrize("dtype", DTYPES)
124
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
125
+ @pytest.mark.parametrize("estimator, method", gen_models_info(PATCHED_MODELS))
126
+ def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
127
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
128
+ est = PATCHED_MODELS[estimator]()
129
+
130
+ if queue:
131
+ if dtype == np.float16 and not queue.sycl_device.has_aspect_fp16:
132
+ pytest.skip("Hardware does not support fp16 SYCL testing")
133
+ elif dtype == np.float64 and not queue.sycl_device.has_aspect_fp64:
134
+ pytest.skip("Hardware does not support fp64 SYCL testing")
135
+ elif queue.sycl_device.is_gpu and estimator in [
136
+ "ElasticNet",
137
+ "Lasso",
138
+ "Ridge",
139
+ ]:
140
+ pytest.skip(f"{estimator} does not support GPU queues")
141
+
142
+ if "NearestNeighbors" in estimator and "radius" in method:
143
+ pytest.skip(f"RadiusNeighbors estimator not implemented in sklearnex")
144
+
145
+ if estimator == "TSNE" and method == "fit_transform":
146
+ pytest.skip("TSNE.fit_transform is too slow for common testing")
147
+ elif (
148
+ estimator == "Ridge"
149
+ and method in ["predict", "score"]
150
+ and sys.platform == "win32"
151
+ and dtype in [np.uint32, np.uint64]
152
+ ):
153
+ pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints")
154
+ elif estimator == "IncrementalLinearRegression" and np.issubdtype(
155
+ dtype, np.integer
156
+ ):
157
+ pytest.skip(
158
+ "IncrementalLinearRegression fails on oneDAL side with int types because dataset is filled by zeroes"
159
+ )
160
+ elif method and not hasattr(est, method):
161
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
162
+
163
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0]
164
+ est.fit(X, y)
165
+
166
+ if method:
167
+ call_method(est, method, X, y)
168
+
169
+ assert all(
170
+ [
171
+ "running accelerated version" in i.message
172
+ or "fallback to original Scikit-learn" in i.message
173
+ for i in caplog.records
174
+ ]
175
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
176
+
177
+
178
+ @pytest.mark.parametrize("dtype", DTYPES)
179
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
180
+ @pytest.mark.parametrize("estimator, method", gen_models_info(SPECIAL_INSTANCES))
181
+ def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
182
+ # prepare logging
183
+
184
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
185
+ est = SPECIAL_INSTANCES[estimator]
186
+
187
+ if queue:
188
+ # Its not possible to get the dpnp/dpctl arrays to be in the proper dtype
189
+ if dtype == np.float16 and not queue.sycl_device.has_aspect_fp16:
190
+ pytest.skip("Hardware does not support fp16 SYCL testing")
191
+ elif dtype == np.float64 and not queue.sycl_device.has_aspect_fp64:
192
+ pytest.skip("Hardware does not support fp64 SYCL testing")
193
+
194
+ if "NearestNeighbors" in estimator and "radius" in method:
195
+ pytest.skip(f"RadiusNeighbors estimator not implemented in sklearnex")
196
+
197
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0]
198
+ est.fit(X, y)
199
+
200
+ if method and not hasattr(est, method):
201
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
202
+
203
+ if method:
204
+ call_method(est, method, X, y)
205
+
206
+ assert all(
207
+ [
208
+ "running accelerated version" in i.message
209
+ or "fallback to original Scikit-learn" in i.message
210
+ for i in caplog.records
211
+ ]
212
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
213
+
214
+
215
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
216
+ def test_standard_estimator_signatures(estimator):
217
+ est = PATCHED_MODELS[estimator]()
218
+ unpatched_est = UNPATCHED_MODELS[estimator]()
219
+
220
+ # all public sklearn methods should have signature matches in sklearnex
221
+
222
+ unpatched_est_methods = [
223
+ i
224
+ for i in dir(unpatched_est)
225
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched_est, i)
226
+ ]
227
+ for method in unpatched_est_methods:
228
+ est_method = getattr(est, method)
229
+ unpatched_est_method = getattr(unpatched_est, method)
230
+ if callable(unpatched_est_method):
231
+ regex = rf"(?:sklearn|daal4py)\S*{estimator}" # needed due to differences in module structure
232
+ patched_sig = re.sub(regex, estimator, str(signature(est_method)))
233
+ unpatched_sig = re.sub(regex, estimator, str(signature(unpatched_est_method)))
234
+ assert (
235
+ patched_sig == unpatched_sig
236
+ ), f"Signature of {estimator}.{method} does not match sklearn"
237
+
238
+
239
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
240
+ def test_standard_estimator_init_signatures(estimator):
241
+ # Several estimators have additional parameters that are user-accessible
242
+ # which are sklearnex-specific. They will fail and are removed from tests.
243
+ # remove n_jobs due to estimator patching for sklearnex (known deviation)
244
+ patched_sig = str(signature(PATCHED_MODELS[estimator].__init__))
245
+ unpatched_sig = str(signature(UNPATCHED_MODELS[estimator].__init__))
246
+
247
+ # Sklearnex allows for positional kwargs and n_jobs, when sklearn doesn't
248
+ for kwarg in ["n_jobs=None", "*"]:
249
+ patched_sig = patched_sig.replace(", " + kwarg, "")
250
+ unpatched_sig = unpatched_sig.replace(", " + kwarg, "")
251
+
252
+ # Special sklearnex-specific kwargs are removed from signatures here
253
+ if estimator in [
254
+ "RandomForestRegressor",
255
+ "RandomForestClassifier",
256
+ "ExtraTreesRegressor",
257
+ "ExtraTreesClassifier",
258
+ ]:
259
+ for kwarg in ["min_bin_size=1", "max_bins=256"]:
260
+ patched_sig = patched_sig.replace(", " + kwarg, "")
261
+
262
+ assert (
263
+ patched_sig == unpatched_sig
264
+ ), f"Signature of {estimator}.__init__ does not match sklearn"
265
+
266
+
267
+ @pytest.mark.parametrize(
268
+ "function",
269
+ [
270
+ i
271
+ for i in UNPATCHED_FUNCTIONS.keys()
272
+ if i not in ["train_test_split", "set_config", "config_context"]
273
+ ],
274
+ )
275
+ def test_patched_function_signatures(function):
276
+ # certain functions are dropped from the test
277
+ # as they add functionality to the underlying sklearn function
278
+ if not sklearn_check_version("1.1") and function == "_assert_all_finite":
279
+ pytest.skip("Sklearn versioning not added to _assert_all_finite")
280
+ func = PATCHED_FUNCTIONS[function]
281
+ unpatched_func = UNPATCHED_FUNCTIONS[function]
282
+
283
+ if callable(unpatched_func):
284
+ assert str(signature(func)) == str(
285
+ signature(unpatched_func)
286
+ ), f"Signature of {func} does not match sklearn"
287
+
288
+
289
+ def test_patch_map_match():
290
+ # This rule applies to functions and classes which are out of preview.
291
+ # Items listed in a matching submodule's __all__ attribute should be
292
+ # in get_patch_map. There should not be any missing or additional elements.
293
+
294
+ def list_all_attr(string):
295
+ try:
296
+ modules = set(importlib.import_module(string).__all__)
297
+ except ModuleNotFoundError:
298
+ modules = set([None])
299
+ return modules
300
+
301
+ if _is_preview_enabled():
302
+ pytest.skip("preview sklearnex has been activated")
303
+ patched = {**PATCHED_MODELS, **PATCHED_FUNCTIONS}
304
+
305
+ sklearnex__all__ = list_all_attr("sklearnex")
306
+ sklearn__all__ = list_all_attr("sklearn")
307
+
308
+ module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)}
309
+
310
+ # _assert_all_finite patches an internal sklearn function which isn't
311
+ # exposed via __all__ in sklearn. It is a special case where this rule
312
+ # is not applied (e.g. it is grandfathered in).
313
+ del patched["_assert_all_finite"]
314
+
315
+ # remove all scikit-learn-intelex-only estimators
316
+ for i in patched.copy():
317
+ if i not in UNPATCHED_MODELS and i not in UNPATCHED_FUNCTIONS:
318
+ del patched[i]
319
+
320
+ for module in module_map:
321
+ sklearn_module__all__ = list_all_attr("sklearn." + module_map[module])
322
+ sklearnex_module__all__ = list_all_attr("sklearnex." + module)
323
+ intersect = sklearnex_module__all__.intersection(sklearn_module__all__)
324
+
325
+ for i in intersect:
326
+ if i:
327
+ del patched[i]
328
+ else:
329
+ del patched[module]
330
+ assert patched == {}, f"{patched.keys()} were not properly patched"
331
+
332
+
333
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
334
+ def test_is_patched_instance(estimator):
335
+ patched = PATCHED_MODELS[estimator]
336
+ unpatched = UNPATCHED_MODELS[estimator]
337
+ assert is_patched_instance(patched), f"{patched} is a patched instance"
338
+ assert not is_patched_instance(unpatched), f"{unpatched} is an unpatched instance"
339
+
340
+
341
+ @pytest.mark.parametrize("estimator", PATCHED_MODELS.keys())
342
+ def test_if_estimator_inherits_sklearn(estimator):
343
+ est = PATCHED_MODELS[estimator]
344
+ if estimator in UNPATCHED_MODELS:
345
+ assert issubclass(
346
+ est, UNPATCHED_MODELS[estimator]
347
+ ), f"{estimator} does not inherit from the patched sklearn estimator"
348
+ else:
349
+ assert issubclass(est, BaseEstimator)
350
+
351
+
352
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
353
+ def test_docstring_patching_match(estimator):
354
+ patched = PATCHED_MODELS[estimator]
355
+ unpatched = UNPATCHED_MODELS[estimator]
356
+ patched_docstrings = {
357
+ i: getattr(patched, i).__doc__
358
+ for i in dir(patched)
359
+ if not i.startswith("_") and not i.endswith("_") and hasattr(patched, i)
360
+ }
361
+ unpatched_docstrings = {
362
+ i: getattr(unpatched, i).__doc__
363
+ for i in dir(unpatched)
364
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched, i)
365
+ }
366
+
367
+ # check class docstring match if a docstring is available
368
+
369
+ assert (patched.__doc__ is None) == (unpatched.__doc__ is None)
370
+
371
+ # check class attribute docstrings
372
+
373
+ for i in unpatched_docstrings:
374
+ assert (patched_docstrings[i] is None) == (unpatched_docstrings[i] is None)
375
+
376
+
377
+ @pytest.mark.parametrize("member", ["_onedal_cpu_supported", "_onedal_gpu_supported"])
378
+ @pytest.mark.parametrize(
379
+ "name",
380
+ [i for i in PATCHED_MODELS.keys() if "sklearnex" in PATCHED_MODELS[i].__module__],
381
+ )
382
+ def test_onedal_supported_member(name, member):
383
+ patched = PATCHED_MODELS[name]
384
+ sig = str(inspect.signature(getattr(patched, member)))
385
+ assert "(self, method_name, *data)" == sig