scikit-learn-intelex 2024.1.0__py310-none-manylinux1_x86_64.whl → 2024.3.0__py310-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (51) hide show
  1. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  2. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/RECORD +45 -44
  3. sklearnex/__init__.py +9 -7
  4. sklearnex/cluster/dbscan.py +6 -4
  5. sklearnex/conftest.py +63 -0
  6. sklearnex/{preview/decomposition → covariance}/__init__.py +19 -19
  7. sklearnex/covariance/incremental_covariance.py +130 -0
  8. sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  9. sklearnex/decomposition/pca.py +322 -1
  10. sklearnex/decomposition/tests/test_pca.py +34 -5
  11. sklearnex/dispatcher.py +91 -59
  12. sklearnex/ensemble/_forest.py +15 -24
  13. sklearnex/ensemble/tests/test_forest.py +15 -19
  14. sklearnex/linear_model/__init__.py +1 -2
  15. sklearnex/linear_model/linear.py +3 -10
  16. sklearnex/{preview/linear_model → linear_model}/logistic_regression.py +32 -40
  17. sklearnex/linear_model/tests/test_logreg.py +70 -7
  18. sklearnex/neighbors/__init__.py +1 -1
  19. sklearnex/neighbors/_lof.py +204 -0
  20. sklearnex/neighbors/knn_classification.py +13 -18
  21. sklearnex/neighbors/knn_regression.py +12 -17
  22. sklearnex/neighbors/knn_unsupervised.py +10 -15
  23. sklearnex/neighbors/tests/test_neighbors.py +12 -16
  24. sklearnex/preview/__init__.py +1 -1
  25. sklearnex/preview/cluster/k_means.py +3 -8
  26. sklearnex/preview/covariance/covariance.py +46 -12
  27. sklearnex/spmd/__init__.py +1 -0
  28. sklearnex/{preview/linear_model → spmd/covariance}/__init__.py +5 -5
  29. sklearnex/spmd/covariance/covariance.py +21 -0
  30. sklearnex/spmd/ensemble/forest.py +4 -12
  31. sklearnex/spmd/linear_model/__init__.py +2 -1
  32. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  33. sklearnex/svm/nusvc.py +9 -6
  34. sklearnex/svm/nusvr.py +6 -7
  35. sklearnex/svm/svc.py +9 -6
  36. sklearnex/svm/svr.py +3 -4
  37. sklearnex/tests/_utils.py +155 -0
  38. sklearnex/tests/test_memory_usage.py +9 -7
  39. sklearnex/tests/test_monkeypatch.py +179 -138
  40. sklearnex/tests/test_n_jobs_support.py +71 -9
  41. sklearnex/tests/test_parallel.py +6 -8
  42. sklearnex/tests/test_patching.py +321 -82
  43. sklearnex/neighbors/lof.py +0 -436
  44. sklearnex/preview/decomposition/pca.py +0 -376
  45. sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
  46. sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  47. sklearnex/tests/_models_info.py +0 -170
  48. sklearnex/tests/utils/_launch_algorithms.py +0 -118
  49. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  50. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  51. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,12 @@
17
17
  import sklearnex
18
18
  from daal4py.sklearn._utils import daal_check_version
19
19
 
20
+ # General use of patch_sklearn and unpatch_sklearn in pytest is not recommended.
21
+ # It changes global state and can impact the operation of other tests. This file
22
+ # specifically tests patch_sklearn and unpatch_sklearn and is exempt from this.
23
+ # If sklearnex patching is necessary in testing, use the 'with_sklearnex' pytest
24
+ # fixture.
25
+
20
26
 
21
27
  def test_monkey_patching():
22
28
  _tokens = sklearnex.get_patch_names()
@@ -27,129 +33,170 @@ def test_monkey_patching():
27
33
  for c in v:
28
34
  _classes.append(c[0])
29
35
 
