scikit-learn-intelex 2024.1.0__py311-none-manylinux1_x86_64.whl → 2024.4.0__py311-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (62) hide show
  1. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  2. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  3. sklearnex/__init__.py +9 -7
  4. sklearnex/_device_offload.py +31 -4
  5. sklearnex/basic_statistics/__init__.py +2 -1
  6. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  7. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  8. sklearnex/cluster/dbscan.py +6 -4
  9. sklearnex/conftest.py +63 -0
  10. sklearnex/{preview/decomposition → covariance}/__init__.py +19 -19
  11. sklearnex/covariance/incremental_covariance.py +130 -0
  12. sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  13. sklearnex/decomposition/pca.py +319 -1
  14. sklearnex/decomposition/tests/test_pca.py +34 -5
  15. sklearnex/dispatcher.py +93 -61
  16. sklearnex/ensemble/_forest.py +81 -97
  17. sklearnex/ensemble/tests/test_forest.py +15 -19
  18. sklearnex/linear_model/__init__.py +1 -2
  19. sklearnex/linear_model/linear.py +275 -347
  20. sklearnex/{preview/linear_model → linear_model}/logistic_regression.py +83 -50
  21. sklearnex/linear_model/tests/test_linear.py +40 -5
  22. sklearnex/linear_model/tests/test_logreg.py +70 -7
  23. sklearnex/neighbors/__init__.py +1 -1
  24. sklearnex/neighbors/_lof.py +221 -0
  25. sklearnex/neighbors/common.py +4 -1
  26. sklearnex/neighbors/knn_classification.py +47 -137
  27. sklearnex/neighbors/knn_regression.py +20 -132
  28. sklearnex/neighbors/knn_unsupervised.py +16 -93
  29. sklearnex/neighbors/tests/test_neighbors.py +12 -16
  30. sklearnex/preview/__init__.py +1 -1
  31. sklearnex/preview/cluster/k_means.py +8 -81
  32. sklearnex/preview/covariance/covariance.py +51 -16
  33. sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  34. sklearnex/spmd/__init__.py +1 -0
  35. sklearnex/{preview/linear_model → spmd/covariance}/__init__.py +5 -5
  36. sklearnex/spmd/covariance/covariance.py +21 -0
  37. sklearnex/spmd/ensemble/forest.py +4 -12
  38. sklearnex/spmd/linear_model/__init__.py +2 -1
  39. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  40. sklearnex/svm/_common.py +4 -7
  41. sklearnex/svm/nusvc.py +74 -55
  42. sklearnex/svm/nusvr.py +9 -56
  43. sklearnex/svm/svc.py +74 -56
  44. sklearnex/svm/svr.py +6 -53
  45. sklearnex/tests/_utils.py +164 -0
  46. sklearnex/tests/test_memory_usage.py +9 -7
  47. sklearnex/tests/test_monkeypatch.py +179 -138
  48. sklearnex/tests/test_n_jobs_support.py +77 -9
  49. sklearnex/tests/test_parallel.py +6 -8
  50. sklearnex/tests/test_patching.py +338 -89
  51. sklearnex/utils/__init__.py +2 -1
  52. sklearnex/utils/_namespace.py +97 -0
  53. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  54. sklearnex/neighbors/lof.py +0 -436
  55. sklearnex/preview/decomposition/pca.py +0 -376
  56. sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -42
  57. sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  58. sklearnex/tests/_models_info.py +0 -170
  59. sklearnex/tests/utils/_launch_algorithms.py +0 -118
  60. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  61. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  62. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from inspect import isclass
