scikit-learn-intelex 2025.4.0__py313-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (282) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-313-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-313-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +696 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +204 -0
  62. onedal/_onedal_py_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-313-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +175 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +242 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  70. onedal/basic_statistics/tests/utils.py +50 -0
  71. onedal/cluster/__init__.py +27 -0
  72. onedal/cluster/dbscan.py +105 -0
  73. onedal/cluster/kmeans.py +557 -0
  74. onedal/cluster/kmeans_init.py +112 -0
  75. onedal/cluster/tests/test_dbscan.py +125 -0
  76. onedal/cluster/tests/test_kmeans.py +88 -0
  77. onedal/cluster/tests/test_kmeans_init.py +93 -0
  78. onedal/common/_base.py +38 -0
  79. onedal/common/_estimator_checks.py +47 -0
  80. onedal/common/_mixin.py +62 -0
  81. onedal/common/_policy.py +55 -0
  82. onedal/common/_spmd_policy.py +30 -0
  83. onedal/common/hyperparameters.py +125 -0
  84. onedal/common/tests/test_policy.py +76 -0
  85. onedal/common/tests/test_sycl.py +128 -0
  86. onedal/covariance/__init__.py +20 -0
  87. onedal/covariance/covariance.py +122 -0
  88. onedal/covariance/incremental_covariance.py +161 -0
  89. onedal/covariance/tests/test_covariance.py +50 -0
  90. onedal/covariance/tests/test_incremental_covariance.py +190 -0
  91. onedal/datatypes/__init__.py +19 -0
  92. onedal/datatypes/_data_conversion.py +121 -0
  93. onedal/datatypes/tests/common.py +126 -0
  94. onedal/datatypes/tests/test_data.py +475 -0
  95. onedal/decomposition/__init__.py +20 -0
  96. onedal/decomposition/incremental_pca.py +214 -0
  97. onedal/decomposition/pca.py +186 -0
  98. onedal/decomposition/tests/test_incremental_pca.py +285 -0
  99. onedal/ensemble/__init__.py +29 -0
  100. onedal/ensemble/forest.py +736 -0
  101. onedal/ensemble/tests/test_random_forest.py +97 -0
  102. onedal/linear_model/__init__.py +27 -0
  103. onedal/linear_model/incremental_linear_model.py +292 -0
  104. onedal/linear_model/linear_model.py +325 -0
  105. onedal/linear_model/logistic_regression.py +247 -0
  106. onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  107. onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  108. onedal/linear_model/tests/test_linear_regression.py +259 -0
  109. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  110. onedal/linear_model/tests/test_ridge.py +95 -0
  111. onedal/neighbors/__init__.py +19 -0
  112. onedal/neighbors/neighbors.py +763 -0
  113. onedal/neighbors/tests/test_knn_classification.py +49 -0
  114. onedal/primitives/__init__.py +27 -0
  115. onedal/primitives/get_tree.py +25 -0
  116. onedal/primitives/kernel_functions.py +152 -0
  117. onedal/primitives/tests/test_kernel_functions.py +159 -0
  118. onedal/spmd/__init__.py +25 -0
  119. onedal/spmd/_base.py +30 -0
  120. onedal/spmd/basic_statistics/__init__.py +20 -0
  121. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  122. onedal/spmd/basic_statistics/incremental_basic_statistics.py +71 -0
  123. onedal/spmd/cluster/__init__.py +28 -0
  124. onedal/spmd/cluster/dbscan.py +23 -0
  125. onedal/spmd/cluster/kmeans.py +56 -0
  126. onedal/spmd/covariance/__init__.py +20 -0
  127. onedal/spmd/covariance/covariance.py +26 -0
  128. onedal/spmd/covariance/incremental_covariance.py +83 -0
  129. onedal/spmd/decomposition/__init__.py +20 -0
  130. onedal/spmd/decomposition/incremental_pca.py +124 -0
  131. onedal/spmd/decomposition/pca.py +26 -0
  132. onedal/spmd/ensemble/__init__.py +19 -0
  133. onedal/spmd/ensemble/forest.py +28 -0
  134. onedal/spmd/linear_model/__init__.py +21 -0
  135. onedal/spmd/linear_model/incremental_linear_model.py +101 -0
  136. onedal/spmd/linear_model/linear_model.py +30 -0
  137. onedal/spmd/linear_model/logistic_regression.py +38 -0
  138. onedal/spmd/neighbors/__init__.py +19 -0
  139. onedal/spmd/neighbors/neighbors.py +75 -0
  140. onedal/svm/__init__.py +19 -0
  141. onedal/svm/svm.py +556 -0
  142. onedal/svm/tests/test_csr_svm.py +351 -0
  143. onedal/svm/tests/test_nusvc.py +204 -0
  144. onedal/svm/tests/test_nusvr.py +210 -0
  145. onedal/svm/tests/test_svc.py +176 -0
  146. onedal/svm/tests/test_svr.py +243 -0
  147. onedal/tests/test_common.py +57 -0
  148. onedal/tests/utils/_dataframes_support.py +162 -0
  149. onedal/tests/utils/_device_selection.py +102 -0
  150. onedal/utils/__init__.py +49 -0
  151. onedal/utils/_array_api.py +81 -0
  152. onedal/utils/_dpep_helpers.py +56 -0
  153. onedal/utils/tests/test_validation.py +142 -0
  154. onedal/utils/validation.py +464 -0
  155. scikit_learn_intelex-2025.4.0.dist-info/LICENSE.txt +202 -0
  156. scikit_learn_intelex-2025.4.0.dist-info/METADATA +190 -0
  157. scikit_learn_intelex-2025.4.0.dist-info/RECORD +282 -0
  158. scikit_learn_intelex-2025.4.0.dist-info/WHEEL +5 -0
  159. scikit_learn_intelex-2025.4.0.dist-info/top_level.txt +3 -0
  160. sklearnex/__init__.py +66 -0
  161. sklearnex/__main__.py +58 -0
  162. sklearnex/_config.py +116 -0
  163. sklearnex/_device_offload.py +126 -0
  164. sklearnex/_utils.py +177 -0
  165. sklearnex/basic_statistics/__init__.py +20 -0
  166. sklearnex/basic_statistics/basic_statistics.py +261 -0
  167. sklearnex/basic_statistics/incremental_basic_statistics.py +352 -0
  168. sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  169. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  170. sklearnex/cluster/__init__.py +20 -0
  171. sklearnex/cluster/dbscan.py +197 -0
  172. sklearnex/cluster/k_means.py +397 -0
  173. sklearnex/cluster/tests/test_dbscan.py +38 -0
  174. sklearnex/cluster/tests/test_kmeans.py +157 -0
  175. sklearnex/conftest.py +82 -0
  176. sklearnex/covariance/__init__.py +19 -0
  177. sklearnex/covariance/incremental_covariance.py +405 -0
  178. sklearnex/covariance/tests/test_incremental_covariance.py +287 -0
  179. sklearnex/decomposition/__init__.py +19 -0
  180. sklearnex/decomposition/pca.py +427 -0
  181. sklearnex/decomposition/tests/test_pca.py +58 -0
  182. sklearnex/dispatcher.py +534 -0
  183. sklearnex/doc/third-party-programs.txt +424 -0
  184. sklearnex/ensemble/__init__.py +29 -0
  185. sklearnex/ensemble/_forest.py +2029 -0
  186. sklearnex/ensemble/tests/test_forest.py +140 -0
  187. sklearnex/glob/__main__.py +72 -0
  188. sklearnex/glob/dispatcher.py +101 -0
  189. sklearnex/linear_model/__init__.py +32 -0
  190. sklearnex/linear_model/coordinate_descent.py +30 -0
  191. sklearnex/linear_model/incremental_linear.py +495 -0
  192. sklearnex/linear_model/incremental_ridge.py +432 -0
  193. sklearnex/linear_model/linear.py +346 -0
  194. sklearnex/linear_model/logistic_regression.py +415 -0
  195. sklearnex/linear_model/ridge.py +390 -0
  196. sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  197. sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  198. sklearnex/linear_model/tests/test_linear.py +142 -0
  199. sklearnex/linear_model/tests/test_logreg.py +134 -0
  200. sklearnex/linear_model/tests/test_ridge.py +256 -0
  201. sklearnex/manifold/__init__.py +19 -0
  202. sklearnex/manifold/t_sne.py +26 -0
  203. sklearnex/manifold/tests/test_tsne.py +250 -0
  204. sklearnex/metrics/__init__.py +23 -0
  205. sklearnex/metrics/pairwise.py +22 -0
  206. sklearnex/metrics/ranking.py +20 -0
  207. sklearnex/metrics/tests/test_metrics.py +39 -0
  208. sklearnex/model_selection/__init__.py +21 -0
  209. sklearnex/model_selection/split.py +22 -0
  210. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  211. sklearnex/neighbors/__init__.py +27 -0
  212. sklearnex/neighbors/_lof.py +236 -0
  213. sklearnex/neighbors/common.py +310 -0
  214. sklearnex/neighbors/knn_classification.py +231 -0
  215. sklearnex/neighbors/knn_regression.py +207 -0
  216. sklearnex/neighbors/knn_unsupervised.py +178 -0
  217. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  218. sklearnex/preview/__init__.py +17 -0
  219. sklearnex/preview/covariance/__init__.py +19 -0
  220. sklearnex/preview/covariance/covariance.py +142 -0
  221. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  222. sklearnex/preview/decomposition/__init__.py +19 -0
  223. sklearnex/preview/decomposition/incremental_pca.py +244 -0
  224. sklearnex/preview/decomposition/tests/test_incremental_pca.py +336 -0
  225. sklearnex/spmd/__init__.py +25 -0
  226. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  227. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  228. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  229. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  230. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +306 -0
  231. sklearnex/spmd/cluster/__init__.py +30 -0
  232. sklearnex/spmd/cluster/dbscan.py +50 -0
  233. sklearnex/spmd/cluster/kmeans.py +21 -0
  234. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  235. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +173 -0
  236. sklearnex/spmd/covariance/__init__.py +20 -0
  237. sklearnex/spmd/covariance/covariance.py +21 -0
  238. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  239. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  240. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  241. sklearnex/spmd/decomposition/__init__.py +20 -0
  242. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  243. sklearnex/spmd/decomposition/pca.py +21 -0
  244. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  245. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  246. sklearnex/spmd/ensemble/__init__.py +19 -0
  247. sklearnex/spmd/ensemble/forest.py +71 -0
  248. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  249. sklearnex/spmd/linear_model/__init__.py +21 -0
  250. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  251. sklearnex/spmd/linear_model/linear_model.py +21 -0
  252. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  253. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +331 -0
  254. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  255. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  256. sklearnex/spmd/neighbors/__init__.py +19 -0
  257. sklearnex/spmd/neighbors/neighbors.py +25 -0
  258. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  259. sklearnex/svm/__init__.py +29 -0
  260. sklearnex/svm/_common.py +339 -0
  261. sklearnex/svm/nusvc.py +371 -0
  262. sklearnex/svm/nusvr.py +170 -0
  263. sklearnex/svm/svc.py +399 -0
  264. sklearnex/svm/svr.py +167 -0
  265. sklearnex/svm/tests/test_svm.py +93 -0
  266. sklearnex/tests/test_common.py +491 -0
  267. sklearnex/tests/test_config.py +123 -0
  268. sklearnex/tests/test_hyperparameters.py +43 -0
  269. sklearnex/tests/test_memory_usage.py +347 -0
  270. sklearnex/tests/test_monkeypatch.py +269 -0
  271. sklearnex/tests/test_n_jobs_support.py +108 -0
  272. sklearnex/tests/test_parallel.py +48 -0
  273. sklearnex/tests/test_patching.py +377 -0
  274. sklearnex/tests/test_run_to_run_stability.py +326 -0
  275. sklearnex/tests/utils/__init__.py +48 -0
  276. sklearnex/tests/utils/base.py +436 -0
  277. sklearnex/tests/utils/spmd.py +198 -0
  278. sklearnex/utils/__init__.py +19 -0
  279. sklearnex/utils/_array_api.py +82 -0
  280. sklearnex/utils/parallel.py +59 -0
  281. sklearnex/utils/tests/test_validation.py +238 -0
  282. sklearnex/utils/validation.py +208 -0
