scikit-learn-intelex 2024.1.0__py311-none-manylinux1_x86_64.whl → 2024.4.0__py311-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (62) hide show
  1. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  2. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  3. sklearnex/__init__.py +9 -7
  4. sklearnex/_device_offload.py +31 -4
  5. sklearnex/basic_statistics/__init__.py +2 -1
  6. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  7. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  8. sklearnex/cluster/dbscan.py +6 -4
  9. sklearnex/conftest.py +63 -0
  10. sklearnex/{preview/decomposition → covariance}/__init__.py +19 -19
  11. sklearnex/covariance/incremental_covariance.py +130 -0
  12. sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  13. sklearnex/decomposition/pca.py +319 -1
  14. sklearnex/decomposition/tests/test_pca.py +34 -5
  15. sklearnex/dispatcher.py +93 -61
  16. sklearnex/ensemble/_forest.py +81 -97
  17. sklearnex/ensemble/tests/test_forest.py +15 -19
  18. sklearnex/linear_model/__init__.py +1 -2
  19. sklearnex/linear_model/linear.py +275 -347
  20. sklearnex/{preview/linear_model → linear_model}/logistic_regression.py +83 -50
  21. sklearnex/linear_model/tests/test_linear.py +40 -5
  22. sklearnex/linear_model/tests/test_logreg.py +70 -7
  23. sklearnex/neighbors/__init__.py +1 -1
  24. sklearnex/neighbors/_lof.py +221 -0
  25. sklearnex/neighbors/common.py +4 -1
  26. sklearnex/neighbors/knn_classification.py +47 -137
  27. sklearnex/neighbors/knn_regression.py +20 -132
  28. sklearnex/neighbors/knn_unsupervised.py +16 -93
  29. sklearnex/neighbors/tests/test_neighbors.py +12 -16
  30. sklearnex/preview/__init__.py +1 -1
  31. sklearnex/preview/cluster/k_means.py +8 -81
  32. sklearnex/preview/covariance/covariance.py +51 -16
  33. sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  34. sklearnex/spmd/__init__.py +1 -0
  35. sklearnex/{preview/linear_model → spmd/covariance}/__init__.py +5 -5
  36. sklearnex/spmd/covariance/covariance.py +21 -0
  37. sklearnex/spmd/ensemble/forest.py +4 -12
  38. sklearnex/spmd/linear_model/__init__.py +2 -1
  39. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  40. sklearnex/svm/_common.py +4 -7
  41. sklearnex/svm/nusvc.py +74 -55
  42. sklearnex/svm/nusvr.py +9 -56
  43. sklearnex/svm/svc.py +74 -56
  44. sklearnex/svm/svr.py +6 -53
  45. sklearnex/tests/_utils.py +164 -0
  46. sklearnex/tests/test_memory_usage.py +9 -7
  47. sklearnex/tests/test_monkeypatch.py +179 -138
  48. sklearnex/tests/test_n_jobs_support.py +77 -9
  49. sklearnex/tests/test_parallel.py +6 -8
  50. sklearnex/tests/test_patching.py +338 -89
  51. sklearnex/utils/__init__.py +2 -1
  52. sklearnex/utils/_namespace.py +97 -0
  53. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  54. sklearnex/neighbors/lof.py +0 -436
  55. sklearnex/preview/decomposition/pca.py +0 -376
  56. sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
  57. sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  58. sklearnex/tests/_models_info.py +0 -170
  59. sklearnex/tests/utils/_launch_algorithms.py +0 -118
  60. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  61. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  62. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -14,18 +14,86 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+ import inspect
18
+ import logging
19
+ from multiprocessing import cpu_count
20
+
17
21
  import pytest