18
+
19
+ import numpy as np
20
+ from sklearn.base import (
21
+ BaseEstimator,
22
+ ClassifierMixin,
23
+ ClusterMixin,
24
+ OutlierMixin,
25
+ RegressorMixin,
26
+ TransformerMixin,
27
+ )
28
+ from sklearn.datasets import load_diabetes, load_iris
29
+ from sklearn.neighbors._base import KNeighborsMixin
30
+
31
+ from onedal.tests.utils._dataframes_support import _convert_to_dataframe
32
+ from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
33
+ from sklearnex.linear_model import LogisticRegression
34
+ from sklearnex.neighbors import (
35
+ KNeighborsClassifier,
36
+ KNeighborsRegressor,
37
+ LocalOutlierFactor,
38
+ NearestNeighbors,
39
+ )
40
+ from sklearnex.svm import SVC, NuSVC
41
+
42
+
43
+ def _load_all_models(with_sklearnex=True, estimator=True):
44
+ # insure that patch state is correct as dictated by patch_sklearn boolean
45
+ # and return it to the previous state no matter what occurs.
46
+ already_patched_map = sklearn_is_patched(return_map=True)
47
+ already_patched = any(already_patched_map.values())
48
+ try:
49
+ if with_sklearnex:
50
+ patch_sklearn()
51
+ elif already_patched:
52
+ unpatch_sklearn()
53
+
54
+ models = {}
55
+ for patch_infos in get_patch_map().values():
56
+ candidate = getattr(patch_infos[0][0][0], patch_infos[0][0][1], None)
57
+ if candidate is not None and isclass(candidate) == estimator:
58
+ if not estimator or issubclass(candidate, BaseEstimator):
59
+ models[patch_infos[0][0][1]] = candidate
60
+ finally:
61
+ if with_sklearnex:
62
+ unpatch_sklearn()
63
+ # both branches are now in an unpatched state, repatch as necessary
64
+ if already_patched:
65
+ patch_sklearn(name=[i for i in already_patched_map if already_patched_map[i]])
66
+
67
+ return models
68
+
69
+
70
+ PATCHED_MODELS = _load_all_models(with_sklearnex=True)
71
+ UNPATCHED_MODELS = _load_all_models(with_sklearnex=False)
72
+
73
+ PATCHED_FUNCTIONS = _load_all_models(with_sklearnex=True, estimator=False)
74
+ UNPATCHED_FUNCTIONS = _load_all_models(with_sklearnex=False, estimator=False)
75
+
76
+ mixin_map = [
77
+ [
78
+ ClassifierMixin,
79
+ ["decision_function", "predict", "predict_proba", "predict_log_proba", "score"],
80
+ "classification",
81
+ ],
82
+ [RegressorMixin, ["predict", "score"], "regression"],
83
+ [ClusterMixin, ["fit_predict"], "classification"],
84
+ [TransformerMixin, ["fit_transform", "transform", "score"], "classification"],
85
+ [OutlierMixin, ["fit_predict", "predict"], "classification"],
86
+ [KNeighborsMixin, ["kneighbors"], None],
87
+ ]
88
+
89
+
90
+ SPECIAL_INSTANCES = {
91
+ str(i): i
92
+ for i in [
93
+ LocalOutlierFactor(novelty=True),
94
+ SVC(probability=True),
95
+ NuSVC(probability=True),
96
+ KNeighborsClassifier(algorithm="brute"),
97
+ KNeighborsRegressor(algorithm="brute"),
98
+ NearestNeighbors(algorithm="brute"),
99
+ LogisticRegression(solver="newton-cg"),
100
+ ]
101
+ }
102
+
103
+
104
+ def gen_models_info(algorithms):
105
+ output = []
106
+ for i in algorithms:
107
+
108
+ if i in PATCHED_MODELS:
109
+ est = PATCHED_MODELS[i]
110
+ elif i in SPECIAL_INSTANCES:
111
+ est = SPECIAL_INSTANCES[i].__class__
112
+ else:
113
+ raise KeyError(f"Unrecognized sklearnex estimator: {i}")
114
+
115
+ methods = set()
116
+ candidates = set(
117
+ [i for i in dir(est) if not i.startswith("_") and not i.endswith("_")]
118
+ )
119
+
120
+ for mixin, method, _ in mixin_map:
121
+ if issubclass(est, mixin):
122
+ methods |= candidates & set(method)
123
+
124
+ output += [[i, j] for j in methods] if methods else [[i, None]]
125
+
126
+ # In the case that no methods are available, set method to None.
127
+ # This will allow estimators without mixins to still test the fit
128
+ # method in various tests.
129
+ return output
130
+
131
+
132
+ def gen_dataset(estimator, queue=None, target_df=None, dtype=np.float64):
133
+ dataset = None
134
+ name = estimator.__class__.__name__
135
+ est = PATCHED_MODELS[name]
136
+ for mixin, _, data in mixin_map:
137
+ if issubclass(est, mixin) and data is not None:
138
+ dataset = data
139
+ # load data
140
+ if dataset == "classification" or dataset is None:
141
+ X, y = load_iris(return_X_y=True)
142
+ elif dataset == "regression":
143
+ X, y = load_diabetes(return_X_y=True)
144
+ else:
145
+ raise ValueError("Unknown dataset type")
146
+
147
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=target_df, dtype=dtype)
148
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=target_df, dtype=dtype)
149
+ return X, y
150
+
151
+
152
+ DTYPES = [
153
+ np.int8,
154
+ np.int16,
155
+ np.int32,
156
+ np.int64,
157
+ np.float16,
158
+ np.float32,
159
+ np.float64,
160
+ np.uint8,
161
+ np.uint16,
162
+ np.uint32,
163
+ np.uint64,
164
+ ]
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
17
18
  import gc