30
- sklearnex.patch_sklearn()
31
-
32
- for i, _ in enumerate(_tokens):
33
- t = _tokens[i]
34
- p = _classes[i][0]
35
- n = _classes[i][1]
36
-
37
- class_module = getattr(p, n).__module__
38
- assert class_module.startswith("daal4py") or class_module.startswith(
39
- "sklearnex"
40
- ), "Patching has completed with error."
41
-
42
- for i, _ in enumerate(_tokens):
43
- t = _tokens[i]
44
- p = _classes[i][0]
45
- n = _classes[i][1]
46
-
47
- sklearnex.unpatch_sklearn(t)
48
- class_module = getattr(p, n).__module__
49
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
50
-
51
- sklearnex.unpatch_sklearn()
52
-
53
- for i, _ in enumerate(_tokens):
54
- t = _tokens[i]
55
- p = _classes[i][0]
56
- n = _classes[i][1]
57
-
58
- class_module = getattr(p, n).__module__
59
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
60
-
61
- sklearnex.unpatch_sklearn()
62
-
63
- for i, _ in enumerate(_tokens):
64
- t = _tokens[i]
65
- p = _classes[i][0]
66
- n = _classes[i][1]
67
-
68
- sklearnex.patch_sklearn(t)
69
-
70
- class_module = getattr(p, n).__module__
71
- assert class_module.startswith("daal4py") or class_module.startswith(
72
- "sklearnex"
73
- ), "Patching has completed with error."
74
-
75
- sklearnex.unpatch_sklearn()
36
+ try:
37
+ sklearnex.patch_sklearn()
38
+
39
+ for i, _ in enumerate(_tokens):
40
+ t = _tokens[i]
41
+ p = _classes[i][0]
42
+ n = _classes[i][1]
43
+
44
+ class_module = getattr(p, n).__module__
45
+ assert class_module.startswith("daal4py") or class_module.startswith(
46
+ "sklearnex"
47
+ ), "Patching has completed with error."
48
+
49
+ for i, _ in enumerate(_tokens):
50
+ t = _tokens[i]
51
+ p = _classes[i][0]
52
+ n = _classes[i][1]
53
+
54
+ sklearnex.unpatch_sklearn(t)
55
+ sklearn_class = getattr(p, n, None)
56
+ if sklearn_class is not None:
57
+ sklearn_class = sklearn_class.__module__
58
+ assert sklearn_class is None or sklearn_class.startswith(
59
+ "sklearn"
60
+ ), "Unpatching has completed with error."
61
+
62
+ finally:
63
+ sklearnex.unpatch_sklearn()
64
+
65
+ try:
66
+ for i, _ in enumerate(_tokens):
67
+ t = _tokens[i]
68
+ p = _classes[i][0]
69
+ n = _classes[i][1]
70
+
71
+ sklearn_class = getattr(p, n, None)
72
+ if sklearn_class is not None:
73
+ sklearn_class = sklearn_class.__module__
74
+ assert sklearn_class is None or sklearn_class.startswith(
75
+ "sklearn"
76
+ ), "Unpatching has completed with error."
77
+
78
+ finally:
79
+ sklearnex.unpatch_sklearn()
80
+
81
+ try:
82
+ for i, _ in enumerate(_tokens):
83
+ t = _tokens[i]
84
+ p = _classes[i][0]
85
+ n = _classes[i][1]
86
+
87
+ sklearnex.patch_sklearn(t)
88
+
89
+ class_module = getattr(p, n).__module__
90
+ assert class_module.startswith("daal4py") or class_module.startswith(
91
+ "sklearnex"
92
+ ), "Patching has completed with error."
93
+ finally:
94
+ sklearnex.unpatch_sklearn()
76
95
 
77
96
 
78
97
  def test_patch_by_list_simple():
79
- sklearnex.patch_sklearn(["LogisticRegression"])
98
+ try:
99
+ sklearnex.patch_sklearn(["LogisticRegression"])
80
100
 
