scikit-learn-intelex 2025.1.0__py311-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-311-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-311-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-311-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
sklearnex/svm/svr.py ADDED
@@ -0,0 +1,167 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ from sklearn.svm import SVR as _sklearn_SVR
19
+ from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
20
+
21
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
22
+ from daal4py.sklearn._utils import sklearn_check_version
23
+ from onedal.svm import SVR as onedal_SVR
24
+
25
+ from .._device_offload import dispatch, wrap_output_data
26
+ from ._common import BaseSVR
27
+
28
+ if sklearn_check_version("1.6"):
29
+ from sklearn.utils.validation import validate_data
30
+ else:
31
+ validate_data = BaseSVR._validate_data
32
+
33
+
34
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
35
+ class SVR(_sklearn_SVR, BaseSVR):
36
+ __doc__ = _sklearn_SVR.__doc__
37
+
38
+ if sklearn_check_version("1.2"):
39
+ _parameter_constraints: dict = {**_sklearn_SVR._parameter_constraints}
40
+
41
+ @_deprecate_positional_args
42
+ def __init__(
43
+ self,
44
+ *,
45
+ kernel="rbf",
46
+ degree=3,
47
+ gamma="scale",
48
+ coef0=0.0,
49
+ tol=1e-3,
50
+ C=1.0,
51
+ epsilon=0.1,
52
+ shrinking=True,
53
+ cache_size=200,
54
+ verbose=False,
55
+ max_iter=-1,
56
+ ):
57
+ super().__init__(
58
+ kernel=kernel,
59
+ degree=degree,
60
+ gamma=gamma,
61
+ coef0=coef0,
62
+ tol=tol,
63
+ C=C,
64
+ epsilon=epsilon,
65
+ shrinking=shrinking,
66
+ cache_size=cache_size,
67
+ verbose=verbose,
68
+ max_iter=max_iter,
69
+ )
70
+
71
+ def fit(self, X, y, sample_weight=None):
72
+ if sklearn_check_version("1.2"):
73
+ self._validate_params()
74
+ elif self.C <= 0:
75
+ # else if added to correct issues with
76
+ # sklearn tests:
77
+ # svm/tests/test_sparse.py::test_error
78
+ # svm/tests/test_svm.py::test_bad_input
79
+ # for sklearn versions < 1.2 (i.e. without
80
+ # validate_params parameter checking)
81
+ # Without this, a segmentation fault with
82
+ # Windows fatal exception: access violation
83
+ # occurs
84
+ raise ValueError("C <= 0")
85
+ dispatch(
86
+ self,
87
+ "fit",
88
+ {
89
+ "onedal": self.__class__._onedal_fit,
90
+ "sklearn": _sklearn_SVR.fit,
91
+ },
92
+ X,
93
+ y,
94
+ sample_weight=sample_weight,
95
+ )
96
+
97
+ return self
98
+
99
+ @wrap_output_data
100
+ def predict(self, X):
101
+ check_is_fitted(self)
102
+ return dispatch(
103
+ self,
104
+ "predict",
105
+ {
106
+ "onedal": self.__class__._onedal_predict,
107
+ "sklearn": _sklearn_SVR.predict,
108
+ },
109
+ X,
110
+ )
111
+
112
+ @wrap_output_data
113
+ def score(self, X, y, sample_weight=None):
114
+ check_is_fitted(self)
115
+ return dispatch(
116
+ self,
117
+ "score",
118
+ {
119
+ "onedal": self.__class__._onedal_score,
120
+ "sklearn": _sklearn_SVR.score,
121
+ },
122
+ X,
123
+ y,
124
+ sample_weight=sample_weight,
125
+ )
126
+
127
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
128
+ X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
129
+ onedal_params = {
130
+ "C": self.C,
131
+ "epsilon": self.epsilon,
132
+ "kernel": self.kernel,
133
+ "degree": self.degree,
134
+ "gamma": self._compute_gamma_sigma(X),
135
+ "coef0": self.coef0,
136
+ "tol": self.tol,
137
+ "shrinking": self.shrinking,
138
+ "cache_size": self.cache_size,
139
+ "max_iter": self.max_iter,
140
+ }
141
+
142
+ self._onedal_estimator = onedal_SVR(**onedal_params)
143
+ self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
144
+ self._save_attributes()
145
+
146
+ def _onedal_predict(self, X, queue=None):
147
+ if sklearn_check_version("1.0"):
148
+ X = validate_data(
149
+ self,
150
+ X,
151
+ dtype=[np.float64, np.float32],
152
+ force_all_finite=False,
153
+ accept_sparse="csr",
154
+ reset=False,
155
+ )
156
+ else:
157
+ X = check_array(
158
+ X,
159
+ dtype=[np.float64, np.float32],
160
+ force_all_finite=False,
161
+ accept_sparse="csr",
162
+ )
163
+ return self._onedal_estimator.predict(X, queue=queue)
164
+
165
+ fit.__doc__ = _sklearn_SVR.fit.__doc__
166
+ predict.__doc__ = _sklearn_SVR.predict.__doc__
167
+ score.__doc__ = _sklearn_SVR.score.__doc__
@@ -0,0 +1,93 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from onedal.tests.utils._dataframes_support import (
22
+ _as_numpy,
23
+ _convert_to_dataframe,
24
+ get_dataframes_and_queues,
25
+ )
26
+
27
+
28
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
29
+ def test_sklearnex_import_svc(dataframe, queue):
30
+ if queue and queue.sycl_device.is_gpu:
31
+ pytest.skip("SVC fit for the GPU sycl_queue is buggy.")
32
+ from sklearnex.svm import SVC
33
+
34
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
35
+ y = np.array([1, 1, 1, 2, 2, 2])
36
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
37
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
38
+ svc = SVC(kernel="linear").fit(X, y)
39
+ assert "daal4py" in svc.__module__ or "sklearnex" in svc.__module__
40
+ assert_allclose(_as_numpy(svc.dual_coef_), [[-0.25, 0.25]])
41
+ assert_allclose(_as_numpy(svc.support_), [1, 3])
42
+
43
+
44
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
45
+ def test_sklearnex_import_nusvc(dataframe, queue):
46
+ if queue and queue.sycl_device.is_gpu:
47
+ pytest.skip("NuSVC fit for the GPU sycl_queue is buggy.")
48
+ from sklearnex.svm import NuSVC
49
+
50
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
51
+ y = np.array([1, 1, 1, 2, 2, 2])
52
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
53
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
54
+ svc = NuSVC(kernel="linear").fit(X, y)
55
+ assert "daal4py" in svc.__module__ or "sklearnex" in svc.__module__
56
+ assert_allclose(
57
+ _as_numpy(svc.dual_coef_), [[-0.04761905, -0.0952381, 0.0952381, 0.04761905]]
58
+ )
59
+ assert_allclose(_as_numpy(svc.support_), [0, 1, 3, 4])
60
+
61
+
62
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
63
+ def test_sklearnex_import_svr(dataframe, queue):
64
+ if queue and queue.sycl_device.is_gpu:
65
+ pytest.skip("SVR fit for the GPU sycl_queue is buggy.")
66
+ from sklearnex.svm import SVR
67
+
68
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
69
+ y = np.array([1, 1, 1, 2, 2, 2])
70
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
71
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
72
+ svc = SVR(kernel="linear").fit(X, y)
73
+ assert "daal4py" in svc.__module__ or "sklearnex" in svc.__module__
74
+ assert_allclose(_as_numpy(svc.dual_coef_), [[-0.1, 0.1]])
75
+ assert_allclose(_as_numpy(svc.support_), [1, 3])
76
+
77
+
78
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
79
+ def test_sklearnex_import_nusvr(dataframe, queue):
80
+ if queue and queue.sycl_device.is_gpu:
81
+ pytest.skip("NuSVR fit for the GPU sycl_queue is buggy.")
82
+ from sklearnex.svm import NuSVR
83
+
84
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [+1, +1], [+1, +2], [+2, +1]])
85
+ y = np.array([1, 1, 1, 2, 2, 2])
86
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
87
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
88
+ svc = NuSVR(kernel="linear", nu=0.9).fit(X, y)
89
+ assert "daal4py" in svc.__module__ or "sklearnex" in svc.__module__
90
+ assert_allclose(
91
+ _as_numpy(svc.dual_coef_), [[-1.0, 0.611111, 1.0, -0.611111]], rtol=1e-3
92
+ )
93
+ assert_allclose(_as_numpy(svc.support_), [1, 2, 3, 5])
@@ -0,0 +1,390 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import importlib.util
18
+ import os
19
+ import pathlib
20
+ import pkgutil
21
+ import re
22
+ import sys
23
+ import trace
24
+
25
+ import pytest
26
+ from sklearn.utils import all_estimators
27
+
28
+ from daal4py.sklearn._utils import sklearn_check_version
29
+ from onedal.tests.test_common import _check_primitive_usage_ban
30
+ from sklearnex.tests.utils import (
31
+ PATCHED_MODELS,
32
+ SPECIAL_INSTANCES,
33
+ call_method,
34
+ gen_dataset,
35
+ gen_models_info,
36
+ )
37
+
38
+ TARGET_OFFLOAD_ALLOWED_LOCATIONS = [
39
+ "_config.py",
40
+ "_device_offload.py",
41
+ "test",
42
+ "svc.py",
43
+ "svm" + os.sep + "_common.py",
44
+ ]
45
+
46
+ _DESIGN_RULE_VIOLATIONS = {
47
+ "PCA-fit_transform-call_validate_data": "calls both 'fit' and 'transform'",
48
+ "IncrementalEmpiricalCovariance-score-call_validate_data": "must call clone of itself",
49
+ "SVC(probability=True)-fit-call_validate_data": "SVC fit can use sklearn estimator",
50
+ "NuSVC(probability=True)-fit-call_validate_data": "NuSVC fit can use sklearn estimator",
51
+ "LogisticRegression-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
52
+ "LogisticRegression-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
53
+ "LogisticRegression-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
54
+ "LogisticRegression-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
55
+ "LogisticRegression-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
56
+ "KNeighborsClassifier-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
57
+ "KNeighborsClassifier-fit-n_jobs_check": "uses daal4py for cpu in onedal",
58
+ "KNeighborsClassifier-score-n_jobs_check": "uses daal4py for cpu in onedal",
59
+ "KNeighborsClassifier-predict-n_jobs_check": "uses daal4py for cpu in onedal",
60
+ "KNeighborsClassifier-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
61
+ "KNeighborsClassifier-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
62
+ "KNeighborsRegressor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
63
+ "KNeighborsRegressor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
64
+ "KNeighborsRegressor-score-n_jobs_check": "uses daal4py for cpu in onedal",
65
+ "KNeighborsRegressor-predict-n_jobs_check": "uses daal4py for cpu in onedal",
66
+ "KNeighborsRegressor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
67
+ "NearestNeighbors-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
68
+ "NearestNeighbors-fit-n_jobs_check": "uses daal4py for cpu in onedal",
69
+ "NearestNeighbors-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
70
+ "NearestNeighbors-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
71
+ "NearestNeighbors-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
72
+ "LocalOutlierFactor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
73
+ "LocalOutlierFactor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
74
+ "LocalOutlierFactor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
75
+ "KNeighborsClassifier(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
76
+ "KNeighborsClassifier(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
77
+ "KNeighborsClassifier(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
78
+ "KNeighborsClassifier(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
79
+ "KNeighborsClassifier(algorithm='brute')-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
80
+ "KNeighborsClassifier(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
81
+ "KNeighborsRegressor(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
82
+ "KNeighborsRegressor(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
83
+ "KNeighborsRegressor(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
84
+ "KNeighborsRegressor(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
85
+ "KNeighborsRegressor(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
86
+ "NearestNeighbors(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
87
+ "NearestNeighbors(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
88
+ "NearestNeighbors(algorithm='brute')-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
89
+ "NearestNeighbors(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
90
+ "NearestNeighbors(algorithm='brute')-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
91
+ "LocalOutlierFactor(novelty=True)-fit-n_jobs_check": "uses daal4py for cpu in onedal",
92
+ "LocalOutlierFactor(novelty=True)-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
93
+ "LocalOutlierFactor(novelty=True)-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
94
+ "LogisticRegression(solver='newton-cg')-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
95
+ "LogisticRegression(solver='newton-cg')-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
96
+ "LogisticRegression(solver='newton-cg')-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
97
+ "LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
98
+ "LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
99
+ }
100
+
101
+
102
+ def test_target_offload_ban():
103
+ """This test blocks the use of target_offload in
104
+ in sklearnex files. Offloading computation to devices
105
+ via target_offload should only occur externally, and not
106
+ within the architecture of the sklearnex classes. This
107
+ is for clarity, traceability and maintainability.
108
+ """
109
+ output = _check_primitive_usage_ban(
110
+ primitive_name="target_offload",
111
+ package="sklearnex",
112
+ allowed_locations=TARGET_OFFLOAD_ALLOWED_LOCATIONS,
113
+ )
114
+ output = "\n".join(output)
115
+ assert output == "", f"target offloading is occuring in: \n{output}"
116
+
117
+
118
+ def _sklearnex_walk(func):
119
+ """this replaces checks on pkgutils to look through sklearnex
120
+ folders specifically"""
121
+
122
+ def wrap(*args, **kwargs):
123
+ if "prefix" in kwargs and kwargs["prefix"] == "sklearn.":
124
+ kwargs["prefix"] = "sklearnex."
125
+ if "path" in kwargs:
126
+ # force root to sklearnex
127
+ kwargs["path"] = [str(pathlib.Path(__file__).parent.parent)]
128
+ for pkginfo in func(*args, **kwargs):
129
+ # Do not allow spmd to be yielded
130
+ if "spmd" not in pkginfo.name.split("."):
131
+ yield pkginfo
132
+
133
+ return wrap
134
+
135
+
136
+ def test_class_trailing_underscore_ban(monkeypatch):
137
+ """Trailing underscores are defined for sklearn to be signatures of a fitted
138
+ estimator instance, sklearnex extends this to the classes as well"""
139
+ monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages))
140
+ estimators = all_estimators() # list of tuples
141
+ for name, obj in estimators:
142
+ if "preview" not in obj.__module__ and "daal4py" not in obj.__module__:
143
+ # propeties also occur in sklearn, especially in deprecations and are expected
144
+ # to error if queried and the estimator is not fitted
145
+ assert all(
146
+ [
147
+ isinstance(getattr(obj, attr), property)
148
+ or (attr.startswith("_") or not attr.endswith("_"))
149
+ for attr in dir(obj)
150
+ ]
151
+ ), f"{name} contains class attributes which have a trailing underscore but no leading one"
152
+
153
+
154
+ def test_all_estimators_covered(monkeypatch):
155
+ """Check that all estimators defined in sklearnex are available in either the
156
+ patch map or covered in special testing via SPECIAL_INSTANCES. The estimator
157
+ must inherit sklearn's BaseEstimator and must not have a leading underscore.
158
+ The sklearnex.spmd and sklearnex.preview packages are not tested.
159
+ """
160
+ monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages))
161
+ estimators = all_estimators() # list of tuples
162
+ uncovered_estimators = []
163
+ for name, obj in estimators:
164
+ # do nothing if defined in preview
165
+ if "preview" not in obj.__module__ and not (
166
+ any([issubclass(est, obj) for est in PATCHED_MODELS.values()])
167
+ or any([issubclass(est.__class__, obj) for est in SPECIAL_INSTANCES.values()])
168
+ ):
169
+ uncovered_estimators += [".".join([obj.__module__, name])]
170
+
171
+ assert (
172
+ uncovered_estimators == []
173
+ ), f"{uncovered_estimators} are currently not included"
174
+
175
+
176
+ def _fullpath(path):
177
+ return os.path.realpath(os.path.expanduser(path))
178
+
179
+
180
+ _TRACE_ALLOW_DICT = {
181
+ i: _fullpath(os.path.dirname(importlib.util.find_spec(i).origin))
182
+ for i in ["sklearn", "sklearnex", "onedal", "daal4py"]
183
+ }
184
+
185
+
186
+ def _whitelist_to_blacklist():
187
+ """block all standard library, built-in or site packages which are not
188
+ related to sklearn, daal4py, onedal or sklearnex"""
189
+
190
+ def _commonpath(inp):
191
+ # ValueError generated by os.path.commonpath when it is on a separate drive
192
+ try:
193
+ return os.path.commonpath(inp)
194
+ except ValueError:
195
+ return ""
196
+
197
+ blacklist = []
198
+ for path in sys.path:
199
+ fpath = _fullpath(path)
200
+ try:
201
+ # if candidate path is a parent directory to any directory in the whitelist
202
+ if any(
203
+ [_commonpath([i, fpath]) == fpath for i in _TRACE_ALLOW_DICT.values()]
204
+ ):
205
+ # find all sub-paths which are not in the whitelist and block them
206
+ # they should not have a common path that is either the whitelist path
207
+ # or the sub-path (meaning one is a parent directory of the either)
208
+ for f in os.scandir(fpath):
209
+ temppath = _fullpath(f.path)
210
+ if all(
211
+ [
212
+ _commonpath([i, temppath]) not in [i, temppath]
213
+ for i in _TRACE_ALLOW_DICT.values()
214
+ ]
215
+ ):
216
+ blacklist += [temppath]
217
+ # add path to blacklist if not a sub path of anything in the whitelist
218
+ elif all([_commonpath([i, fpath]) != i for i in _TRACE_ALLOW_DICT.values()]):
219
+ blacklist += [fpath]
220
+ except FileNotFoundError:
221
+ blacklist += [fpath]
222
+ return blacklist
223
+
224
+
225
+ _TRACE_BLOCK_LIST = _whitelist_to_blacklist()
226
+
227
+
228
+ @pytest.fixture
229
+ def estimator_trace(estimator, method, cache, capsys, monkeypatch):
230
+ """Generate a trace of all function calls in calling estimator.method with cache.
231
+
232
+ Parameters
233
+ ----------
234
+ estimator : str
235
+ name of estimator which is a key from PATCHED_MODELS or
236
+
237
+ method : str
238
+ name of estimator method which is to be traced and stored
239
+
240
+ cache: pytest.fixture (standard)
241
+
242
+ capsys: pytest.fixture (standard)
243
+
244
+ monkeypatch: pytest.fixture (standard)
245
+
246
+ Returns
247
+ -------
248
+ dict: [calledfuncs, tracetext, modules, callinglines]
249
+ Returns a list of important attributes of the trace.
250
+ calledfuncs is the list of called functions, tracetext is the
251
+ total text output of the trace as a string, modules are the
252
+ module locations of the called functions (must be from daal4py,
253
+ onedal, sklearn, or sklearnex), and callinglines is the line
254
+ which calls the function in calledfuncs
255
+ """
256
+ key = "-".join((str(estimator), method))
257
+ flag = cache.get("key", "") != key
258
+ if flag:
259
+ # get estimator
260
+ try:
261
+ est = PATCHED_MODELS[estimator]()
262
+ except KeyError:
263
+ est = SPECIAL_INSTANCES[estimator]
264
+
265
+ # get dataset
266
+ X, y = gen_dataset(est)[0]
267
+ # fit dataset if method does not contain 'fit'
268
+ if "fit" not in method:
269
+ est.fit(X, y)
270
+
271
+ # initialize tracer to have a more verbose module naming
272
+ # this impacts ignoremods, but it is not used.
273
+ monkeypatch.setattr(trace, "_modname", _fullpath)
274
+ tracer = trace.Trace(
275
+ count=0,
276
+ trace=1,
277
+ ignoredirs=_TRACE_BLOCK_LIST,
278
+ )
279
+ # call trace on method with dataset
280
+ tracer.runfunc(call_method, est, method, X, y)
281
+
282
+ # collect trace for analysis
283
+ text = capsys.readouterr().out
284
+ for modulename, file in _TRACE_ALLOW_DICT.items():
285
+ text = text.replace(file, modulename)
286
+ regex_func = (
287
+ r"(?<=funcname: )\S*(?=\n)" # needed due to differences in module structure
288
+ )
289
+ regex_mod = r"(?<=--- modulename: )\S*(?=\.py)" # needed due to differences in module structure
290
+
291
+ regex_callingline = r"(?<=\n)\S.*(?=\n --- modulename: )"
292
+
293
+ cache.set("key", key)
294
+ cache.set(
295
+ "text",
296
+ {
297
+ "funcs": re.findall(regex_func, text),
298
+ "trace": text,
299
+ "modules": [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)],
300
+ "callingline": [""] + re.findall(regex_callingline, text),
301
+ },
302
+ )
303
+
304
+ return cache.get("text", None)
305
+
306
+
307
+ def call_validate_data(text, estimator, method):
308
+ """test that the sklearn function/attribute validate_data is
309
+ called once before offloading to oneDAL in sklearnex"""
310
+ try:
311
+ # get last to_table call showing end of oneDAL input portion of code
312
+ idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("to_table")
313
+ validfuncs = text["funcs"][:idx]
314
+ except ValueError:
315
+ pytest.skip("onedal backend not used in this function")
316
+
317
+ validate_data = "validate_data" if sklearn_check_version("1.6") else "_validate_data"
318
+
319
+ assert (
320
+ validfuncs.count(validate_data) == 1
321
+ ), f"sklearn's {validate_data} should be called"
322
+ assert (
323
+ validfuncs.count("_check_feature_names") == 1
324
+ ), "estimator should check feature names in validate_data"
325
+
326
+
327
+ def n_jobs_check(text, estimator, method):
328
+ """verify the n_jobs is being set if '_get_backend' or 'to_table' is called"""
329
+ # remove the _get_backend function from sklearnex from considered _get_backend
330
+ count = max(
331
+ text["funcs"].count("to_table"),
332
+ len(
333
+ [
334
+ i
335
+ for i in range(len(text["funcs"]))
336
+ if text["funcs"][i] == "_get_backend"
337
+ and "sklearnex" not in text["modules"][i]
338
+ ]
339
+ ),
340
+ )
341
+ n_jobs_count = text["funcs"].count("n_jobs_wrapper")
342
+
343
+ assert bool(count) == bool(
344
+ n_jobs_count
345
+ ), f"verify if {method} should be in control_n_jobs' decorated_methods for {estimator}"
346
+
347
+
348
+ def runtime_property_check(text, estimator, method):
349
+ """use of Python's 'property' should not be used at runtime, only at class instantiation"""
350
+ assert (
351
+ len(re.findall(r"property\(", text["trace"])) == 0
352
+ ), f"{estimator}.{method} should only use 'property' at instantiation"
353
+
354
+
355
+ def fit_check_before_support_check(text, estimator, method):
356
+ if "fit" not in method:
357
+ if "dispatch" not in text["funcs"]:
358
+ pytest.skip(f"onedal dispatching not used in {estimator}.{method}")
359
+ idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("dispatch")
360
+ validfuncs = text["funcs"][:idx]
361
+ assert (
362
+ "check_is_fitted" in validfuncs
363
+ ), f"sklearn's check_is_fitted must be called before checking oneDAL support"
364
+
365
+ else:
366
+ pytest.skip(f"fitting occurs in {estimator}.{method}")
367
+
368
+
369
+ DESIGN_RULES = [n_jobs_check, runtime_property_check, fit_check_before_support_check]
370
+
371
+
372
+ if sklearn_check_version("1.0"):
373
+ DESIGN_RULES += [call_validate_data]
374
+
375
+
376
+ @pytest.mark.parametrize("design_pattern", DESIGN_RULES)
377
+ @pytest.mark.parametrize(
378
+ "estimator, method",
379
+ gen_models_info({**PATCHED_MODELS, **SPECIAL_INSTANCES}, fit=True, daal4py=False),
380
+ )
381
+ def test_estimator(estimator, method, design_pattern, estimator_trace):
382
+ # These tests only apply to sklearnex estimators
383
+ try:
384
+ design_pattern(estimator_trace, estimator, method)
385
+ except AssertionError:
386
+ key = "-".join([estimator, method, design_pattern.__name__])
387
+ if key in _DESIGN_RULE_VIOLATIONS:
388
+ pytest.xfail(_DESIGN_RULE_VIOLATIONS[key])
389
+ else:
390
+ raise