18
19
  import logging
19
20
  import tracemalloc
@@ -30,7 +31,6 @@ from sklearn.model_selection import KFold
30
31
  from sklearnex import get_patch_map
31
32
  from sklearnex.metrics import pairwise_distances, roc_auc_score
32
33
  from sklearnex.model_selection import train_test_split
33
- from sklearnex.preview.decomposition import PCA as PreviewPCA
34
34
  from sklearnex.utils import _assert_all_finite
35
35
 
36
36
 
@@ -75,6 +75,8 @@ class RocAucEstimator:
75
75
 
76
76
 
77
77
  # add all daal4py estimators enabled in patching (except banned)
78
+
79
+
78
80
  def get_patched_estimators(ban_list, output_list):
79
81
  patched_estimators = get_patch_map().values()
80
82
  for listing in patched_estimators:
@@ -94,12 +96,8 @@ def remove_duplicated_estimators(estimators_list):
94
96
  return estimators_map.values()
95
97
 
96
98
 
97
- BANNED_ESTIMATORS = (
98
- "LocalOutlierFactor", # fails on ndarray_c for sklearn > 1.0
99
- "TSNE", # too slow for using in testing on common data size
100
- )
99
+ BANNED_ESTIMATORS = ("TSNE",) # too slow for using in testing on common data size
101
100
  estimators = [
102
- PreviewPCA,
103
101
  TrainTestSplitEstimator,
104
102
  FiniteCheckEstimator,
105
103
  CosineDistancesEstimator,
@@ -156,6 +154,7 @@ def split_train_inference(kf, x, y, estimator):
156
154
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]
157
155
  # TODO: add parameters for all estimators to prevent
158
156
  # fallback to stock scikit-learn with default parameters
157
+
159
158
  alg = estimator()
160
159
  alg.fit(x_train, y_train)
161
160
  if hasattr(alg, "predict"):
@@ -166,7 +165,6 @@ def split_train_inference(kf, x, y, estimator):
166
165
  alg.kneighbors(x_test)
167
166
  del alg, x_train, x_test, y_train, y_test
168
167
  mem_tracks.append(tracemalloc.get_traced_memory()[0])
169
-
170
168
  return mem_tracks
171
169
 
172
170
 
@@ -218,6 +216,10 @@ def _kfold_function_template(estimator, data_transform_function, data_shape):
218
216
  )
219
217
 
220
218
 
219
+ # disable fallback check as logging impacts memory use
220
+
221
+
222
+ @pytest.mark.allow_sklearn_fallback
221
223
  @pytest.mark.parametrize("data_transform_function", data_transforms)