@@ -0,0 +1,491 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import importlib.util
18
+ import io
19
+ import os
20
+ import pathlib
21
+ import pkgutil
22
+ import re
23
+ import sys
24
+ import trace
25
+ from contextlib import redirect_stdout
26
+ from multiprocessing import Pipe, Process, get_context
27
+
28
+ import pytest
29
+ from sklearn.utils import all_estimators
30
+
31
+ from daal4py.sklearn._utils import sklearn_check_version
32
+ from onedal.tests.test_common import _check_primitive_usage_ban
33
+ from sklearnex.tests.utils import (
34
+ PATCHED_MODELS,
35
+ SPECIAL_INSTANCES,
36
+ call_method,
37
+ gen_dataset,
38
+ gen_models_info,
39
+ )
40
+
41
+ TARGET_OFFLOAD_ALLOWED_LOCATIONS = [
42
+ "_config.py",
43
+ "_device_offload.py",
44
+ "test",
45
+ "svc.py",
46
+ "svm" + os.sep + "_common.py",
47
+ ]
48
+
49
+ _DESIGN_RULE_VIOLATIONS = {
50
+ "PCA-fit_transform-call_validate_data": "calls both 'fit' and 'transform'",
51
+ "IncrementalEmpiricalCovariance-score-call_validate_data": "must call clone of itself",
52
+ "SVC(probability=True)-fit-call_validate_data": "SVC fit can use sklearn estimator",
53
+ "NuSVC(probability=True)-fit-call_validate_data": "NuSVC fit can use sklearn estimator",
54
+ "LogisticRegression-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
55
+ "LogisticRegression-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
56
+ "LogisticRegression-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
57
+ "LogisticRegression-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
58
+ "LogisticRegression-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
59
+ "KNeighborsClassifier-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
60
+ "KNeighborsClassifier-fit-n_jobs_check": "uses daal4py for cpu in onedal",
61
+ "KNeighborsClassifier-score-n_jobs_check": "uses daal4py for cpu in onedal",
62
+ "KNeighborsClassifier-predict-n_jobs_check": "uses daal4py for cpu in onedal",
63
+ "KNeighborsClassifier-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
64
+ "KNeighborsClassifier-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
65
+ "KNeighborsRegressor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
66
+ "KNeighborsRegressor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
67
+ "KNeighborsRegressor-score-n_jobs_check": "uses daal4py for cpu in onedal",
68
+ "KNeighborsRegressor-predict-n_jobs_check": "uses daal4py for cpu in onedal",
69
+ "KNeighborsRegressor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
70
+ "NearestNeighbors-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
71
+ "NearestNeighbors-fit-n_jobs_check": "uses daal4py for cpu in onedal",
72
+ "NearestNeighbors-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
73
+ "NearestNeighbors-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
74
+ "NearestNeighbors-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
75
+ "LocalOutlierFactor-fit-n_jobs_check": "uses daal4py for cpu in onedal",
76
+ "LocalOutlierFactor-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
77
+ "LocalOutlierFactor-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
78
+ "KNeighborsClassifier(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
79
+ "KNeighborsClassifier(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
80
+ "KNeighborsClassifier(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
81
+ "KNeighborsClassifier(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
82
+ "KNeighborsClassifier(algorithm='brute')-predict_proba-n_jobs_check": "uses daal4py for cpu in onedal",
83
+ "KNeighborsClassifier(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
84
+ "KNeighborsRegressor(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
85
+ "KNeighborsRegressor(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
86
+ "KNeighborsRegressor(algorithm='brute')-score-n_jobs_check": "uses daal4py for cpu in onedal",
87
+ "KNeighborsRegressor(algorithm='brute')-predict-n_jobs_check": "uses daal4py for cpu in onedal",
88
+ "KNeighborsRegressor(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
89
+ "NearestNeighbors(algorithm='brute')-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
90
+ "NearestNeighbors(algorithm='brute')-fit-n_jobs_check": "uses daal4py for cpu in onedal",
91
+ "NearestNeighbors(algorithm='brute')-radius_neighbors-n_jobs_check": "uses daal4py for cpu in onedal",
92
+ "NearestNeighbors(algorithm='brute')-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
93
+ "NearestNeighbors(algorithm='brute')-radius_neighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
94
+ "LocalOutlierFactor(novelty=True)-fit-n_jobs_check": "uses daal4py for cpu in onedal",
95
+ "LocalOutlierFactor(novelty=True)-kneighbors-n_jobs_check": "uses daal4py for cpu in onedal",
96
+ "LocalOutlierFactor(novelty=True)-kneighbors_graph-n_jobs_check": "uses daal4py for cpu in onedal",
97
+ "LogisticRegression(solver='newton-cg')-score-n_jobs_check": "uses daal4py for cpu in sklearnex",
98
+ "LogisticRegression(solver='newton-cg')-fit-n_jobs_check": "uses daal4py for cpu in sklearnex",
99
+ "LogisticRegression(solver='newton-cg')-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
100
+ "LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
101
+ "LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
102
+ }
103
+
104
+
105
+ def test_target_offload_ban():
106
+ """This test blocks the use of target_offload in
107
+ in sklearnex files. Offloading computation to devices
108
+ via target_offload should only occur externally, and not
109
+ within the architecture of the sklearnex classes. This
110
+ is for clarity, traceability and maintainability.
111
+ """
112
+ output = _check_primitive_usage_ban(
113
+ primitive_name="target_offload",
114
+ package="sklearnex",
115
+ allowed_locations=TARGET_OFFLOAD_ALLOWED_LOCATIONS,
116
+ )
117
+ output = "\n".join(output)
118
+ assert output == "", f"target offloading is occuring in: \n{output}"
119
+
120
+
121
+ def _sklearnex_walk(func):
122
+ """this replaces checks on pkgutils to look through sklearnex
123
+ folders specifically"""
124
+
125
+ def wrap(*args, **kwargs):
126
+ if "prefix" in kwargs and kwargs["prefix"] == "sklearn.":
127
+ kwargs["prefix"] = "sklearnex."
128
+ if "path" in kwargs:
129
+ # force root to sklearnex
130
+ kwargs["path"] = [str(pathlib.Path(__file__).parent.parent)]
131
+ for pkginfo in func(*args, **kwargs):
132
+ # Do not allow spmd to be yielded
133
+ if "spmd" not in pkginfo.name.split("."):
134
+ yield pkginfo
135
+
136
+ return wrap
137
+
138
+
139
+ def test_class_trailing_underscore_ban(monkeypatch):
140
+ """Trailing underscores are defined for sklearn to be signatures of a fitted
141
+ estimator instance, sklearnex extends this to the classes as well"""
142
+ monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages))
143
+ estimators = all_estimators() # list of tuples
144
+ for name, obj in estimators:
145
+ if "preview" not in obj.__module__ and "daal4py" not in obj.__module__:
146
+ # propeties also occur in sklearn, especially in deprecations and are expected
147
+ # to error if queried and the estimator is not fitted
148
+ assert all(
149
+ [
150
+ isinstance(getattr(obj, attr), property)
151
+ or (attr.startswith("_") or not attr.endswith("_"))
152
+ for attr in dir(obj)
153
+ ]
154
+ ), f"{name} contains class attributes which have a trailing underscore but no leading one"
155
+
156
+
157
+ def test_all_estimators_covered(monkeypatch):
158
+ """Check that all estimators defined in sklearnex are available in either the
159
+ patch map or covered in special testing via SPECIAL_INSTANCES. The estimator
160
+ must inherit sklearn's BaseEstimator and must not have a leading underscore.
161
+ The sklearnex.spmd and sklearnex.preview packages are not tested.
162
+ """
163
+ monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages))
164
+ estimators = all_estimators() # list of tuples
165
+ uncovered_estimators = []
166
+ for name, obj in estimators:
167
+ # do nothing if defined in preview
168
+ if "preview" not in obj.__module__ and not (
169
+ any([issubclass(est, obj) for est in PATCHED_MODELS.values()])
170
+ or any([issubclass(est.__class__, obj) for est in SPECIAL_INSTANCES.values()])
171
+ ):
172
+ uncovered_estimators += [".".join([obj.__module__, name])]
173
+
174
+ assert (
175
+ uncovered_estimators == []
176
+ ), f"{uncovered_estimators} are currently not included"
177
+
178
+
179
+ def _fullpath(path):
180
+ return os.path.realpath(os.path.expanduser(path))
181
+
182
+
183
+ _TRACE_ALLOW_DICT = {
184
+ i: _fullpath(os.path.dirname(importlib.util.find_spec(i).origin))
185
+ for i in ["sklearn", "sklearnex", "onedal", "daal4py"]
186
+ }
187
+
188
+
189
+ def _whitelist_to_blacklist():
190
+ """block all standard library, built-in or site packages which are not
191
+ related to sklearn, daal4py, onedal or sklearnex"""
192
+
193
+ def _commonpath(inp):
194
+ # ValueError generated by os.path.commonpath when it is on a separate drive
195
+ try:
196
+ return os.path.commonpath(inp)
197
+ except ValueError:
198
+ return ""
199
+
200
+ blacklist = []
201
+ for path in sys.path:
202
+ fpath = _fullpath(path)
203
+ try:
204
+ # if candidate path is a parent directory to any directory in the whitelist
205
+ if any(
206
+ [_commonpath([i, fpath]) == fpath for i in _TRACE_ALLOW_DICT.values()]
207
+ ):
208
+ # find all sub-paths which are not in the whitelist and block them
209
+ # they should not have a common path that is either the whitelist path
210
+ # or the sub-path (meaning one is a parent directory of the either)
211
+ for f in os.scandir(fpath):
212
+ temppath = _fullpath(f.path)
213
+ if all(
214
+ [
215
+ _commonpath([i, temppath]) not in [i, temppath]
216
+ for i in _TRACE_ALLOW_DICT.values()
217
+ ]
218
+ ):
219
+ blacklist += [temppath]
220
+ # add path to blacklist if not a sub path of anything in the whitelist
221
+ elif all([_commonpath([i, fpath]) != i for i in _TRACE_ALLOW_DICT.values()]):
222
+ blacklist += [fpath]
223
+ except FileNotFoundError:
224
+ blacklist += [fpath]
225
+ return blacklist
226
+
227
+
228
+ _TRACE_BLOCK_LIST = _whitelist_to_blacklist()
229
+
230
+
231
+ def sklearnex_trace(estimator_name, method_name):
232
+ """Generate a trace of all function calls in calling estimator.method.
233
+
234
+ Parameters
235
+ ----------
236
+ estimator_name : str
237
+ name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
238
+
239
+ method_name : str
240
+ name of estimator method which is to be traced and stored
241
+
242
+ Returns
243
+ -------
244
+ text: str
245
+ Returns a string output (captured stdout of a python Trace call). It is a
246
+ modified version to be more informative, completed by a monkeypatching
247
+ of trace._modname.
248
+ """
249
+ # get estimator
250
+ est = (
251
+ PATCHED_MODELS[estimator_name]()
252
+ if estimator_name in PATCHED_MODELS
253
+ else SPECIAL_INSTANCES[estimator_name]
254
+ )
255
+
256
+ # get dataset
257
+ X, y = gen_dataset(est)[0]
258
+ # fit dataset if method does not contain 'fit'
259
+ if "fit" not in method_name:
260
+ est.fit(X, y)
261
+
262
+ # monkeypatch new modname for clearer info
263
+ orig_modname = trace._modname
264
+ try:
265
+ # initialize tracer to have a more verbose module naming
266
+ # this impacts ignoremods, but it is not used.
267
+ trace._modname = _fullpath
268
+ tracer = trace.Trace(
269
+ count=0,
270
+ trace=1,
271
+ ignoredirs=_TRACE_BLOCK_LIST,
272
+ )
273
+ # call trace on method with dataset
274
+ f = io.StringIO()
275
+ with redirect_stdout(f):
276
+ tracer.runfunc(call_method, est, method_name, X, y)
277
+ return f.getvalue()
278
+ finally:
279
+ trace._modname = orig_modname
280
+
281
+
282
+ def _trace_daemon(pipe):
283
+ """function interface for the other process. Information
284
+ exchanged using a multiprocess.Pipe"""
285
+ # a sent value with inherent conversion to False will break
286
+ # the while loop and complete the function
287
+ while key := pipe.recv():
288
+ try:
289
+ text = sklearnex_trace(*key)
290
+ except:
291
+ # catch all exceptions and pass back,
292
+ # this way the process still runs
293
+ text = ""
294
+ finally:
295
+ pipe.send(text)
296
+
297
+
298
+ class _FakePipe:
299
+ """Minimalistic representation of a multiprocessing.Pipe for test development.
300
+ This allows for running sklearnex_trace in the parent process"""
301
+
302
+ _text = ""
303
+
304
+ def send(self, key):
305
+ self._text = sklearnex_trace(*key)
306
+
307
+ def recv(self):
308
+ return self._text
309
+
310
+
311
+ @pytest.fixture(scope="module")
312
+ def isolated_trace():
313
+ """Generates a separate python process for isolated sklearnex traces.
314
+
315
+ It is a module scope fixture due to the overhead of importing all the
316
+ various dependencies and is done once before all the various tests.
317
+ Each test will first check a cached value, if not existent it will have
318
+ the waiting child process generate the trace and return the text for
319
+ caching on its behalf. The isolated process is stopped at test teardown.
320
+
321
+ Yields
322
+ -------
323
+ pipe_parent: multiprocessing.Connection
324
+ one end of a duplex pipe to be used by other pytest fixtures for
325
+ communicating with the special isolated tracing python instance
326
+ for sklearnex estimators.
327
+ """
328
+ # yield _FakePipe()
329
+ try:
330
+ # force use of 'spawn' to guarantee a clean python environment
331
+ # from possible coverage arc tracing
332
+ ctx = get_context("spawn")
333
+ pipe_parent, pipe_child = ctx.Pipe()
334
+ p = ctx.Process(target=_trace_daemon, args=(pipe_child,), daemon=True)
335
+ p.start()
336
+ yield pipe_parent
337
+ finally:
338
+ # guarantee closing of the process via a try-catch-finally
339
+ # passing False terminates _trace_daemon's loop
340
+ pipe_parent.send(False)
341
+ pipe_parent.close()
342
+ pipe_child.close()
343
+ p.join()
344
+ p.close()
345
+
346
+
347
+ @pytest.fixture
348
+ def estimator_trace(estimator, method, cache, isolated_trace):
349
+ """Create cache of all function calls in calling estimator.method.
350
+
351
+ Parameters
352
+ ----------
353
+ estimator : str
354
+ name of estimator which is a key from PATCHED_MODELS or SPECIAL_INSTANCES
355
+
356
+ method : str
357
+ name of estimator method which is to be traced and stored
358
+
359
+ cache: pytest.fixture (standard)
360
+
361
+ isolated_trace: pytest.fixture (test_common.py)
362
+
363
+ Returns
364
+ -------
365
+ dict: [calledfuncs, tracetext, modules, callinglines]
366
+ Returns a list of important attributes of the trace.
367
+ calledfuncs is the list of called functions, tracetext is the
368
+ total text output of the trace as a string, modules are the
369
+ module locations of the called functions (must be from daal4py,
370
+ onedal, sklearn, or sklearnex), and callinglines is the line
371
+ which calls the function in calledfuncs
372
+ """
373
+ key = "-".join((str(estimator), method))
374
+ flag = cache.get("key", "") != key
375
+ if flag:
376
+
377
+ isolated_trace.send((estimator, method))
378
+ text = isolated_trace.recv()
379
+ # if tracing does not function in isolated_trace, run it in parent process and error
380
+ if text == "":
381
+ sklearnex_trace(estimator, method)
382
+ # guarantee failure if intermittent
383
+ assert text, f"sklearnex_trace failure for {estimator}.{method}"
384
+
385
+ for modulename, file in _TRACE_ALLOW_DICT.items():
386
+ text = text.replace(file, modulename)
387
+ regex_func = (
388
+ r"(?<=funcname: )\S*(?=\n)" # needed due to differences in module structure
389
+ )
390
+ regex_mod = r"(?<=--- modulename: )\S*(?=\.py)" # needed due to differences in module structure
391
+
392
+ regex_callingline = r"(?<=\n)\S.*(?=\n --- modulename: )"
393
+
394
+ cache.set("key", key)
395
+ cache.set(
396
+ "text",
397
+ {
398
+ "funcs": re.findall(regex_func, text),
399
+ "trace": text,
400
+ "modules": [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)],
401
+ "callingline": [""] + re.findall(regex_callingline, text),
402
+ },
403
+ )
404
+
405
+ return cache.get("text", None)
406
+
407
+
408
+ def call_validate_data(text, estimator, method):
409
+ """test that the sklearn function/attribute validate_data is
410
+ called once before offloading to oneDAL in sklearnex"""
411
+ try:
412
+ # get last to_table call showing end of oneDAL input portion of code
413
+ idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("to_table")
414
+ validfuncs = text["funcs"][:idx]
415
+ except ValueError:
416
+ pytest.skip("onedal backend not used in this function")
417
+
418
+ validate_data = "validate_data" if sklearn_check_version("1.6") else "_validate_data"
419
+
420
+ assert (
421
+ validfuncs.count(validate_data) == 1
422
+ ), f"sklearn's {validate_data} should be called"
423
+ assert (
424
+ validfuncs.count("_check_feature_names") == 1
425
+ ), "estimator should check feature names in validate_data"
426
+
427
+
428
+ def n_jobs_check(text, estimator, method):
429
+ """verify the n_jobs is being set if '_get_backend' or 'to_table' is called"""
430
+ # remove the _get_backend function from sklearnex from considered _get_backend
431
+ count = max(
432
+ text["funcs"].count("to_table"),
433
+ len(
434
+ [
435
+ i
436
+ for i in range(len(text["funcs"]))
437
+ if text["funcs"][i] == "_get_backend"
438
+ and "sklearnex" not in text["modules"][i]
439
+ ]
440
+ ),
441
+ )
442
+ n_jobs_count = text["funcs"].count("n_jobs_wrapper")
443
+
444
+ assert bool(count) == bool(
445
+ n_jobs_count
446
+ ), f"verify if {method} should be in control_n_jobs' decorated_methods for {estimator}"
447
+
448
+
449
+ def runtime_property_check(text, estimator, method):
450
+ """use of Python's 'property' should not be used at runtime, only at class instantiation"""
451
+ assert (
452
+ len(re.findall(r"property\(", text["trace"])) == 0
453
+ ), f"{estimator}.{method} should only use 'property' at instantiation"
454
+
455
+
456
+ def fit_check_before_support_check(text, estimator, method):
457
+ if "fit" not in method:
458
+ if "dispatch" not in text["funcs"]:
459
+ pytest.skip(f"onedal dispatching not used in {estimator}.{method}")
460
+ idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("dispatch")
461
+ validfuncs = text["funcs"][:idx]
462
+ assert (
463
+ "check_is_fitted" in validfuncs
464
+ ), f"sklearn's check_is_fitted must be called before checking oneDAL support"
465
+
466
+ else:
467
+ pytest.skip(f"fitting occurs in {estimator}.{method}")
468
+
469
+
470
+ DESIGN_RULES = [n_jobs_check, runtime_property_check, fit_check_before_support_check]
471
+
472
+
473
+ if sklearn_check_version("1.0"):
474
+ DESIGN_RULES += [call_validate_data]
475
+
476
+
477
+ @pytest.mark.parametrize("design_pattern", DESIGN_RULES)
478
+ @pytest.mark.parametrize(
479
+ "estimator, method",
480
+ gen_models_info({**PATCHED_MODELS, **SPECIAL_INSTANCES}, fit=True, daal4py=False),
481
+ )
482
+ def test_estimator(estimator, method, design_pattern, estimator_trace):
483
+ # These tests only apply to sklearnex estimators
484
+ try:
485
+ design_pattern(estimator_trace, estimator, method)
486
+ except AssertionError:
487
+ key = "-".join([estimator, method, design_pattern.__name__])
488
+ if key in _DESIGN_RULE_VIOLATIONS:
489
+ pytest.xfail(_DESIGN_RULE_VIOLATIONS[key])
490
+ else:
491
+ raise
@@ -0,0 +1,123 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import sklearn
18
+
19
+ import onedal
20
+ import sklearnex
21
+
22
+
23
+ def test_get_config_contains_sklearn_params():
24
+ skex_config = sklearnex.get_config()
25
+ sk_config = sklearn.get_config()
26
+
27
+ assert all(value in skex_config.keys() for value in sk_config.keys())
28
+
29
+
30
+ def test_set_config_works():
31
+ """Test validates that the config settings were applied correctly by
32
+ set_config.
33
+ """
34
+ # This retrieves the current configuration settings
35
+ # from sklearnex to restore them later.
36
+ default_config = sklearnex.get_config()
37
+
38
+ # These variables define the new configuration settings
39
+ # that will be tested.
40
+ assume_finite = True
41
+ target_offload = "cpu:0"
42
+ allow_fallback_to_host = True
43
+ allow_sklearn_after_onedal = False
44
+
45
+ sklearnex.set_config(
46
+ assume_finite=assume_finite,
47
+ target_offload=target_offload,
48
+ allow_fallback_to_host=allow_fallback_to_host,
49
+ allow_sklearn_after_onedal=allow_sklearn_after_onedal,
50
+ )
51
+
52
+ config = sklearnex.get_config()
53
+ onedal_config = onedal._config._get_config()
54
+ # Any assert in test_set_config_works will leave the default config in place.
55
+ # This is an undesired behavior. Using a try finally statement will guarantee
56
+ # the use of set_config in the case of a failure.
57
+ try:
58
+ # These assertions check if the configuration was set correctly.
59
+ # If any assertion fails, it will raise an error.
60
+ assert config["target_offload"] == target_offload
61
+ assert config["allow_fallback_to_host"] == allow_fallback_to_host
62
+ assert config["allow_sklearn_after_onedal"] == allow_sklearn_after_onedal
63
+ assert config["assume_finite"] == assume_finite
64
+ assert onedal_config["target_offload"] == target_offload
65
+ assert onedal_config["allow_fallback_to_host"] == allow_fallback_to_host
66
+ finally:
67
+ # This ensures that the original configuration is restored, regardless of
68
+ # whether the assertions pass or fail.
69
+ sklearnex.set_config(**default_config)
70
+
71
+
72
+ def test_config_context_works():
73
+ """Test validates that the config settings were applied correctly
74
+ by config context manager.
75
+ """
76
+ from sklearnex import config_context, get_config
77
+
78
+ default_config = get_config()
79
+ onedal_default_config = onedal._config._get_config()
80
+
81
+ # These variables define the new configuration settings
82
+ # that will be tested.
83
+ assume_finite = True
84
+ target_offload = "cpu:0"
85
+ allow_fallback_to_host = True
86
+ allow_sklearn_after_onedal = False
87
+
88
+ # Nested context manager applies the new configuration settings.
89
+ # Each config_context temporarily sets a specific configuration,
90
+ # allowing for a clean and isolated testing environment.
91
+ with config_context(assume_finite=assume_finite):
92
+ with config_context(target_offload=target_offload):
93
+ with config_context(allow_fallback_to_host=allow_fallback_to_host):
94
+ with config_context(
95
+ allow_sklearn_after_onedal=allow_sklearn_after_onedal
96
+ ):
97
+ config = sklearnex.get_config()
98
+ onedal_config = onedal._config._get_config()
99
+
100
+ assert config["target_offload"] == target_offload
101
+ assert config["allow_fallback_to_host"] == allow_fallback_to_host
102
+ assert config["allow_sklearn_after_onedal"] == allow_sklearn_after_onedal
103
+ assert config["assume_finite"] == assume_finite
104
+ assert onedal_config["target_offload"] == target_offload
105
+ assert onedal_config["allow_fallback_to_host"] == allow_fallback_to_host
106
+
107
+ # Check that out of the config context manager default settings are
108
+ # remaining.
109
+ default_config_after_cc = get_config()
110
+ onedal_default_config_after_cc = onedal._config._get_config()
111
+ for param in [
112
+ "target_offload",
113
+ "allow_fallback_to_host",
114
+ "allow_sklearn_after_onedal",
115
+ "assume_finite",
116
+ ]:
117
+ assert default_config_after_cc[param] == default_config[param]
118
+
119
+ for param in [
120
+ "target_offload",
121
+ "allow_fallback_to_host",
122
+ ]:
123
+ assert onedal_default_config_after_cc[param] == onedal_default_config[param]
@@ -0,0 +1,43 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ import pytest
19
+
20
+ from sklearnex._utils import register_hyperparameters
21
+
22
+
23
+ def test_register_hyperparameters():
24
+ hyperparameters_map = {"op": "hyperparameters"}
25
+
26
+ @register_hyperparameters(hyperparameters_map)
27
+ class Test:
28
+ pass
29
+
30
+ # assert the correct value is returned
31
+ assert Test.get_hyperparameters("op") == "hyperparameters"
32
+
33
+
34
+ def test_register_hyperparameters_issues_warning():
35
+ hyperparameters_map = {"op": "hyperparameters"}
36
+
37
+ @register_hyperparameters(hyperparameters_map)
38
+ class Test:
39
+ pass
40
+
41
+ # assert a warning is issued when trying to modify the hyperparameters per instance
42
+ with pytest.warns(Warning):
43
+ Test().get_hyperparameters("op")