81
- from sklearn.ensemble import RandomForestRegressor
82
- from sklearn.linear_model import LogisticRegression
83
- from sklearn.neighbors import KNeighborsRegressor
84
- from sklearn.svm import SVC
85
-
86
- assert RandomForestRegressor.__module__.startswith("sklearn")
87
- assert KNeighborsRegressor.__module__.startswith("sklearn")
88
- assert LogisticRegression.__module__.startswith("daal4py")
89
- assert SVC.__module__.startswith("sklearn")
101
+ from sklearn.ensemble import RandomForestRegressor
102
+ from sklearn.linear_model import LogisticRegression
103
+ from sklearn.neighbors import KNeighborsRegressor
104
+ from sklearn.svm import SVC
90
105
 
91
- sklearnex.unpatch_sklearn()
106
+ assert RandomForestRegressor.__module__.startswith("sklearn")
107
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
108
+ if daal_check_version((2024, "P", 1)):
109
+ assert LogisticRegression.__module__.startswith("sklearnex")
110
+ else:
111
+ assert LogisticRegression.__module__.startswith("daal4py")
112
+ assert SVC.__module__.startswith("sklearn")
113
+ finally:
114
+ sklearnex.unpatch_sklearn()
92
115
 
93
116
 
94
117
  def test_patch_by_list_many_estimators():
95
- sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
118
+ try:
119
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
96
120
 
97
- from sklearn.ensemble import RandomForestRegressor
98
- from sklearn.linear_model import LogisticRegression
99
- from sklearn.neighbors import KNeighborsRegressor
100
- from sklearn.svm import SVC
121
+ from sklearn.ensemble import RandomForestRegressor
122
+ from sklearn.linear_model import LogisticRegression
123
+ from sklearn.neighbors import KNeighborsRegressor
124
+ from sklearn.svm import SVC
101
125
 
102
- assert RandomForestRegressor.__module__.startswith("sklearn")
103
- assert KNeighborsRegressor.__module__.startswith("sklearn")
104
- assert LogisticRegression.__module__.startswith("daal4py")
105
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
126
+ assert RandomForestRegressor.__module__.startswith("sklearn")
127
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
128
+ if daal_check_version((2024, "P", 1)):
129
+ assert LogisticRegression.__module__.startswith("sklearnex")
130
+ else:
131
+ assert LogisticRegression.__module__.startswith("daal4py")
132
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
133
+ "sklearnex"
134
+ )
106
135
 
107
- sklearnex.unpatch_sklearn()
136
+ finally:
137
+ sklearnex.unpatch_sklearn()
108
138
 
109
139
 
110
140
  def test_unpatch_by_list_many_estimators():
111
- sklearnex.patch_sklearn()
141
+ try:
142
+ sklearnex.patch_sklearn()
143
+
144
+ from sklearn.ensemble import RandomForestRegressor
145
+ from sklearn.linear_model import LogisticRegression
146
+ from sklearn.neighbors import KNeighborsRegressor
147
+ from sklearn.svm import SVC
112
148
 
113
- from sklearn.ensemble import RandomForestRegressor
114
- from sklearn.linear_model import LogisticRegression
115
- from sklearn.neighbors import KNeighborsRegressor
116
- from sklearn.svm import SVC
149
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
150
+ assert KNeighborsRegressor.__module__.startswith(
151
+ "daal4py"
152
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
153
+ if daal_check_version((2024, "P", 1)):
154
+ assert LogisticRegression.__module__.startswith("sklearnex")
155
+ else:
156
+ assert LogisticRegression.__module__.startswith("daal4py")
157
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
158
+ "sklearnex"
159
+ )
117
160
 
118
- assert RandomForestRegressor.__module__.startswith("sklearnex")
119
- assert KNeighborsRegressor.__module__.startswith(
120
- "daal4py"
121
- ) or KNeighborsRegressor.__module__.startswith("sklearnex")
122
- assert LogisticRegression.__module__.startswith("daal4py")
123
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
161
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
124
162
 
125
- sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
163
+ from sklearn.ensemble import RandomForestRegressor
164
+ from sklearn.linear_model import LogisticRegression
165
+ from sklearn.neighbors import KNeighborsRegressor
166
+ from sklearn.svm import SVC
126
167
 