222
224
  @pytest.mark.parametrize("estimator", estimators)
223
225
  @pytest.mark.parametrize("data_shape", data_shapes)
@@ -17,6 +17,12 @@
17
17
  import sklearnex
18
18
  from daal4py.sklearn._utils import daal_check_version
19
19
 
20
+ # General use of patch_sklearn and unpatch_sklearn in pytest is not recommended.
21
+ # It changes global state and can impact the operation of other tests. This file
22
+ # specifically tests patch_sklearn and unpatch_sklearn and is exempt from this.
23
+ # If sklearnex patching is necessary in testing, use the 'with_sklearnex' pytest
24
+ # fixture.
25
+
20
26
 
21
27
  def test_monkey_patching():
22
28
  _tokens = sklearnex.get_patch_names()
@@ -27,129 +33,170 @@ def test_monkey_patching():
27
33
  for c in v:
28
34
  _classes.append(c[0])
29
35
 
30
- sklearnex.patch_sklearn()
31
-
32
- for i, _ in enumerate(_tokens):
33
- t = _tokens[i]
34
- p = _classes[i][0]
35
- n = _classes[i][1]
36
-
37
- class_module = getattr(p, n).__module__
38
- assert class_module.startswith("daal4py") or class_module.startswith(
39
- "sklearnex"
40
- ), "Patching has completed with error."
41
-
42
- for i, _ in enumerate(_tokens):
43
- t = _tokens[i]
44
- p = _classes[i][0]
45
- n = _classes[i][1]
46
-
47
- sklearnex.unpatch_sklearn(t)
48
- class_module = getattr(p, n).__module__
49
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
50
-
51
- sklearnex.unpatch_sklearn()
52
-
53
- for i, _ in enumerate(_tokens):
54
- t = _tokens[i]
55
- p = _classes[i][0]
56
- n = _classes[i][1]
57
-
58
- class_module = getattr(p, n).__module__
59
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
60
-
61
- sklearnex.unpatch_sklearn()
62
-
63
- for i, _ in enumerate(_tokens):
64
- t = _tokens[i]
65
- p = _classes[i][0]
66
- n = _classes[i][1]
67
-
68
- sklearnex.patch_sklearn(t)
69
-
70
- class_module = getattr(p, n).__module__
71
- assert class_module.startswith("daal4py") or class_module.startswith(
72
- "sklearnex"
73
- ), "Patching has completed with error."
74
-
75
- sklearnex.unpatch_sklearn()
36
+ try:
37
+ sklearnex.patch_sklearn()
38
+
39
+ for i, _ in enumerate(_tokens):
40
+ t = _tokens[i]
41
+ p = _classes[i][0]
42
+ n = _classes[i][1]
43
+
44
+ class_module = getattr(p, n).__module__
45
+ assert class_module.startswith("daal4py") or class_module.startswith(
46
+ "sklearnex"
47
+ ), "Patching has completed with error."
48
+
49
+ for i, _ in enumerate(_tokens):
50
+ t = _tokens[i]
51
+ p = _classes[i][0]
52
+ n = _classes[i][1]
53
+
54
+ sklearnex.unpatch_sklearn(t)
55
+ sklearn_class = getattr(p, n, None)
56
+ if sklearn_class is not None:
57
+ sklearn_class = sklearn_class.__module__
58
+ assert sklearn_class is None or sklearn_class.startswith(
59
+ "sklearn"
60
+ ), "Unpatching has completed with error."
61
+
62
+ finally:
63
+ sklearnex.unpatch_sklearn()
64
+
65
+ try:
66
+ for i, _ in enumerate(_tokens):
67
+ t = _tokens[i]
68
+ p = _classes[i][0]
69
+ n = _classes[i][1]
70
+
71
+ sklearn_class = getattr(p, n, None)
72
+ if sklearn_class is not None:
73
+ sklearn_class = sklearn_class.__module__
74
+ assert sklearn_class is None or sklearn_class.startswith(
75
+ "sklearn"
76
+ ), "Unpatching has completed with error."
77
+
78
+ finally:
79
+ sklearnex.unpatch_sklearn()
80
+
81
+ try:
82
+ for i, _ in enumerate(_tokens):
83
+ t = _tokens[i]
84
+ p = _classes[i][0]
85
+ n = _classes[i][1]
86
+
87
+ sklearnex.patch_sklearn(t)
88
+
89
+ class_module = getattr(p, n).__module__
90
+ assert class_module.startswith("daal4py") or class_module.startswith(
91
+ "sklearnex"
92
+ ), "Patching has completed with error."
93
+ finally:
94
+ sklearnex.unpatch_sklearn()
76
95
 