22
+ from sklearn.base import BaseEstimator
23
+ from sklearn.datasets import make_classification
24
+
25
+ from sklearnex.dispatcher import get_patch_map
26
+ from sklearnex.svm import SVC, NuSVC
27
+
28
+ ESTIMATORS = set(
29
+ filter(
30
+ lambda x: inspect.isclass(x) and issubclass(x, BaseEstimator),
31
+ [value[0][0][2] for value in get_patch_map().values()],
32
+ )
33
+ )
34
+
35
+ X, Y = make_classification(n_samples=40, n_features=4, random_state=42)
36
+
37
+
38
+ @pytest.mark.parametrize("estimator_class", ESTIMATORS)
39
+ @pytest.mark.parametrize("n_jobs", [None, -1, 1, 2])
40
+ def test_n_jobs_support(caplog, estimator_class, n_jobs):
41
+ def check_estimator_doc(estimator):
42
+ if estimator.__doc__ is not None:
43
+ assert "n_jobs" in estimator.__doc__
44
+
45
+ def check_n_jobs_entry_in_logs(caplog, function_name, n_jobs):
46
+ for rec in caplog.records:
47
+ if function_name in rec.message and "threads" in rec.message:
48
+ expected_n_jobs = n_jobs if n_jobs > 0 else cpu_count() + 1 + n_jobs
49
+ logging.info(f"{function_name}: setting {expected_n_jobs} threads")
50
+ if f"{function_name}: setting {expected_n_jobs} threads" in rec.message:
51
+ return True
52
+ # False if n_jobs is set and not found in logs
53
+ return n_jobs is None
18
54
 
19
- from sklearnex.cluster import KMeans
20
- from sklearnex.linear_model import ElasticNet, Lasso, Ridge
21
- from sklearnex.svm import SVC, SVR, NuSVC, NuSVR
55
+ def check_method(*args, method, caplog):
56
+ method(*args)
57
+ assert check_n_jobs_entry_in_logs(caplog, method.__name__, n_jobs)
22
58
 
23
- estimators = [KMeans, SVC, SVR, NuSVC, NuSVR, Lasso, Ridge, ElasticNet]
59
+ def check_methods_decoration(estimator):
60
+ funcs = {
61
+ i: getattr(estimator, i)
62
+ for i in dir(estimator)
63
+ if hasattr(estimator, i) and callable(getattr(estimator, i))
64
+ }
24
65
 
66
+ for func_name, func in funcs.items():
67
+ assert hasattr(func, "__onedal_n_jobs_decorated__") == (
68
+ func_name in estimator._n_jobs_supported_onedal_methods
69
+ ), f"{estimator}.{func_name} n_jobs decoration does not match {estimator} n_jobs supported methods"
25
70
 
26
- @pytest.mark.parametrize("estimator", estimators)
27
- def test_n_jobs_support(estimator):
28
- # use `n_jobs` parameter where original sklearn doesn't expect it
29
- estimator(n_jobs=1)
71
+ caplog.set_level(logging.DEBUG, logger="sklearnex")
72
+ estimator_kwargs = {"n_jobs": n_jobs}
73
+ # by default, [Nu]SVC.predict_proba is restricted by @available_if decorator
74
+ if estimator_class in [SVC, NuSVC]:
75
+ estimator_kwargs["probability"] = True
76
+ estimator_instance = estimator_class(**estimator_kwargs)
30
77
  # check `n_jobs` parameter doc entry
31
- assert "n_jobs" in estimator.__doc__
78
+ check_estimator_doc(estimator_class)
79
+ check_estimator_doc(estimator_instance)
80
+ # check `n_jobs` log entry for supported methods
81
+ # `fit` call is required before other methods
82
+ check_method(X, Y, method=estimator_instance.fit, caplog=caplog)
83
+ for method_name in estimator_instance._n_jobs_supported_onedal_methods:
84
+ if method_name == "fit":
85
+ continue
86
+ method = getattr(estimator_instance, method_name)
87
+ argdict = inspect.signature(method).parameters
88
+ argnum = len(
89
+ [i for i in argdict if argdict[i].default == inspect.Parameter.empty]
90
+ )
91
+ if argnum == 0:
92
+ check_method(method=method, caplog=caplog)
93
+ elif argnum == 1:
94
+ check_method(X, method=method, caplog=caplog)
95
+ else:
96
+ check_method(X, Y, method=method, caplog=caplog)
97
+ # check if correct methods were decorated
98
+ check_methods_decoration(estimator_class)
99
+ check_methods_decoration(estimator_instance)
@@ -15,13 +15,7 @@
15
15
  # ==============================================================================
16
16
  import pytest
17
17
 
18
- from sklearnex import config_context, patch_sklearn
19
-
20
- patch_sklearn()
21
-
22
- from sklearn.datasets import make_classification
23
- from sklearn.ensemble import BaggingClassifier
24
- from sklearn.svm import SVC
18
+ from sklearnex import config_context
25
19
 
26
20
  try:
