scikit-learn-intelex 2024.2.0__py310-none-manylinux1_x86_64.whl → 2024.4.0__py310-none-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/RECORD +45 -45
- sklearnex/__init__.py +9 -7
- sklearnex/_device_offload.py +31 -4
- sklearnex/basic_statistics/__init__.py +2 -1
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
- sklearnex/cluster/dbscan.py +3 -1
- sklearnex/conftest.py +63 -0
- sklearnex/decomposition/pca.py +319 -1
- sklearnex/decomposition/tests/test_pca.py +34 -5
- sklearnex/dispatcher.py +74 -43
- sklearnex/ensemble/_forest.py +78 -89
- sklearnex/ensemble/tests/test_forest.py +15 -19
- sklearnex/linear_model/linear.py +275 -340
- sklearnex/linear_model/logistic_regression.py +63 -11
- sklearnex/linear_model/tests/test_linear.py +40 -5
- sklearnex/linear_model/tests/test_logreg.py +0 -2
- sklearnex/neighbors/_lof.py +74 -20
- sklearnex/neighbors/common.py +4 -1
- sklearnex/neighbors/knn_classification.py +44 -131
- sklearnex/neighbors/knn_regression.py +16 -126
- sklearnex/neighbors/knn_unsupervised.py +11 -86
- sklearnex/neighbors/tests/test_neighbors.py +0 -5
- sklearnex/preview/__init__.py +1 -1
- sklearnex/preview/cluster/k_means.py +5 -73
- sklearnex/preview/covariance/covariance.py +6 -5
- sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- sklearnex/spmd/ensemble/forest.py +4 -12
- sklearnex/svm/_common.py +4 -7
- sklearnex/svm/nusvc.py +70 -50
- sklearnex/svm/nusvr.py +6 -52
- sklearnex/svm/svc.py +70 -51
- sklearnex/svm/svr.py +3 -49
- sklearnex/tests/_utils.py +164 -0
- sklearnex/tests/test_memory_usage.py +8 -3
- sklearnex/tests/test_monkeypatch.py +177 -149
- sklearnex/tests/test_n_jobs_support.py +8 -2
- sklearnex/tests/test_parallel.py +6 -8
- sklearnex/tests/test_patching.py +322 -87
- sklearnex/utils/__init__.py +2 -1
- sklearnex/utils/_namespace.py +97 -0
- sklearnex/preview/decomposition/__init__.py +0 -19
- sklearnex/preview/decomposition/pca.py +0 -374
- sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
- sklearnex/tests/_models_info.py +0 -170
- sklearnex/tests/utils/_launch_algorithms.py +0 -118
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
sklearnex/tests/test_patching.py
CHANGED
|
@@ -14,107 +14,306 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ==============================================================================
|
|
16
16
|
|
|
17
|
+
|
|
18
|
+
import importlib
|
|
17
19
|
import inspect
|
|
20
|
+
import logging
|
|
18
21
|
import os
|
|
19
|
-
import pathlib
|
|
20
22
|
import re
|
|
21
|
-
import subprocess
|
|
22
23
|
import sys
|
|
23
|
-
from inspect import
|
|
24
|
+
from inspect import signature
|
|
24
25
|
|
|
26
|
+
import numpy as np
|
|
27
|
+
import numpy.random as nprnd
|
|
25
28
|
import pytest
|
|
26
|
-
from _models_info import TO_SKIP
|
|
27
29
|
from sklearn.base import BaseEstimator
|
|
28
30
|
|
|
29
|
-
from
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
31
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
32
|
+
from onedal.tests.utils._dataframes_support import (
|
|
33
|
+
_convert_to_dataframe,
|
|
34
|
+
get_dataframes_and_queues,
|
|
35
|
+
)
|
|
36
|
+
from sklearnex import is_patched_instance
|
|
37
|
+
from sklearnex.dispatcher import _is_preview_enabled
|
|
38
|
+
from sklearnex.metrics import pairwise_distances, roc_auc_score
|
|
39
|
+
from sklearnex.tests._utils import (
|
|
40
|
+
DTYPES,
|
|
41
|
+
PATCHED_FUNCTIONS,
|
|
42
|
+
PATCHED_MODELS,
|
|
43
|
+
SPECIAL_INSTANCES,
|
|
44
|
+
UNPATCHED_FUNCTIONS,
|
|
45
|
+
UNPATCHED_MODELS,
|
|
46
|
+
gen_dataset,
|
|
47
|
+
gen_models_info,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.parametrize("dtype", DTYPES)
|
|
52
|
+
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
|
|
53
|
+
@pytest.mark.parametrize("metric", ["cosine", "correlation"])
|
|
54
|
+
def test_pairwise_distances_patching(caplog, dataframe, queue, dtype, metric):
|
|
55
|
+
with caplog.at_level(logging.WARNING, logger="sklearnex"):
|
|
56
|
+
if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
|
|
57
|
+
pytest.skip("Hardware does not support fp16 SYCL testing")
|
|
58
|
+
elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
|
|
59
|
+
pytest.skip("Hardware does not support fp64 SYCL testing")
|
|
60
|
+
elif queue and queue.sycl_device.is_gpu:
|
|
61
|
+
pytest.skip("pairwise_distances does not support GPU queues")
|
|
62
|
+
|
|
63
|
+
rng = nprnd.default_rng()
|
|
64
|
+
X = _convert_to_dataframe(
|
|
65
|
+
rng.random(size=1000).reshape(1, -1),
|
|
66
|
+
sycl_queue=queue,
|
|
67
|
+
target_df=dataframe,
|
|
68
|
+
dtype=dtype,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
_ = pairwise_distances(X, metric=metric)
|
|
72
|
+
assert all(
|
|
73
|
+
[
|
|
74
|
+
"running accelerated version" in i.message
|
|
75
|
+
or "fallback to original Scikit-learn" in i.message
|
|
76
|
+
for i in caplog.records
|
|
77
|
+
]
|
|
78
|
+
), f"sklearnex patching issue in pairwise_distances with log: \n{caplog.text}"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize(
|
|
82
|
+
"dtype", [i for i in DTYPES if "32" in i.__name__ or "64" in i.__name__]
|
|
83
|
+
)
|
|
84
|
+
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
|
|
85
|
+
def test_roc_auc_score_patching(caplog, dataframe, queue, dtype):
|
|
86
|
+
if dtype in [np.uint32, np.uint64] and sys.platform == "win32":
|
|
87
|
+
pytest.skip("Windows issue with unsigned ints")
|
|
88
|
+
elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
|
|
89
|
+
pytest.skip("Hardware does not support fp64 SYCL testing")
|
|
90
|
+
|
|
91
|
+
with caplog.at_level(logging.WARNING, logger="sklearnex"):
|
|
92
|
+
rng = nprnd.default_rng()
|
|
93
|
+
X = _convert_to_dataframe(
|
|
94
|
+
rng.integers(2, size=1000),
|
|
95
|
+
sycl_queue=queue,
|
|
96
|
+
target_df=dataframe,
|
|
97
|
+
dtype=dtype,
|
|
98
|
+
)
|
|
99
|
+
y = _convert_to_dataframe(
|
|
100
|
+
rng.integers(2, size=1000),
|
|
101
|
+
sycl_queue=queue,
|
|
102
|
+
target_df=dataframe,
|
|
103
|
+
dtype=dtype,
|
|
64
104
|
)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
if (
|
|
104
|
-
maybe_class is not None
|
|
105
|
-
and isclass(maybe_class)
|
|
106
|
-
and issubclass(maybe_class, BaseEstimator)
|
|
105
|
+
|
|
106
|
+
_ = roc_auc_score(X, y)
|
|
107
|
+
assert all(
|
|
108
|
+
[
|
|
109
|
+
"running accelerated version" in i.message
|
|
110
|
+
or "fallback to original Scikit-learn" in i.message
|
|
111
|
+
for i in caplog.records
|
|
112
|
+
]
|
|
113
|
+
), f"sklearnex patching issue in roc_auc_score with log: \n{caplog.text}"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@pytest.mark.parametrize("dtype", DTYPES)
|
|
117
|
+
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
|
|
118
|
+
@pytest.mark.parametrize("estimator, method", gen_models_info(PATCHED_MODELS))
|
|
119
|
+
def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
|
|
120
|
+
with caplog.at_level(logging.WARNING, logger="sklearnex"):
|
|
121
|
+
est = PATCHED_MODELS[estimator]()
|
|
122
|
+
|
|
123
|
+
if queue:
|
|
124
|
+
if dtype == np.float16 and not queue.sycl_device.has_aspect_fp16:
|
|
125
|
+
pytest.skip("Hardware does not support fp16 SYCL testing")
|
|
126
|
+
elif dtype == np.float64 and not queue.sycl_device.has_aspect_fp64:
|
|
127
|
+
pytest.skip("Hardware does not support fp64 SYCL testing")
|
|
128
|
+
elif queue.sycl_device.is_gpu and estimator in [
|
|
129
|
+
"KMeans",
|
|
130
|
+
"ElasticNet",
|
|
131
|
+
"Lasso",
|
|
132
|
+
"Ridge",
|
|
133
|
+
]:
|
|
134
|
+
pytest.skip(f"{estimator} does not support GPU queues")
|
|
135
|
+
|
|
136
|
+
if estimator == "TSNE" and method == "fit_transform":
|
|
137
|
+
pytest.skip("TSNE.fit_transform is too slow for common testing")
|
|
138
|
+
elif (
|
|
139
|
+
estimator == "Ridge"
|
|
140
|
+
and method in ["predict", "score"]
|
|
141
|
+
and sys.platform == "win32"
|
|
142
|
+
and dtype in [np.uint32, np.uint64]
|
|
107
143
|
):
|
|
108
|
-
|
|
144
|
+
pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints")
|
|
145
|
+
elif method and not hasattr(est, method):
|
|
146
|
+
pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
|
|
147
|
+
|
|
148
|
+
X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
|
|
149
|
+
est.fit(X, y)
|
|
150
|
+
|
|
151
|
+
if method:
|
|
152
|
+
if method != "score":
|
|
153
|
+
getattr(est, method)(X)
|
|
154
|
+
else:
|
|
155
|
+
est.score(X, y)
|
|
156
|
+
assert all(
|
|
157
|
+
[
|
|
158
|
+
"running accelerated version" in i.message
|
|
159
|
+
or "fallback to original Scikit-learn" in i.message
|
|
160
|
+
for i in caplog.records
|
|
161
|
+
]
|
|
162
|
+
), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@pytest.mark.parametrize("dtype", DTYPES)
|
|
166
|
+
@pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
|
|
167
|
+
@pytest.mark.parametrize("estimator, method", gen_models_info(SPECIAL_INSTANCES))
|
|
168
|
+
def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
|
|
169
|
+
# prepare logging
|
|
170
|
+
|
|
171
|
+
with caplog.at_level(logging.WARNING, logger="sklearnex"):
|
|
172
|
+
est = SPECIAL_INSTANCES[estimator]
|
|
173
|
+
|
|
174
|
+
# Its not possible to get the dpnp/dpctl arrays to be in the proper dtype
|
|
175
|
+
if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
|
|
176
|
+
pytest.skip("Hardware does not support fp16 SYCL testing")
|
|
177
|
+
elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
|
|
178
|
+
pytest.skip("Hardware does not support fp64 SYCL testing")
|
|
179
|
+
|
|
180
|
+
X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
|
|
181
|
+
est.fit(X, y)
|
|
182
|
+
|
|
183
|
+
if method and not hasattr(est, method):
|
|
184
|
+
pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
|
|
185
|
+
|
|
186
|
+
if method:
|
|
187
|
+
if method != "score":
|
|
188
|
+
getattr(est, method)(X)
|
|
189
|
+
else:
|
|
190
|
+
est.score(X, y)
|
|
191
|
+
|
|
192
|
+
assert all(
|
|
193
|
+
[
|
|
194
|
+
"running accelerated version" in i.message
|
|
195
|
+
or "fallback to original Scikit-learn" in i.message
|
|
196
|
+
for i in caplog.records
|
|
197
|
+
]
|
|
198
|
+
), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
|
|
202
|
+
def test_standard_estimator_signatures(estimator):
|
|
203
|
+
est = PATCHED_MODELS[estimator]()
|
|
204
|
+
unpatched_est = UNPATCHED_MODELS[estimator]()
|
|
205
|
+
|
|
206
|
+
# all public sklearn methods should have signature matches in sklearnex
|
|
207
|
+
|
|
208
|
+
unpatched_est_methods = [
|
|
209
|
+
i
|
|
210
|
+
for i in dir(unpatched_est)
|
|
211
|
+
if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched_est, i)
|
|
212
|
+
]
|
|
213
|
+
for method in unpatched_est_methods:
|
|
214
|
+
est_method = getattr(est, method)
|
|
215
|
+
unpatched_est_method = getattr(unpatched_est, method)
|
|
216
|
+
if callable(unpatched_est_method):
|
|
217
|
+
regex = rf"(?:sklearn|daal4py)\S*{estimator}" # needed due to differences in module structure
|
|
218
|
+
patched_sig = re.sub(regex, estimator, str(signature(est_method)))
|
|
219
|
+
unpatched_sig = re.sub(regex, estimator, str(signature(unpatched_est_method)))
|
|
220
|
+
assert (
|
|
221
|
+
patched_sig == unpatched_sig
|
|
222
|
+
), f"Signature of {estimator}.{method} does not match sklearn"
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
|
|
226
|
+
def test_standard_estimator_init_signatures(estimator):
|
|
227
|
+
# Several estimators have additional parameters that are user-accessible
|
|
228
|
+
# which are sklearnex-specific. They will fail and are removed from tests.
|
|
229
|
+
# remove n_jobs due to estimator patching for sklearnex (known deviation)
|
|
230
|
+
patched_sig = str(signature(PATCHED_MODELS[estimator].__init__))
|
|
231
|
+
unpatched_sig = str(signature(UNPATCHED_MODELS[estimator].__init__))
|
|
232
|
+
|
|
233
|
+
# Sklearnex allows for positional kwargs and n_jobs, when sklearn doesn't
|
|
234
|
+
for kwarg in ["n_jobs=None", "*"]:
|
|
235
|
+
patched_sig = patched_sig.replace(", " + kwarg, "")
|
|
236
|
+
unpatched_sig = unpatched_sig.replace(", " + kwarg, "")
|
|
237
|
+
|
|
238
|
+
# Special sklearnex-specific kwargs are removed from signatures here
|
|
239
|
+
if estimator in [
|
|
240
|
+
"RandomForestRegressor",
|
|
241
|
+
"RandomForestClassifier",
|
|
242
|
+
"ExtraTreesRegressor",
|
|
243
|
+
"ExtraTreesClassifier",
|
|
244
|
+
]:
|
|
245
|
+
for kwarg in ["min_bin_size=1", "max_bins=256"]:
|
|
246
|
+
patched_sig = patched_sig.replace(", " + kwarg, "")
|
|
247
|
+
|
|
248
|
+
assert (
|
|
249
|
+
patched_sig == unpatched_sig
|
|
250
|
+
), f"Signature of {estimator}.__init__ does not match sklearn"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@pytest.mark.parametrize(
|
|
254
|
+
"function",
|
|
255
|
+
[
|
|
256
|
+
i
|
|
257
|
+
for i in UNPATCHED_FUNCTIONS.keys()
|
|
258
|
+
if i not in ["train_test_split", "set_config", "config_context"]
|
|
259
|
+
],
|
|
260
|
+
)
|
|
261
|
+
def test_patched_function_signatures(function):
|
|
262
|
+
# certain functions are dropped from the test
|
|
263
|
+
# as they add functionality to the underlying sklearn function
|
|
264
|
+
if not sklearn_check_version("1.1") and function == "_assert_all_finite":
|
|
265
|
+
pytest.skip("Sklearn versioning not added to _assert_all_finite")
|
|
266
|
+
func = PATCHED_FUNCTIONS[function]
|
|
267
|
+
unpatched_func = UNPATCHED_FUNCTIONS[function]
|
|
268
|
+
|
|
269
|
+
if callable(unpatched_func):
|
|
270
|
+
assert str(signature(func)) == str(
|
|
271
|
+
signature(unpatched_func)
|
|
272
|
+
), f"Signature of {func} does not match sklearn"
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def test_patch_map_match():
|
|
276
|
+
# This rule applies to functions and classes which are out of preview.
|
|
277
|
+
# Items listed in a matching submodule's __all__ attribute should be
|
|
278
|
+
# in get_patch_map. There should not be any missing or additional elements.
|
|
109
279
|
|
|
110
|
-
|
|
111
|
-
|
|
280
|
+
def list_all_attr(string):
|
|
281
|
+
try:
|
|
282
|
+
modules = set(importlib.import_module(string).__all__)
|
|
283
|
+
except ModuleNotFoundError:
|
|
284
|
+
modules = set([None])
|
|
285
|
+
return modules
|
|
112
286
|
|
|
113
|
-
|
|
287
|
+
if _is_preview_enabled():
|
|
288
|
+
pytest.skip("preview sklearnex has been activated")
|
|
289
|
+
patched = {**PATCHED_MODELS, **PATCHED_FUNCTIONS}
|
|
114
290
|
|
|
291
|
+
sklearnex__all__ = list_all_attr("sklearnex")
|
|
292
|
+
sklearn__all__ = list_all_attr("sklearn")
|
|
115
293
|
|
|
116
|
-
|
|
117
|
-
|
|
294
|
+
module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)}
|
|
295
|
+
|
|
296
|
+
# _assert_all_finite patches an internal sklearn function which isn't
|
|
297
|
+
# exposed via __all__ in sklearn. It is a special case where this rule
|
|
298
|
+
# is not applied (e.g. it is grandfathered in).
|
|
299
|
+
del patched["_assert_all_finite"]
|
|
300
|
+
|
|
301
|
+
# remove all scikit-learn-intelex-only estimators
|
|
302
|
+
for i in patched.copy():
|
|
303
|
+
if i not in UNPATCHED_MODELS and i not in UNPATCHED_FUNCTIONS:
|
|
304
|
+
del patched[i]
|
|
305
|
+
|
|
306
|
+
for module in module_map:
|
|
307
|
+
sklearn_module__all__ = list_all_attr("sklearn." + module_map[module])
|
|
308
|
+
sklearnex_module__all__ = list_all_attr("sklearnex." + module)
|
|
309
|
+
intersect = sklearnex_module__all__.intersection(sklearn_module__all__)
|
|
310
|
+
|
|
311
|
+
for i in intersect:
|
|
312
|
+
if i:
|
|
313
|
+
del patched[i]
|
|
314
|
+
else:
|
|
315
|
+
del patched[module]
|
|
316
|
+
assert patched == {}, f"{patched.keys()} were not properly patched"
|
|
118
317
|
|
|
119
318
|
|
|
120
319
|
@pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
|
|
@@ -125,6 +324,42 @@ def test_is_patched_instance(estimator):
|
|
|
125
324
|
assert not is_patched_instance(unpatched), f"{unpatched} is an unpatched instance"
|
|
126
325
|
|
|
127
326
|
|
|
327
|
+
@pytest.mark.parametrize("estimator", PATCHED_MODELS.keys())
|
|
328
|
+
def test_if_estimator_inherits_sklearn(estimator):
|
|
329
|
+
est = PATCHED_MODELS[estimator]
|
|
330
|
+
if estimator in UNPATCHED_MODELS:
|
|
331
|
+
assert issubclass(
|
|
332
|
+
est, UNPATCHED_MODELS[estimator]
|
|
333
|
+
), f"{estimator} does not inherit from the patched sklearn estimator"
|
|
334
|
+
else:
|
|
335
|
+
assert issubclass(est, BaseEstimator)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
|
|
339
|
+
def test_docstring_patching_match(estimator):
|
|
340
|
+
patched = PATCHED_MODELS[estimator]
|
|
341
|
+
unpatched = UNPATCHED_MODELS[estimator]
|
|
342
|
+
patched_docstrings = {
|
|
343
|
+
i: getattr(patched, i).__doc__
|
|
344
|
+
for i in dir(patched)
|
|
345
|
+
if not i.startswith("_") and not i.endswith("_") and hasattr(patched, i)
|
|
346
|
+
}
|
|
347
|
+
unpatched_docstrings = {
|
|
348
|
+
i: getattr(unpatched, i).__doc__
|
|
349
|
+
for i in dir(unpatched)
|
|
350
|
+
if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched, i)
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
# check class docstring match if a docstring is available
|
|
354
|
+
|
|
355
|
+
assert (patched.__doc__ is None) == (unpatched.__doc__ is None)
|
|
356
|
+
|
|
357
|
+
# check class attribute docstrings
|
|
358
|
+
|
|
359
|
+
for i in unpatched_docstrings:
|
|
360
|
+
assert (patched_docstrings[i] is None) == (unpatched_docstrings[i] is None)
|
|
361
|
+
|
|
362
|
+
|
|
128
363
|
@pytest.mark.parametrize("member", ["_onedal_cpu_supported", "_onedal_gpu_supported"])
|
|
129
364
|
@pytest.mark.parametrize(
|
|
130
365
|
"name",
|
sklearnex/utils/__init__.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ===============================================================================
|
|
16
16
|
|
|
17
|
+
from ._namespace import get_namespace
|
|
17
18
|
from .validation import _assert_all_finite
|
|
18
19
|
|
|
19
|
-
__all__ = ["_assert_all_finite"]
|
|
20
|
+
__all__ = ["get_namespace", "_assert_all_finite"]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
20
|
+
|
|
21
|
+
from .._device_offload import dpnp_available
|
|
22
|
+
|
|
23
|
+
if sklearn_check_version("1.2"):
|
|
24
|
+
from sklearn.utils._array_api import get_namespace as sklearn_get_namespace
|
|
25
|
+
|
|
26
|
+
if dpnp_available:
|
|
27
|
+
import dpnp
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_namespace(*arrays):
|
|
31
|
+
"""Get namespace of arrays.
|
|
32
|
+
|
|
33
|
+
Introspect `arrays` arguments and return their common Array API
|
|
34
|
+
compatible namespace object, if any. NumPy 1.22 and later can
|
|
35
|
+
construct such containers using the `numpy.array_api` namespace
|
|
36
|
+
for instance.
|
|
37
|
+
|
|
38
|
+
This function will return the namespace of SYCL-related arrays
|
|
39
|
+
which define the __sycl_usm_array_interface__ attribute
|
|
40
|
+
regardless of array_api support, the configuration of
|
|
41
|
+
array_api_dispatch, or scikit-learn version.
|
|
42
|
+
|
|
43
|
+
See: https://numpy.org/neps/nep-0047-array-api-standard.html
|
|
44
|
+
|
|
45
|
+
If `arrays` are regular numpy arrays, an instance of the
|
|
46
|
+
`_NumPyApiWrapper` compatibility wrapper is returned instead.
|
|
47
|
+
|
|
48
|
+
Namespace support is not enabled by default. To enabled it
|
|
49
|
+
call:
|
|
50
|
+
|
|
51
|
+
sklearn.set_config(array_api_dispatch=True)
|
|
52
|
+
|
|
53
|
+
or:
|
|
54
|
+
|
|
55
|
+
with sklearn.config_context(array_api_dispatch=True):
|
|
56
|
+
# your code here
|
|
57
|
+
|
|
58
|
+
Otherwise an instance of the `_NumPyApiWrapper`
|
|
59
|
+
compatibility wrapper is always returned irrespective of
|
|
60
|
+
the fact that arrays implement the `__array_namespace__`
|
|
61
|
+
protocol or not.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
*arrays : array objects
|
|
66
|
+
Array objects.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
namespace : module
|
|
71
|
+
Namespace shared by array objects.
|
|
72
|
+
|
|
73
|
+
is_array_api : bool
|
|
74
|
+
True of the arrays are containers that implement the Array API spec.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# sycl support designed to work regardless of array_api_dispatch sklearn global value
|
|
78
|
+
sycl_type = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")}
|
|
79
|
+
|
|
80
|
+
if len(sycl_type) > 1:
|
|
81
|
+
raise ValueError(f"Multiple SYCL types for array inputs: {sycl_type}")
|
|
82
|
+
|
|
83
|
+
if sycl_type:
|
|
84
|
+
|
|
85
|
+
(X,) = sycl_type.values()
|
|
86
|
+
|
|
87
|
+
if hasattr(X, "__array_namespace__"):
|
|
88
|
+
return X.__array_namespace__(), True
|
|
89
|
+
elif dpnp_available and isinstance(X, dpnp.ndarray):
|
|
90
|
+
return dpnp, False
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"SYCL type not recognized: {sycl_type}")
|
|
93
|
+
|
|
94
|
+
elif sklearn_check_version("1.2"):
|
|
95
|
+
return sklearn_get_namespace(*arrays)
|
|
96
|
+
else:
|
|
97
|
+
return np, True
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
# ===============================================================================
|
|
2
|
-
# Copyright 2023 Intel Corporation
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
# ===============================================================================
|
|
16
|
-
|
|
17
|
-
from .pca import PCA
|
|
18
|
-
|
|
19
|
-
__all__ = ["PCA"]
|