77
96
 
78
97
  def test_patch_by_list_simple():
79
- sklearnex.patch_sklearn(["LogisticRegression"])
98
+ try:
99
+ sklearnex.patch_sklearn(["LogisticRegression"])
80
100
 
81
- from sklearn.ensemble import RandomForestRegressor
82
- from sklearn.linear_model import LogisticRegression
83
- from sklearn.neighbors import KNeighborsRegressor
84
- from sklearn.svm import SVC
85
-
86
- assert RandomForestRegressor.__module__.startswith("sklearn")
87
- assert KNeighborsRegressor.__module__.startswith("sklearn")
88
- assert LogisticRegression.__module__.startswith("daal4py")
89
- assert SVC.__module__.startswith("sklearn")
101
+ from sklearn.ensemble import RandomForestRegressor
102
+ from sklearn.linear_model import LogisticRegression
103
+ from sklearn.neighbors import KNeighborsRegressor
104
+ from sklearn.svm import SVC
90
105
 
91
- sklearnex.unpatch_sklearn()
106
+ assert RandomForestRegressor.__module__.startswith("sklearn")
107
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
108
+ if daal_check_version((2024, "P", 1)):
109
+ assert LogisticRegression.__module__.startswith("sklearnex")
110
+ else:
111
+ assert LogisticRegression.__module__.startswith("daal4py")
112
+ assert SVC.__module__.startswith("sklearn")
113
+ finally:
114
+ sklearnex.unpatch_sklearn()
92
115
 
93
116
 
94
117
  def test_patch_by_list_many_estimators():
95
- sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
118
+ try:
119
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
96
120
 
97
- from sklearn.ensemble import RandomForestRegressor
98
- from sklearn.linear_model import LogisticRegression
99
- from sklearn.neighbors import KNeighborsRegressor
100
- from sklearn.svm import SVC
121
+ from sklearn.ensemble import RandomForestRegressor
122
+ from sklearn.linear_model import LogisticRegression
123
+ from sklearn.neighbors import KNeighborsRegressor
124
+ from sklearn.svm import SVC
101
125
 
102
- assert RandomForestRegressor.__module__.startswith("sklearn")
103
- assert KNeighborsRegressor.__module__.startswith("sklearn")
104
- assert LogisticRegression.__module__.startswith("daal4py")
105
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
126
+ assert RandomForestRegressor.__module__.startswith("sklearn")
127
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
128
+ if daal_check_version((2024, "P", 1)):
129
+ assert LogisticRegression.__module__.startswith("sklearnex")
130
+ else:
131
+ assert LogisticRegression.__module__.startswith("daal4py")
132
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
133
+ "sklearnex"
134
+ )
106
135
 
107
- sklearnex.unpatch_sklearn()
136
+ finally:
137
+ sklearnex.unpatch_sklearn()
108
138
 
109
139
 
110
140
  def test_unpatch_by_list_many_estimators():
111
- sklearnex.patch_sklearn()
141
+ try:
142
+ sklearnex.patch_sklearn()
143
+
144
+ from sklearn.ensemble import RandomForestRegressor
145
+ from sklearn.linear_model import LogisticRegression
146
+ from sklearn.neighbors import KNeighborsRegressor
147
+ from sklearn.svm import SVC
112
148
 