27
21
  import dpctl
@@ -38,7 +32,11 @@ except (ImportError, ModuleNotFoundError):
38
32
  "to see raised 'SyclQueueCreationError'. "
39
33
  "'dpctl' module is required for test.",
40
34
  )
41
- def test_config_context_in_parallel():
35
+ def test_config_context_in_parallel(with_sklearnex):
36
+ from sklearn.datasets import make_classification
37
+ from sklearn.ensemble import BaggingClassifier
38
+ from sklearn.svm import SVC
39
+
42
40
  x, y = make_classification(random_state=42)
43
41
  try:
44
42
  with config_context(target_offload="gpu", allow_fallback_to_host=False):
@@ -14,109 +14,358 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
18
+ import importlib
19
+ import inspect
20
+ import logging
17
21
  import os
18
- import pathlib
19
22
  import re
20
- import subprocess
21
23
  import sys
22
- from inspect import isclass
24
+ from inspect import signature
23
25
 
26
+ import numpy as np
27
+ import numpy.random as nprnd
24
28
  import pytest
25
- from _models_info import TO_SKIP
26
29
  from sklearn.base import BaseEstimator
27
30
 
28
- from sklearnex import get_patch_map, is_patched_instance, patch_sklearn, unpatch_sklearn
29
-
30
-
31
- def get_branch(s):
32
- if len(s) == 0:
33
- return "NO INFO"
34
- for i in s:
35
- if "failed to run accelerated version, fallback to original Scikit-learn" in i:
36
- return "was in OPT, but go in Scikit"
37
- for i in s:
38
- if "running accelerated version" in i:
39
- return "OPT"
40
- return "Scikit"
41
-
42
-
43
- def run_parse(mas, result):
44
- name, dtype = mas[0].split()
45
- temp = []
46
- INFO_POS = 16
47
- for i in range(1, len(mas)):
48
- mas[i] = mas[i][INFO_POS:] # remove 'SKLEARNEX INFO: '
49
- if not mas[i].startswith("sklearn"):
50
- ind = name + " " + dtype + " " + mas[i]
51
- result[ind] = get_branch(temp)
52
- temp.clear()
53
- else:
54
- temp.append(mas[i])
55
-
56
-
57
- def get_result_log():
58
- os.environ["SKLEARNEX_VERBOSE"] = "INFO"
59
- absolute_path = str(pathlib.Path(__file__).parent.absolute())
60
- try:
61
- process = subprocess.check_output(
62
- [sys.executable, absolute_path + "/utils/_launch_algorithms.py"]
31
+ from daal4py.sklearn._utils import sklearn_check_version
32
+ from onedal.tests.utils._dataframes_support import (
33
+ _convert_to_dataframe,
34
+ get_dataframes_and_queues,
35
+ )
36
+ from sklearnex import is_patched_instance
37
+ from sklearnex.dispatcher import _is_preview_enabled
38
+ from sklearnex.metrics import pairwise_distances, roc_auc_score
39
+ from sklearnex.tests._utils import (
40
+ DTYPES,
41
+ PATCHED_FUNCTIONS,
42
+ PATCHED_MODELS,
43
+ SPECIAL_INSTANCES,
44
+ UNPATCHED_FUNCTIONS,
45
+ UNPATCHED_MODELS,
46
+ gen_dataset,
47
+ gen_models_info,
48
+ )
49
+
50
+
51
+ @pytest.mark.parametrize("dtype", DTYPES)
52
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
53
+ @pytest.mark.parametrize("metric", ["cosine", "correlation"])
54
+ def test_pairwise_distances_patching(caplog, dataframe, queue, dtype, metric):
55
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
56
+ if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
57
+ pytest.skip("Hardware does not support fp16 SYCL testing")
58
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
59
+ pytest.skip("Hardware does not support fp64 SYCL testing")
60
+ elif queue and queue.sycl_device.is_gpu:
61
+ pytest.skip("pairwise_distances does not support GPU queues")
62
+
63
+ rng = nprnd.default_rng()
64
+ X = _convert_to_dataframe(
65
+ rng.random(size=1000).reshape(1, -1),
66
+ sycl_queue=queue,
67
+ target_df=dataframe,
68
+ dtype=dtype,
69
+ )
70
+
71
+ _ = pairwise_distances(X, metric=metric)
72
+ assert all(
73
+ [
74
+ "running accelerated version" in i.message
75
+ or "fallback to original Scikit-learn" in i.message
76
+ for i in caplog.records
77
+ ]
78
+ ), f"sklearnex patching issue in pairwise_distances with log: \n{caplog.text}"
79
+
80
+
81
+ @pytest.mark.parametrize(
82
+ "dtype", [i for i in DTYPES if "32" in i.__name__ or "64" in i.__name__]
83
+ )
84
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
85
+ def test_roc_auc_score_patching(caplog, dataframe, queue, dtype):
86
+ if dtype in [np.uint32, np.uint64] and sys.platform == "win32":
87
+ pytest.skip("Windows issue with unsigned ints")
88
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
89
+ pytest.skip("Hardware does not support fp64 SYCL testing")
90
+
91
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
92
+ rng = nprnd.default_rng()
93
+ X = _convert_to_dataframe(
94
+ rng.integers(2, size=1000),
95
+ sycl_queue=queue,
96
+ target_df=dataframe,
97
+ dtype=dtype,
63
98
  )
64
- except subprocess.CalledProcessError as e:
65
- print(e)
66
- exit(1)
67
- mas = []
68
- result = {}
69
- for i in process.decode().split("\n"):
70
- if i.startswith("SKLEARNEX WARNING"):
71
- continue
72
- if not i.startswith("SKLEARNEX INFO") and len(mas) != 0:
73
- run_parse(mas, result)
74
- mas.clear()
75
- mas.append(i.strip())
76
- else:
77
- mas.append(i.strip())
78
- del os.environ["SKLEARNEX_VERBOSE"]
79
- return result
80
-
81
-
82
- result_log = get_result_log()
83
-
84
-
85
- @pytest.mark.parametrize("configuration", result_log)
86
- def test_patching(configuration):
87
- if "OPT" in result_log[configuration]:
88
- return
89
- for skip in TO_SKIP:
90
- if re.search(skip, configuration) is not None:
91
- pytest.skip("SKIPPED", allow_module_level=False)
92
- raise ValueError("Test patching failed: " + configuration)
93
-
94
-
95
- def _load_all_models(patched):
96
- if patched:
97
- patch_sklearn()
98
-
99
- models = []
100
- for patch_infos in get_patch_map().values():
101
- maybe_class = getattr(patch_infos[0][0][0], patch_infos[0][0][1])
102
- if (
103
- maybe_class is not None
104
- and isclass(maybe_class)
105
- and issubclass(maybe_class, BaseEstimator)
99
+ y = _convert_to_dataframe(
100
+ rng.integers(2, size=1000),
101
+ sycl_queue=queue,
102
+ target_df=dataframe,
103
+ dtype=dtype,
104
+ )
105
+
106
+ _ = roc_auc_score(X, y)
107
+ assert all(
108
+ [
109
+ "running accelerated version" in i.message
110
+ or "fallback to original Scikit-learn" in i.message
111
+ for i in caplog.records
112
+ ]
113
+ ), f"sklearnex patching issue in roc_auc_score with log: \n{caplog.text}"
114
+
115
+
116
+ @pytest.mark.parametrize("dtype", DTYPES)
117
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
118
+ @pytest.mark.parametrize("estimator, method", gen_models_info(PATCHED_MODELS))
119
+ def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
120
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
121
+ est = PATCHED_MODELS[estimator]()
122
+
123
+ if queue:
124
+ if dtype == np.float16 and not queue.sycl_device.has_aspect_fp16:
125
+ pytest.skip("Hardware does not support fp16 SYCL testing")
126
+ elif dtype == np.float64 and not queue.sycl_device.has_aspect_fp64:
127
+ pytest.skip("Hardware does not support fp64 SYCL testing")
128
+ elif queue.sycl_device.is_gpu and estimator in [
129
+ "KMeans",
130
+ "ElasticNet",
131
+ "Lasso",
132
+ "Ridge",
133
+ ]:
134
+ pytest.skip(f"{estimator} does not support GPU queues")
135
+
136
+ if estimator == "TSNE" and method == "fit_transform":
137
+ pytest.skip("TSNE.fit_transform is too slow for common testing")
138
+ elif (
139
+ estimator == "Ridge"
140
+ and method in ["predict", "score"]
141
+ and sys.platform == "win32"
142
+ and dtype in [np.uint32, np.uint64]
106
143
  ):
107
- models.append(maybe_class())
144
+ pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints")
145
+ elif method and not hasattr(est, method):
146
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
147
+
148
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
149
+ est.fit(X, y)
150
+
151
+ if method:
152
+ if method != "score":
153
+ getattr(est, method)(X)
154
+ else:
155
+ est.score(X, y)
156
+ assert all(
157
+ [
158
+ "running accelerated version" in i.message
159
+ or "fallback to original Scikit-learn" in i.message
160
+ for i in caplog.records
161
+ ]
162
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
163
+
164
+
165
+ @pytest.mark.parametrize("dtype", DTYPES)
166
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
167
+ @pytest.mark.parametrize("estimator, method", gen_models_info(SPECIAL_INSTANCES))
168
+ def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
169
+ # prepare logging
170
+
171
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
172
+ est = SPECIAL_INSTANCES[estimator]
173
+
174
+ # Its not possible to get the dpnp/dpctl arrays to be in the proper dtype
175
+ if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
176
+ pytest.skip("Hardware does not support fp16 SYCL testing")
177
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
178
+ pytest.skip("Hardware does not support fp64 SYCL testing")
179
+
180
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
181
+ est.fit(X, y)
182
+
183
+ if method and not hasattr(est, method):
184
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
185
+
186
+ if method:
187
+ if method != "score":
188
+ getattr(est, method)(X)
189
+ else:
190
+ est.score(X, y)
191
+
192
+ assert all(
193
+ [
194
+ "running accelerated version" in i.message
195
+ or "fallback to original Scikit-learn" in i.message
196
+ for i in caplog.records
197
+ ]
198
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
199
+
200
+
201
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
202
+ def test_standard_estimator_signatures(estimator):
203
+ est = PATCHED_MODELS[estimator]()
204
+ unpatched_est = UNPATCHED_MODELS[estimator]()
108
205
 
109
- if patched:
110
- unpatch_sklearn()
206
+ # all public sklearn methods should have signature matches in sklearnex
111
207
 
112
- return models
208
+ unpatched_est_methods = [
209
+ i
210
+ for i in dir(unpatched_est)
211
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched_est, i)
212
+ ]
213
+ for method in unpatched_est_methods:
214
+ est_method = getattr(est, method)
215
+ unpatched_est_method = getattr(unpatched_est, method)
216
+ if callable(unpatched_est_method):
217
+ regex = rf"(?:sklearn|daal4py)\S*{estimator}" # needed due to differences in module structure
218
+ patched_sig = re.sub(regex, estimator, str(signature(est_method)))
219
+ unpatched_sig = re.sub(regex, estimator, str(signature(unpatched_est_method)))
220
+ assert (
221
+ patched_sig == unpatched_sig
222
+ ), f"Signature of {estimator}.{method} does not match sklearn"
113
223
 
114
224
 
115
- PATCHED_MODELS = _load_all_models(patched=True)
116
- UNPATCHED_MODELS = _load_all_models(patched=False)
225
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
226
+ def test_standard_estimator_init_signatures(estimator):
227
+ # Several estimators have additional parameters that are user-accessible
228
+ # which are sklearnex-specific. They will fail and are removed from tests.
229
+ # remove n_jobs due to estimator patching for sklearnex (known deviation)
230
+ patched_sig = str(signature(PATCHED_MODELS[estimator].__init__))
231
+ unpatched_sig = str(signature(UNPATCHED_MODELS[estimator].__init__))
117
232
 
233
+ # Sklearnex allows for positional kwargs and n_jobs, when sklearn doesn't
234
+ for kwarg in ["n_jobs=None", "*"]:
235
+ patched_sig = patched_sig.replace(", " + kwarg, "")
236
+ unpatched_sig = unpatched_sig.replace(", " + kwarg, "")
118
237
 
119
- @pytest.mark.parametrize(("patched", "unpatched"), zip(PATCHED_MODELS, UNPATCHED_MODELS))
120
- def test_is_patched_instance(patched, unpatched):
238
+ # Special sklearnex-specific kwargs are removed from signatures here
239
+ if estimator in [
240
+ "RandomForestRegressor",
241
+ "RandomForestClassifier",
242
+ "ExtraTreesRegressor",
243
+ "ExtraTreesClassifier",
244
+ ]:
245
+ for kwarg in ["min_bin_size=1", "max_bins=256"]:
246
+ patched_sig = patched_sig.replace(", " + kwarg, "")
247
+
248
+ assert (
249
+ patched_sig == unpatched_sig
250
+ ), f"Signature of {estimator}.__init__ does not match sklearn"
251
+
252
+
253
+ @pytest.mark.parametrize(
254
+ "function",
255
+ [
256
+ i
257
+ for i in UNPATCHED_FUNCTIONS.keys()
258
+ if i not in ["train_test_split", "set_config", "config_context"]
259
+ ],
260
+ )
261
+ def test_patched_function_signatures(function):
262
+ # certain functions are dropped from the test
263
+ # as they add functionality to the underlying sklearn function
264
+ if not sklearn_check_version("1.1") and function == "_assert_all_finite":
265
+ pytest.skip("Sklearn versioning not added to _assert_all_finite")
266
+ func = PATCHED_FUNCTIONS[function]
267
+ unpatched_func = UNPATCHED_FUNCTIONS[function]
268
+
269
+ if callable(unpatched_func):
270
+ assert str(signature(func)) == str(
271
+ signature(unpatched_func)
272
+ ), f"Signature of {func} does not match sklearn"
273
+
274
+
275
+ def test_patch_map_match():
276
+ # This rule applies to functions and classes which are out of preview.
277
+ # Items listed in a matching submodule's __all__ attribute should be
278
+ # in get_patch_map. There should not be any missing or additional elements.
279
+
280
+ def list_all_attr(string):
281
+ try:
282
+ modules = set(importlib.import_module(string).__all__)
283
+ except ModuleNotFoundError:
284
+ modules = set([None])
285
+ return modules
286
+
287
+ if _is_preview_enabled():
288
+ pytest.skip("preview sklearnex has been activated")
289
+ patched = {**PATCHED_MODELS, **PATCHED_FUNCTIONS}
290
+
291
+ sklearnex__all__ = list_all_attr("sklearnex")
292
+ sklearn__all__ = list_all_attr("sklearn")
293
+
294
+ module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)}
295
+
296
+ # _assert_all_finite patches an internal sklearn function which isn't
297
+ # exposed via __all__ in sklearn. It is a special case where this rule
298
+ # is not applied (e.g. it is grandfathered in).
299
+ del patched["_assert_all_finite"]
300
+
301
+ # remove all scikit-learn-intelex-only estimators
302
+ for i in patched.copy():
303
+ if i not in UNPATCHED_MODELS and i not in UNPATCHED_FUNCTIONS:
304
+ del patched[i]
305
+
306
+ for module in module_map:
307
+ sklearn_module__all__ = list_all_attr("sklearn." + module_map[module])
308
+ sklearnex_module__all__ = list_all_attr("sklearnex." + module)
309
+ intersect = sklearnex_module__all__.intersection(sklearn_module__all__)
310
+
311
+ for i in intersect:
312
+ if i:
313
+ del patched[i]
314
+ else:
315
+ del patched[module]
316
+ assert patched == {}, f"{patched.keys()} were not properly patched"
317
+
318
+
319
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
320
+ def test_is_patched_instance(estimator):
321
+ patched = PATCHED_MODELS[estimator]
322
+ unpatched = UNPATCHED_MODELS[estimator]
121
323
  assert is_patched_instance(patched), f"{patched} is a patched instance"
122
324
  assert not is_patched_instance(unpatched), f"{unpatched} is an unpatched instance"
325
+
326
+
327
+ @pytest.mark.parametrize("estimator", PATCHED_MODELS.keys())
328
+ def test_if_estimator_inherits_sklearn(estimator):
329
+ est = PATCHED_MODELS[estimator]
330
+ if estimator in UNPATCHED_MODELS:
331
+ assert issubclass(
332
+ est, UNPATCHED_MODELS[estimator]
333
+ ), f"{estimator} does not inherit from the patched sklearn estimator"
334
+ else:
335
+ assert issubclass(est, BaseEstimator)
336
+
337
+
338
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
339
+ def test_docstring_patching_match(estimator):
340
+ patched = PATCHED_MODELS[estimator]
341
+ unpatched = UNPATCHED_MODELS[estimator]
342
+ patched_docstrings = {
343
+ i: getattr(patched, i).__doc__
344
+ for i in dir(patched)
345
+ if not i.startswith("_") and not i.endswith("_") and hasattr(patched, i)
346
+ }
347
+ unpatched_docstrings = {
348
+ i: getattr(unpatched, i).__doc__
349
+ for i in dir(unpatched)
350
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched, i)
351
+ }
352
+
353
+ # check class docstring match if a docstring is available
354
+
355
+ assert (patched.__doc__ is None) == (unpatched.__doc__ is None)
356
+
357
+ # check class attribute docstrings
358
+
359
+ for i in unpatched_docstrings:
360
+ assert (patched_docstrings[i] is None) == (unpatched_docstrings[i] is None)
361
+
362
+
363
+ @pytest.mark.parametrize("member", ["_onedal_cpu_supported", "_onedal_gpu_supported"])
364
+ @pytest.mark.parametrize(
365
+ "name",
366
+ [i for i in PATCHED_MODELS.keys() if "sklearnex" in PATCHED_MODELS[i].__module__],
367
+ )
368
+ def test_onedal_supported_member(name, member):
369
+ patched = PATCHED_MODELS[name]
370
+ sig = str(inspect.signature(getattr(patched, member)))
371
+ assert "(self, method_name, *data)" == sig
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
+ from ._namespace import get_namespace
17
18
  from .validation import _assert_all_finite