127
- from sklearn.ensemble import RandomForestRegressor
128
- from sklearn.linear_model import LogisticRegression
129
- from sklearn.neighbors import KNeighborsRegressor
130
- from sklearn.svm import SVC
168
+ assert RandomForestRegressor.__module__.startswith("sklearn")
169
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
170
+ if daal_check_version((2024, "P", 1)):
171
+ assert LogisticRegression.__module__.startswith("sklearnex")
172
+ else:
173
+ assert LogisticRegression.__module__.startswith("daal4py")
131
174
 
132
- assert RandomForestRegressor.__module__.startswith("sklearn")
133
- assert KNeighborsRegressor.__module__.startswith("sklearn")
134
- assert LogisticRegression.__module__.startswith("daal4py")
135
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
175
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
176
+ "sklearnex"
177
+ )
178
+ finally:
179
+ sklearnex.unpatch_sklearn()
136
180
 
137
181
 
138
182
  def test_patching_checker():
139
183
  for name in [None, "SVC", "PCA"]:
140
- sklearnex.patch_sklearn(name=name)
141
- assert sklearnex.sklearn_is_patched(name=name)
142
-
143
- sklearnex.unpatch_sklearn(name=name)
144
- assert not sklearnex.sklearn_is_patched(name=name)
145
-
146
- sklearnex.patch_sklearn()
147
- patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
148
- assert len(patching_status_map) == len(sklearnex.get_patch_names())
149
- for status in patching_status_map.values():
150
- assert status
184
+ try:
185
+ sklearnex.patch_sklearn(name=name)
186
+ assert sklearnex.sklearn_is_patched(name=name)
187
+
188
+ finally:
189
+ sklearnex.unpatch_sklearn(name=name)
190
+ assert not sklearnex.sklearn_is_patched(name=name)
191
+ try:
192
+ sklearnex.patch_sklearn()
193
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
194
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
195
+ for status in patching_status_map.values():
196
+ assert status
197
+ finally:
198
+ sklearnex.unpatch_sklearn()
151
199
 
152
- sklearnex.unpatch_sklearn()
153
200
  patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
154
201
  assert len(patching_status_map) == len(sklearnex.get_patch_names())
155
202
  for status in patching_status_map.values():
@@ -161,67 +208,61 @@ def test_preview_namespace():
161
208
  from sklearn.cluster import DBSCAN
162
209
  from sklearn.decomposition import PCA
163
210
  from sklearn.ensemble import RandomForestClassifier
164
- from sklearn.linear_model import LinearRegression, LogisticRegression
211
+ from sklearn.linear_model import LinearRegression
165
212
  from sklearn.svm import SVC
166
213
 
167
214
  return (
168
215
  LinearRegression(),
169
- LogisticRegression(),
170
216
  PCA(),
171
217
  DBSCAN(),
172
218
  SVC(),
173
219
  RandomForestClassifier(),
174
220
  )
175
221
 
176
- # BUG: previous patching tests force PCA to be patched with daal4py.
177
- # This unpatching returns behavior to expected
178
- sklearnex.unpatch_sklearn()
179
- # behavior with enabled preview
180
- sklearnex.patch_sklearn(preview=True)
181
222
  from sklearnex.dispatcher import _is_preview_enabled
182
223
 
183
- assert _is_preview_enabled()
224
+ try:
225
+ sklearnex.patch_sklearn(preview=True)
226
+
227
+ assert _is_preview_enabled()
184
228
 
185
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
186
- assert "sklearnex" in rfc.__module__
229
+ lr, pca, dbscan, svc, rfc = get_estimators()
230
+ assert "sklearnex" in rfc.__module__
187
231
 
188
- if daal_check_version((2023, "P", 100)):
189
- assert "sklearnex" in lr.__module__
190
- else:
191
- assert "daal4py" in lr.__module__
232
+ if daal_check_version((2023, "P", 100)):
233
+ assert "sklearnex" in lr.__module__
234
+ else:
235
+ assert "daal4py" in lr.__module__
192
236
 