113
- from sklearn.ensemble import RandomForestRegressor
114
- from sklearn.linear_model import LogisticRegression
115
- from sklearn.neighbors import KNeighborsRegressor
116
- from sklearn.svm import SVC
149
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
150
+ assert KNeighborsRegressor.__module__.startswith(
151
+ "daal4py"
152
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
153
+ if daal_check_version((2024, "P", 1)):
154
+ assert LogisticRegression.__module__.startswith("sklearnex")
155
+ else:
156
+ assert LogisticRegression.__module__.startswith("daal4py")
157
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
158
+ "sklearnex"
159
+ )
117
160
 
118
- assert RandomForestRegressor.__module__.startswith("sklearnex")
119
- assert KNeighborsRegressor.__module__.startswith(
120
- "daal4py"
121
- ) or KNeighborsRegressor.__module__.startswith("sklearnex")
122
- assert LogisticRegression.__module__.startswith("daal4py")
123
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
161
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
124
162
 
125
- sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
163
+ from sklearn.ensemble import RandomForestRegressor
164
+ from sklearn.linear_model import LogisticRegression
165
+ from sklearn.neighbors import KNeighborsRegressor
166
+ from sklearn.svm import SVC
126
167
 
127
- from sklearn.ensemble import RandomForestRegressor
128
- from sklearn.linear_model import LogisticRegression
129
- from sklearn.neighbors import KNeighborsRegressor
130
- from sklearn.svm import SVC
168
+ assert RandomForestRegressor.__module__.startswith("sklearn")
169
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
170
+ if daal_check_version((2024, "P", 1)):
171
+ assert LogisticRegression.__module__.startswith("sklearnex")
172
+ else:
173
+ assert LogisticRegression.__module__.startswith("daal4py")
131
174
 
132
- assert RandomForestRegressor.__module__.startswith("sklearn")
133
- assert KNeighborsRegressor.__module__.startswith("sklearn")
134
- assert LogisticRegression.__module__.startswith("daal4py")
135
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
175
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
176
+ "sklearnex"
177
+ )
178
+ finally:
179
+ sklearnex.unpatch_sklearn()
136
180
 
137
181
 
138
182
  def test_patching_checker():
139
183
  for name in [None, "SVC", "PCA"]:
140
- sklearnex.patch_sklearn(name=name)
141
- assert sklearnex.sklearn_is_patched(name=name)
142
-
143
- sklearnex.unpatch_sklearn(name=name)
144
- assert not sklearnex.sklearn_is_patched(name=name)
145
-
146
- sklearnex.patch_sklearn()
147
- patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
148
- assert len(patching_status_map) == len(sklearnex.get_patch_names())
149
- for status in patching_status_map.values():
150
- assert status
184
+ try:
185
+ sklearnex.patch_sklearn(name=name)
186
+ assert sklearnex.sklearn_is_patched(name=name)
187
+
188
+ finally:
189
+ sklearnex.unpatch_sklearn(name=name)
190
+ assert not sklearnex.sklearn_is_patched(name=name)
191
+ try:
192
+ sklearnex.patch_sklearn()
193
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
194
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
195
+ for status in patching_status_map.values():
196
+ assert status
197
+ finally:
198
+ sklearnex.unpatch_sklearn()
151
199
 
152
- sklearnex.unpatch_sklearn()
153
200
  patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
154
201
  assert len(patching_status_map) == len(sklearnex.get_patch_names())
155
202
  for status in patching_status_map.values():
@@ -161,67 +208,61 @@ def test_preview_namespace():
161
208
  from sklearn.cluster import DBSCAN
162
209
  from sklearn.decomposition import PCA
163
210
  from sklearn.ensemble import RandomForestClassifier
164
- from sklearn.linear_model import LinearRegression, LogisticRegression
211
+ from sklearn.linear_model import LinearRegression
165
212
  from sklearn.svm import SVC
166
213
 
167
214
  return (
168
215
  LinearRegression(),
169
- LogisticRegression(),
170
216
  PCA(),
171
217
  DBSCAN(),
172
218
  SVC(),
173
219
  RandomForestClassifier(),
174
220
  )
