scikit-learn-intelex 2024.0.1__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (90) hide show
  1. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__init__.py +61 -0
  2. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__main__.py +59 -0
  3. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_config.py +110 -0
  4. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_device_offload.py +223 -0
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  7. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +17 -0
  8. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +21 -0
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +18 -0
  11. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +37 -0
  12. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +31 -0
  13. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +20 -0
  14. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +18 -0
  15. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +28 -0
  16. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/dispatcher.py +329 -0
  17. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +424 -0
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +30 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  20. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  21. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/__main__.py +73 -0
  22. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +88 -0
  23. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +30 -0
  24. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +18 -0
  25. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +373 -0
  26. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +18 -0
  27. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +18 -0
  28. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +77 -0
  29. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +29 -0
  30. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +20 -0
  31. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +18 -0
  32. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +27 -0
  33. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +24 -0
  34. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +18 -0
  35. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +18 -0
  36. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +40 -0
  37. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +22 -0
  38. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/split.py +18 -0
  39. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +35 -0
  40. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +28 -0
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/common.py +264 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  43. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  44. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +220 -0
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +437 -0
  46. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  47. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +18 -0
  48. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +20 -0
  49. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +84 -0
  50. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +370 -0
  51. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +20 -0
  52. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +376 -0
  53. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +38 -0
  54. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +24 -0
  55. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +19 -0
  56. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  58. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  59. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  60. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +19 -0
  61. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +21 -0
  62. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  63. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +79 -0
  64. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +19 -0
  65. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +21 -0
  66. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +19 -0
  67. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +25 -0
  68. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/__init__.py +30 -0
  69. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/_common.py +188 -0
  70. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +272 -0
  71. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +163 -0
  72. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svc.py +301 -0
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svr.py +164 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  75. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  76. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_config.py +39 -0
  77. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +225 -0
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +210 -0
  79. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +122 -0
  81. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  82. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +118 -0
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  84. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  85. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/validation.py +18 -0
  86. scikit_learn_intelex-2024.0.1.dist-info/LICENSE.txt +202 -0
  87. scikit_learn_intelex-2024.0.1.dist-info/METADATA +230 -0
  88. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  89. scikit_learn_intelex-2024.0.1.dist-info/WHEEL +5 -0
  90. scikit_learn_intelex-2024.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,210 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import sklearnex