193
- if daal_check_version((2024, "P", 1)):
194
- assert "sklearnex" in log_reg.__module__
195
- else:
196
- assert "daal4py" in log_reg.__module__
237
+ assert "sklearnex" in pca.__module__
238
+ assert "sklearnex" in dbscan.__module__
239
+ assert "sklearnex" in svc.__module__
197
240
 
198
- assert "sklearnex.preview" in pca.__module__
199
- assert "sklearnex" in dbscan.__module__
200
- assert "sklearnex" in svc.__module__
201
- sklearnex.unpatch_sklearn()
241
+ finally:
242
+ sklearnex.unpatch_sklearn()
202
243
 
203
244
  # no patching behavior
204
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
245
+ lr, pca, dbscan, svc, rfc = get_estimators()
205
246
  assert "sklearn." in lr.__module__ and "daal4py" not in lr.__module__
206
- assert "sklearn." in log_reg.__module__ and "daal4py" not in log_reg.__module__
207
247
  assert "sklearn." in pca.__module__ and "daal4py" not in pca.__module__
208
248
  assert "sklearn." in dbscan.__module__ and "daal4py" not in dbscan.__module__
209
249
  assert "sklearn." in svc.__module__ and "daal4py" not in svc.__module__
210
250
  assert "sklearn." in rfc.__module__ and "daal4py" not in rfc.__module__
211
251
 
212
252
  # default patching behavior
213
- sklearnex.patch_sklearn()
214
- assert not _is_preview_enabled()
215
-
216
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
217
- if daal_check_version((2023, "P", 100)):
218
- assert "sklearnex" in lr.__module__
219
- else:
220
- assert "daal4py" in lr.__module__
221
-
222
- assert "daal4py" in log_reg.__module__
223
- assert "daal4py" in pca.__module__
224
- assert "sklearnex" in rfc.__module__
225
- assert "sklearnex" in dbscan.__module__
226
- assert "sklearnex" in svc.__module__
227
- sklearnex.unpatch_sklearn()
253
+ try:
254
+ sklearnex.patch_sklearn()
255
+ assert not _is_preview_enabled()
256
+
257
+ lr, pca, dbscan, svc, rfc = get_estimators()
258
+ if daal_check_version((2023, "P", 100)):
259
+ assert "sklearnex" in lr.__module__
260
+ else:
261
+ assert "daal4py" in lr.__module__
262
+
263
+ assert "sklearnex" in pca.__module__
264
+ assert "sklearnex" in rfc.__module__
265
+ assert "sklearnex" in dbscan.__module__
266
+ assert "sklearnex" in svc.__module__
267
+ finally:
268
+ sklearnex.unpatch_sklearn()
@@ -14,18 +14,80 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+ import inspect
18
+ import logging
19
+ from multiprocessing import cpu_count
20
+
17
21
  import pytest
22
+ from sklearn.base import BaseEstimator
23
+ from sklearn.datasets import make_classification
24
+
25
+ from sklearnex.dispatcher import get_patch_map
26
+ from sklearnex.svm import SVC, NuSVC
27
+
28
+ ESTIMATORS = set(
29
+ filter(
30
+ lambda x: inspect.isclass(x) and issubclass(x, BaseEstimator),
31
+ [value[0][0][2] for value in get_patch_map().values()],
32
+ )
33
+ )
34
+
35
+ X, Y = make_classification(n_samples=40, n_features=4, random_state=42)
36
+
37
+
38
+ @pytest.mark.parametrize("estimator_class", ESTIMATORS)
39
+ @pytest.mark.parametrize("n_jobs", [None, -1, 1, 2])
40
+ def test_n_jobs_support(caplog, estimator_class, n_jobs):
41
+ def check_estimator_doc(estimator):
42
+ if estimator.__doc__ is not None:
43
+ assert "n_jobs" in estimator.__doc__
44
+
45
+ def check_n_jobs_entry_in_logs(caplog, function_name, n_jobs):
46
+ for rec in caplog.records:
47
+ if function_name in rec.message and "threads" in rec.message:
48
+ expected_n_jobs = n_jobs if n_jobs > 0 else cpu_count() + 1 + n_jobs
49
+ logging.info(f"{function_name}: setting {expected_n_jobs} threads")
50
+ if f"{function_name}: setting {expected_n_jobs} threads" in rec.message:
51
+ return True
52
+ # False if n_jobs is set and not found in logs
53
+ return n_jobs is None
18
54
 