18
19
 
19
- __all__ = ["_assert_all_finite"]
20
+ __all__ = ["get_namespace", "_assert_all_finite"]
@@ -0,0 +1,97 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+
19
+ from daal4py.sklearn._utils import sklearn_check_version
20
+
21
+ from .._device_offload import dpnp_available
22
+
23
+ if sklearn_check_version("1.2"):
24
+ from sklearn.utils._array_api import get_namespace as sklearn_get_namespace
25
+
26
+ if dpnp_available:
27
+ import dpnp
28
+
29
+
30
+ def get_namespace(*arrays):
31
+ """Get namespace of arrays.
32
+
33
+ Introspect `arrays` arguments and return their common Array API
34
+ compatible namespace object, if any. NumPy 1.22 and later can
35
+ construct such containers using the `numpy.array_api` namespace
36
+ for instance.
37
+
38
+ This function will return the namespace of SYCL-related arrays
39
+ which define the __sycl_usm_array_interface__ attribute
40
+ regardless of array_api support, the configuration of
41
+ array_api_dispatch, or scikit-learn version.
42
+
43
+ See: https://numpy.org/neps/nep-0047-array-api-standard.html
44
+
45
+ If `arrays` are regular numpy arrays, an instance of the
46
+ `_NumPyApiWrapper` compatibility wrapper is returned instead.
47
+
48
+ Namespace support is not enabled by default. To enabled it
49
+ call:
50
+
51
+ sklearn.set_config(array_api_dispatch=True)
52
+
53
+ or:
54
+
55
+ with sklearn.config_context(array_api_dispatch=True):
56
+ # your code here
57
+
58
+ Otherwise an instance of the `_NumPyApiWrapper`
59
+ compatibility wrapper is always returned irrespective of
60
+ the fact that arrays implement the `__array_namespace__`
61
+ protocol or not.
62
+
63
+ Parameters
64
+ ----------
65
+ *arrays : array objects
66
+ Array objects.
67
+
68
+ Returns
69
+ -------
70
+ namespace : module
71
+ Namespace shared by array objects.
72
+
73
+ is_array_api : bool
74
+ True of the arrays are containers that implement the Array API spec.
75
+ """
76
+
77
+ # sycl support designed to work regardless of array_api_dispatch sklearn global value
78
+ sycl_type = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")}
79
+
80
+ if len(sycl_type) > 1:
81
+ raise ValueError(f"Multiple SYCL types for array inputs: {sycl_type}")
82
+
83
+ if sycl_type:
84
+
85
+ (X,) = sycl_type.values()
86
+
87
+ if hasattr(X, "__array_namespace__"):
88
+ return X.__array_namespace__(), True
89
+ elif dpnp_available and isinstance(X, dpnp.ndarray):
90
+ return dpnp, False
91
+ else:
92
+ raise ValueError(f"SYCL type not recognized: {sycl_type}")
93
+
94
+ elif sklearn_check_version("1.2"):
95
+ return sklearn_get_namespace(*arrays)
96
+ else:
97
+ return np, True