18
+ from daal4py.sklearn._utils import daal_check_version
19
+
20
+
21
+ def test_monkey_patching():
22
+ _tokens = sklearnex.get_patch_names()
23
+ _values = sklearnex.get_patch_map().values()
24
+ _classes = list()
25
+
26
+ for v in _values:
27
+ for c in v:
28
+ _classes.append(c[0])
29
+
30
+ sklearnex.patch_sklearn()
31
+
32
+ for i, _ in enumerate(_tokens):
33
+ t = _tokens[i]
34
+ p = _classes[i][0]
35
+ n = _classes[i][1]
36
+
37
+ class_module = getattr(p, n).__module__
38
+ assert class_module.startswith("daal4py") or class_module.startswith(
39
+ "sklearnex"
40
+ ), "Patching has completed with error."
41
+
42
+ for i, _ in enumerate(_tokens):
43
+ t = _tokens[i]
44
+ p = _classes[i][0]
45
+ n = _classes[i][1]
46
+
47
+ sklearnex.unpatch_sklearn(t)
48
+ class_module = getattr(p, n).__module__
49
+ assert class_module.startswith("sklearn"), "Unpatching has completed with error."
50
+
51
+ sklearnex.unpatch_sklearn()
52
+
53
+ for i, _ in enumerate(_tokens):
54
+ t = _tokens[i]
55
+ p = _classes[i][0]
56
+ n = _classes[i][1]
57
+
58
+ class_module = getattr(p, n).__module__
59
+ assert class_module.startswith("sklearn"), "Unpatching has completed with error."
60
+
61
+ sklearnex.unpatch_sklearn()
62
+
63
+ for i, _ in enumerate(_tokens):
64
+ t = _tokens[i]
65
+ p = _classes[i][0]
66
+ n = _classes[i][1]
67
+
68
+ sklearnex.patch_sklearn(t)
69
+
70
+ class_module = getattr(p, n).__module__
71
+ assert class_module.startswith("daal4py") or class_module.startswith(
72
+ "sklearnex"
73
+ ), "Patching has completed with error."
74
+
75
+ sklearnex.unpatch_sklearn()
76
+
77
+
78
+ def test_patch_by_list_simple():
79
+ sklearnex.patch_sklearn(["LogisticRegression"])
80
+
81
+ from sklearn.ensemble import RandomForestRegressor
82
+ from sklearn.linear_model import LogisticRegression
83
+ from sklearn.neighbors import KNeighborsRegressor
84
+ from sklearn.svm import SVC
85
+
86
+ assert RandomForestRegressor.__module__.startswith("sklearn")
87
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
88
+ assert LogisticRegression.__module__.startswith("daal4py")
89
+ assert SVC.__module__.startswith("sklearn")
90
+
91
+ sklearnex.unpatch_sklearn()
92
+
93
+
94
+ def test_patch_by_list_many_estimators():
95
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
96
+
97
+ from sklearn.ensemble import RandomForestRegressor
98
+ from sklearn.linear_model import LogisticRegression
99
+ from sklearn.neighbors import KNeighborsRegressor
100
+ from sklearn.svm import SVC
101
+
102
+ assert RandomForestRegressor.__module__.startswith("sklearn")
103
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
104
+ assert LogisticRegression.__module__.startswith("daal4py")
105
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
106
+
107
+ sklearnex.unpatch_sklearn()
108
+
109
+
110
+ def test_unpatch_by_list_many_estimators():
111
+ sklearnex.patch_sklearn()
112
+
113
+ from sklearn.ensemble import RandomForestRegressor
114
+ from sklearn.linear_model import LogisticRegression
115
+ from sklearn.neighbors import KNeighborsRegressor
116
+ from sklearn.svm import SVC
117
+
118
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
119
+ assert KNeighborsRegressor.__module__.startswith(
120
+ "daal4py"
121
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
122
+ assert LogisticRegression.__module__.startswith("daal4py")
123
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
124
+
125
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
126
+
127
+ from sklearn.ensemble import RandomForestRegressor
128
+ from sklearn.linear_model import LogisticRegression
129
+ from sklearn.neighbors import KNeighborsRegressor
130
+ from sklearn.svm import SVC
131
+
132
+ assert RandomForestRegressor.__module__.startswith("sklearn")
133
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
134
+ assert LogisticRegression.__module__.startswith("daal4py")
135
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
136
+
137
+
138
+ def test_patching_checker():
139
+ for name in [None, "SVC", "PCA"]:
140
+ sklearnex.patch_sklearn(name=name)
141
+ assert sklearnex.sklearn_is_patched(name=name)
142
+
143
+ sklearnex.unpatch_sklearn(name=name)
144
+ assert not sklearnex.sklearn_is_patched(name=name)
145
+
146
+ sklearnex.patch_sklearn()
147
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
148
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
149
+ for status in patching_status_map.values():
150
+ assert status
151
+
152
+ sklearnex.unpatch_sklearn()
153
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
154
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
155
+ for status in patching_status_map.values():
156
+ assert not status
157
+
158
+
159
+ def test_preview_namespace():
160
+ def get_estimators():
161
+ from sklearn.cluster import DBSCAN
162
+ from sklearn.decomposition import PCA
163
+ from sklearn.ensemble import RandomForestClassifier
164
+ from sklearn.linear_model import LinearRegression
165
+ from sklearn.svm import SVC
166
+
167
+ return LinearRegression(), PCA(), DBSCAN(), SVC(), RandomForestClassifier()
168
+
169
+ # BUG: previous patching tests force PCA to be patched with daal4py.
170
+ # This unpatching returns behavior to expected
171
+ sklearnex.unpatch_sklearn()
172
+ # behavior with enabled preview
173
+ sklearnex.patch_sklearn(preview=True)
174
+ assert sklearnex.dispatcher._is_preview_enabled()
175
+
176
+ lr, pca, dbscan, svc, rfc = get_estimators()
177
+ assert "sklearnex" in rfc.__module__
178
+
179
+ if daal_check_version((2023, "P", 100)):
180
+ assert "sklearnex" in lr.__module__
181
+ else:
182
+ assert "daal4py" in lr.__module__
183
+
184
+ assert "sklearnex.preview" in pca.__module__
185
+ assert "sklearnex" in dbscan.__module__
186
+ assert "sklearnex" in svc.__module__
187
+ sklearnex.unpatch_sklearn()
188
+
189
+ # no patching behavior
190
+ lr, pca, dbscan, svc, rfc = get_estimators()
191
+ assert "sklearn." in lr.__module__
192
+ assert "sklearn." in pca.__module__
193
+ assert "sklearn." in dbscan.__module__
194
+ assert "sklearn." in svc.__module__
195
+ assert "sklearn." in rfc.__module__
196
+
197
+ # default patching behavior
198
+ sklearnex.patch_sklearn()
199
+ assert not sklearnex.dispatcher._is_preview_enabled()
200
+
201
+ lr, pca, dbscan, svc, rfc = get_estimators()
202
+ if daal_check_version((2023, "P", 100)):
203
+ assert "sklearnex" in lr.__module__
204
+ else:
205
+ assert "daal4py" in lr.__module__
206
+ assert "daal4py" in pca.__module__
207
+ assert "sklearnex" in rfc.__module__
208
+ assert "sklearnex" in dbscan.__module__
209
+ assert "sklearnex" in svc.__module__
210
+ sklearnex.unpatch_sklearn()
@@ -0,0 +1,50 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ import pytest
17
+
18
+ from sklearnex import config_context, patch_sklearn
19
+
20
+ patch_sklearn()
21
+
22
+ from sklearn.datasets import make_classification
23
+ from sklearn.ensemble import BaggingClassifier
24
+ from sklearn.svm import SVC
25
+
26
+ try:
27
+ import dpctl
28
+
29
+ dpctl_is_available = True
30
+ gpu_is_available = dpctl.has_gpu_devices()
31
+ except (ImportError, ModuleNotFoundError):
32
+ dpctl_is_available = False
33
+
34
+
35
+ @pytest.mark.skipif(
36
+ not dpctl_is_available or gpu_is_available,
37
+ reason="GPU device should not be available for this test "
38
+ "to see raised 'SyclQueueCreationError'. "
39
+ "'dpctl' module is required for test.",
40
+ )
41
+ def test_config_context_in_parallel():
42
+ x, y = make_classification(random_state=42)
43
+ try:
44
+ with config_context(target_offload="gpu", allow_fallback_to_host=False):
45
+ BaggingClassifier(SVC(), n_jobs=2).fit(x, y)
46
+ raise ValueError(
47
+ "'SyclQueueCreationError' wasn't raised " "for non-existing 'gpu' device"
48
+ )
49
+ except dpctl._sycl_queue.SyclQueueCreationError:
50
+ pass
@@ -0,0 +1,122 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import os
18
+ import pathlib
19
+ import re
20
+ import subprocess
21
+ import sys
22
+ from inspect import isclass
23
+
24
+ import pytest
25
+ from _models_info import TO_SKIP
26
+ from sklearn.base import BaseEstimator
27
+
28
+ from sklearnex import get_patch_map, is_patched_instance, patch_sklearn, unpatch_sklearn
29
+
30
+
31
+ def get_branch(s):
32
+ if len(s) == 0:
33
+ return "NO INFO"
34
+ for i in s:
35
+ if "failed to run accelerated version, fallback to original Scikit-learn" in i:
36
+ return "was in OPT, but go in Scikit"
37
+ for i in s:
38
+ if "running accelerated version" in i:
39
+ return "OPT"
40
+ return "Scikit"
41
+
42
+
43
+ def run_parse(mas, result):
44
+ name, dtype = mas[0].split()
45
+ temp = []
46
+ INFO_POS = 16
47
+ for i in range(1, len(mas)):
48
+ mas[i] = mas[i][INFO_POS:] # remove 'SKLEARNEX INFO: '
49
+ if not mas[i].startswith("sklearn"):
50
+ ind = name + " " + dtype + " " + mas[i]
51
+ result[ind] = get_branch(temp)
52
+ temp.clear()
53
+ else:
54
+ temp.append(mas[i])
55
+
56
+
57
+ def get_result_log():
58
+ os.environ["SKLEARNEX_VERBOSE"] = "INFO"
59
+ absolute_path = str(pathlib.Path(__file__).parent.absolute())
60
+ try:
61
+ process = subprocess.check_output(
62
+ [sys.executable, absolute_path + "/utils/_launch_algorithms.py"]
63
+ )
64
+ except subprocess.CalledProcessError as e:
65
+ print(e)
66
+ exit(1)
67
+ mas = []
68
+ result = {}
69
+ for i in process.decode().split("\n"):
70
+ if i.startswith("SKLEARNEX WARNING"):
71
+ continue
72
+ if not i.startswith("SKLEARNEX INFO") and len(mas) != 0:
73
+ run_parse(mas, result)
74
+ mas.clear()
75
+ mas.append(i.strip())
76
+ else:
77
+ mas.append(i.strip())
78
+ del os.environ["SKLEARNEX_VERBOSE"]
79
+ return result
80
+
81
+
82
+ result_log = get_result_log()
83
+
84
+
85
+ @pytest.mark.parametrize("configuration", result_log)
86
+ def test_patching(configuration):
87
+ if "OPT" in result_log[configuration]:
88
+ return
89
+ for skip in TO_SKIP:
90
+ if re.search(skip, configuration) is not None:
91
+ pytest.skip("SKIPPED", allow_module_level=False)
92
+ raise ValueError("Test patching failed: " + configuration)
93
+
94
+
95
+ def _load_all_models(patched):
96
+ if patched:
97
+ patch_sklearn()
98
+
99
+ models = []
100
+ for patch_infos in get_patch_map().values():
101
+ maybe_class = getattr(patch_infos[0][0][0], patch_infos[0][0][1])
102
+ if (
103
+ maybe_class is not None
104
+ and isclass(maybe_class)
105
+ and issubclass(maybe_class, BaseEstimator)
106
+ ):
107
+ models.append(maybe_class())
108
+
109
+ if patched:
110
+ unpatch_sklearn()
111
+
112
+ return models
113
+
114
+
115
+ PATCHED_MODELS = _load_all_models(patched=True)
116
+ UNPATCHED_MODELS = _load_all_models(patched=False)
117
+
118
+
119
+ @pytest.mark.parametrize(("patched", "unpatched"), zip(PATCHED_MODELS, UNPATCHED_MODELS))
120
+ def test_is_patched_instance(patched, unpatched):
121
+ assert is_patched_instance(patched), f"{patched} is a patched instance"
122
+ assert not is_patched_instance(unpatched), f"{unpatched} is an unpatched instance"