175
221
 
176
- # BUG: previous patching tests force PCA to be patched with daal4py.
177
- # This unpatching returns behavior to expected
178
- sklearnex.unpatch_sklearn()
179
- # behavior with enabled preview
180
- sklearnex.patch_sklearn(preview=True)
181
222
  from sklearnex.dispatcher import _is_preview_enabled
182
223
 
183
- assert _is_preview_enabled()
224
+ try:
225
+ sklearnex.patch_sklearn(preview=True)
226
+
227
+ assert _is_preview_enabled()
184
228
 
185
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
186
- assert "sklearnex" in rfc.__module__
229
+ lr, pca, dbscan, svc, rfc = get_estimators()
230
+ assert "sklearnex" in rfc.__module__
187
231
 
188
- if daal_check_version((2023, "P", 100)):
189
- assert "sklearnex" in lr.__module__
190
- else:
191
- assert "daal4py" in lr.__module__
232
+ if daal_check_version((2023, "P", 100)):
233
+ assert "sklearnex" in lr.__module__
234
+ else:
235
+ assert "daal4py" in lr.__module__
192
236
 
193
- if daal_check_version((2024, "P", 1)):
194
- assert "sklearnex" in log_reg.__module__
195
- else:
196
- assert "daal4py" in log_reg.__module__
237
+ assert "sklearnex" in pca.__module__
238
+ assert "sklearnex" in dbscan.__module__
239
+ assert "sklearnex" in svc.__module__
197
240
 
198
- assert "sklearnex.preview" in pca.__module__
199
- assert "sklearnex" in dbscan.__module__
200
- assert "sklearnex" in svc.__module__
201
- sklearnex.unpatch_sklearn()
241
+ finally:
242
+ sklearnex.unpatch_sklearn()
202
243
 
203
244
  # no patching behavior
204
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
245
+ lr, pca, dbscan, svc, rfc = get_estimators()
205
246
  assert "sklearn." in lr.__module__ and "daal4py" not in lr.__module__
206
- assert "sklearn." in log_reg.__module__ and "daal4py" not in log_reg.__module__
207
247
  assert "sklearn." in pca.__module__ and "daal4py" not in pca.__module__
208
248
  assert "sklearn." in dbscan.__module__ and "daal4py" not in dbscan.__module__
209
249
  assert "sklearn." in svc.__module__ and "daal4py" not in svc.__module__
210
250
  assert "sklearn." in rfc.__module__ and "daal4py" not in rfc.__module__
211
251
 
212
252
  # default patching behavior
213
- sklearnex.patch_sklearn()
214
- assert not _is_preview_enabled()
215
-
216
- lr, log_reg, pca, dbscan, svc, rfc = get_estimators()
217
- if daal_check_version((2023, "P", 100)):
218
- assert "sklearnex" in lr.__module__
219
- else:
220
- assert "daal4py" in lr.__module__
221
-
222
- assert "daal4py" in log_reg.__module__
223
- assert "daal4py" in pca.__module__
224
- assert "sklearnex" in rfc.__module__
225
- assert "sklearnex" in dbscan.__module__
226
- assert "sklearnex" in svc.__module__
227
- sklearnex.unpatch_sklearn()
253
+ try:
254
+ sklearnex.patch_sklearn()
255
+ assert not _is_preview_enabled()
256
+
257
+ lr, pca, dbscan, svc, rfc = get_estimators()
258
+ if daal_check_version((2023, "P", 100)):
259
+ assert "sklearnex" in lr.__module__
260
+ else:
261
+ assert "daal4py" in lr.__module__
262
+
263
+ assert "sklearnex" in pca.__module__
264
+ assert "sklearnex" in rfc.__module__
265
+ assert "sklearnex" in dbscan.__module__
266
+ assert "sklearnex" in svc.__module__
267
+ finally:
268
+ sklearnex.unpatch_sklearn()