19
- from sklearnex.cluster import KMeans
20
- from sklearnex.linear_model import ElasticNet, Lasso, Ridge
21
- from sklearnex.svm import SVC, SVR, NuSVC, NuSVR
55
+ def check_method(*args, method, caplog):
56
+ method(*args)
57
+ assert check_n_jobs_entry_in_logs(caplog, method.__name__, n_jobs)
22
58
 
23
- estimators = [KMeans, SVC, SVR, NuSVC, NuSVR, Lasso, Ridge, ElasticNet]
59
+ def check_methods_decoration(estimator):
60
+ funcs = {
61
+ i: getattr(estimator, i)
62
+ for i in dir(estimator)
63
+ if hasattr(estimator, i) and callable(getattr(estimator, i))
64
+ }
24
65
 
66
+ for func_name, func in funcs.items():
67
+ assert hasattr(func, "__onedal_n_jobs_decorated__") == (
68
+ func_name in estimator._n_jobs_supported_onedal_methods
69
+ ), f"{estimator}.{func_name} n_jobs decoration does not match {estimator} n_jobs supported methods"
25
70
 
26
- @pytest.mark.parametrize("estimator", estimators)
27
- def test_n_jobs_support(estimator):
28
- # use `n_jobs` parameter where original sklearn doesn't expect it
29
- estimator(n_jobs=1)
71
+ caplog.set_level(logging.DEBUG, logger="sklearnex")
72
+ estimator_kwargs = {"n_jobs": n_jobs}
73
+ # by default, [Nu]SVC.predict_proba is restricted by @available_if decorator
74
+ if estimator_class in [SVC, NuSVC]:
75
+ estimator_kwargs["probability"] = True
76
+ estimator_instance = estimator_class(**estimator_kwargs)
30
77
  # check `n_jobs` parameter doc entry
31
- assert "n_jobs" in estimator.__doc__
78
+ check_estimator_doc(estimator_class)
79
+ check_estimator_doc(estimator_instance)
80
+ # check `n_jobs` log entry for supported methods
81
+ # `fit` call is required before other methods
82
+ check_method(X, Y, method=estimator_instance.fit, caplog=caplog)
83
+ for method_name in estimator_instance._n_jobs_supported_onedal_methods:
84
+ if method_name == "fit":
85
+ continue
86
+ method = getattr(estimator_instance, method_name)
87
+ if len(inspect.signature(method).parameters) == 0:
88
+ check_method(method=method, caplog=caplog)
89
+ else:
90
+ check_method(X, method=method, caplog=caplog)
91
+ # check if correct methods were decorated
92
+ check_methods_decoration(estimator_class)
93
+ check_methods_decoration(estimator_instance)
@@ -15,13 +15,7 @@
15
15
  # ==============================================================================
16
16
  import pytest
17
17
 
18
- from sklearnex import config_context, patch_sklearn
19
-
20
- patch_sklearn()
21
-
22
- from sklearn.datasets import make_classification
23
- from sklearn.ensemble import BaggingClassifier
24
- from sklearn.svm import SVC
18
+ from sklearnex import config_context
25
19
 
26
20
  try:
27
21
  import dpctl
@@ -38,7 +32,11 @@ except (ImportError, ModuleNotFoundError):
38
32
  "to see raised 'SyclQueueCreationError'. "
39
33
  "'dpctl' module is required for test.",
40
34
  )
41
- def test_config_context_in_parallel():
35
+ def test_config_context_in_parallel(with_sklearnex):
36
+ from sklearn.datasets import make_classification
37
+ from sklearn.ensemble import BaggingClassifier
38
+ from sklearn.svm import SVC
39
+
42
40
  x, y = make_classification(random_state=42)
43
41
  try:
44
42
  with config_context(target_offload="gpu", allow_fallback_to_host=False):