scikit-learn-intelex 2024.0.1__py311-none-manylinux1_x86_64.whl → 2024.4.0__py311-none-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (89) hide show
  1. {scikit_learn_intelex-2024.0.1.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  2. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  3. sklearnex/__init__.py +11 -7
  4. sklearnex/__main__.py +0 -1
  5. sklearnex/_device_offload.py +31 -4
  6. sklearnex/_utils.py +15 -1
  7. sklearnex/basic_statistics/__init__.py +2 -2
  8. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  9. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  10. sklearnex/cluster/__init__.py +0 -1
  11. sklearnex/cluster/dbscan.py +5 -2
  12. sklearnex/cluster/k_means.py +0 -1
  13. sklearnex/cluster/tests/test_dbscan.py +0 -1
  14. sklearnex/cluster/tests/test_kmeans.py +0 -1
  15. sklearnex/conftest.py +63 -0
  16. sklearnex/covariance/__init__.py +19 -0
  17. sklearnex/covariance/incremental_covariance.py +130 -0
  18. sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  19. sklearnex/decomposition/__init__.py +0 -1
  20. sklearnex/decomposition/pca.py +319 -2
  21. sklearnex/decomposition/tests/test_pca.py +34 -6
  22. sklearnex/dispatcher.py +93 -28
  23. sklearnex/ensemble/__init__.py +0 -1
  24. sklearnex/ensemble/_forest.py +93 -89
  25. sklearnex/ensemble/tests/test_forest.py +15 -20
  26. sklearnex/glob/__main__.py +0 -1
  27. sklearnex/glob/dispatcher.py +0 -1
  28. sklearnex/linear_model/__init__.py +1 -3
  29. sklearnex/linear_model/coordinate_descent.py +0 -1
  30. sklearnex/linear_model/linear.py +275 -332
  31. sklearnex/linear_model/logistic_path.py +0 -1
  32. sklearnex/linear_model/logistic_regression.py +385 -0
  33. sklearnex/linear_model/ridge.py +0 -1
  34. sklearnex/linear_model/tests/test_linear.py +47 -7
  35. sklearnex/linear_model/tests/test_logreg.py +70 -8
  36. sklearnex/manifold/__init__.py +0 -1
  37. sklearnex/manifold/t_sne.py +0 -1
  38. sklearnex/manifold/tests/test_tsne.py +0 -1
  39. sklearnex/metrics/__init__.py +0 -1
  40. sklearnex/metrics/pairwise.py +0 -1
  41. sklearnex/metrics/ranking.py +0 -1
  42. sklearnex/metrics/tests/test_metrics.py +0 -1
  43. sklearnex/model_selection/__init__.py +0 -1
  44. sklearnex/model_selection/split.py +0 -1
  45. sklearnex/model_selection/tests/test_model_selection.py +0 -1
  46. sklearnex/neighbors/__init__.py +1 -2
  47. sklearnex/neighbors/_lof.py +221 -0
  48. sklearnex/neighbors/common.py +5 -3
  49. sklearnex/neighbors/knn_classification.py +47 -133
  50. sklearnex/neighbors/knn_regression.py +20 -129
  51. sklearnex/neighbors/knn_unsupervised.py +15 -89
  52. sklearnex/neighbors/tests/test_neighbors.py +12 -17
  53. sklearnex/preview/__init__.py +1 -2
  54. sklearnex/preview/cluster/__init__.py +0 -1
  55. sklearnex/preview/cluster/k_means.py +7 -74
  56. sklearnex/preview/{decomposition → covariance}/__init__.py +19 -20
  57. sklearnex/preview/covariance/covariance.py +133 -0
  58. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  59. sklearnex/spmd/__init__.py +1 -0
  60. sklearnex/spmd/covariance/__init__.py +19 -0
  61. sklearnex/spmd/covariance/covariance.py +21 -0
  62. sklearnex/spmd/ensemble/forest.py +4 -12
  63. sklearnex/spmd/linear_model/__init__.py +2 -1
  64. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  65. sklearnex/svm/__init__.py +0 -1
  66. sklearnex/svm/_common.py +4 -7
  67. sklearnex/svm/nusvc.py +73 -49
  68. sklearnex/svm/nusvr.py +8 -52
  69. sklearnex/svm/svc.py +74 -51
  70. sklearnex/svm/svr.py +5 -49
  71. sklearnex/svm/tests/test_svm.py +0 -1
  72. sklearnex/tests/_utils.py +164 -0
  73. sklearnex/tests/test_memory_usage.py +9 -7
  74. sklearnex/tests/test_monkeypatch.py +192 -134
  75. sklearnex/tests/test_n_jobs_support.py +99 -0
  76. sklearnex/tests/test_parallel.py +6 -8
  77. sklearnex/tests/test_patching.py +338 -89
  78. sklearnex/utils/__init__.py +2 -1
  79. sklearnex/utils/_namespace.py +97 -0
  80. sklearnex/utils/validation.py +0 -1
  81. scikit_learn_intelex-2024.0.1.dist-info/RECORD +0 -90
  82. sklearnex/neighbors/lof.py +0 -437
  83. sklearnex/preview/decomposition/pca.py +0 -376
  84. sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -38
  85. sklearnex/tests/_models_info.py +0 -170
  86. sklearnex/tests/utils/_launch_algorithms.py +0 -118
  87. {scikit_learn_intelex-2024.0.1.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  88. {scikit_learn_intelex-2024.0.1.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  89. {scikit_learn_intelex-2024.0.1.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from inspect import isclass
18
+
19
+ import numpy as np
20
+ from sklearn.base import (
21
+ BaseEstimator,
22
+ ClassifierMixin,
23
+ ClusterMixin,
24
+ OutlierMixin,
25
+ RegressorMixin,
26
+ TransformerMixin,
27
+ )
28
+ from sklearn.datasets import load_diabetes, load_iris
29
+ from sklearn.neighbors._base import KNeighborsMixin
30
+
31
+ from onedal.tests.utils._dataframes_support import _convert_to_dataframe
32
+ from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
33
+ from sklearnex.linear_model import LogisticRegression
34
+ from sklearnex.neighbors import (
35
+ KNeighborsClassifier,
36
+ KNeighborsRegressor,
37
+ LocalOutlierFactor,
38
+ NearestNeighbors,
39
+ )
40
+ from sklearnex.svm import SVC, NuSVC
41
+
42
+
43
+ def _load_all_models(with_sklearnex=True, estimator=True):
44
+ # insure that patch state is correct as dictated by patch_sklearn boolean
45
+ # and return it to the previous state no matter what occurs.
46
+ already_patched_map = sklearn_is_patched(return_map=True)
47
+ already_patched = any(already_patched_map.values())
48
+ try:
49
+ if with_sklearnex:
50
+ patch_sklearn()
51
+ elif already_patched:
52
+ unpatch_sklearn()
53
+
54
+ models = {}
55
+ for patch_infos in get_patch_map().values():
56
+ candidate = getattr(patch_infos[0][0][0], patch_infos[0][0][1], None)
57
+ if candidate is not None and isclass(candidate) == estimator:
58
+ if not estimator or issubclass(candidate, BaseEstimator):
59
+ models[patch_infos[0][0][1]] = candidate
60
+ finally:
61
+ if with_sklearnex:
62
+ unpatch_sklearn()
63
+ # both branches are now in an unpatched state, repatch as necessary
64
+ if already_patched:
65
+ patch_sklearn(name=[i for i in already_patched_map if already_patched_map[i]])
66
+
67
+ return models
68
+
69
+
70
+ PATCHED_MODELS = _load_all_models(with_sklearnex=True)
71
+ UNPATCHED_MODELS = _load_all_models(with_sklearnex=False)
72
+
73
+ PATCHED_FUNCTIONS = _load_all_models(with_sklearnex=True, estimator=False)
74
+ UNPATCHED_FUNCTIONS = _load_all_models(with_sklearnex=False, estimator=False)
75
+
76
+ mixin_map = [
77
+ [
78
+ ClassifierMixin,
79
+ ["decision_function", "predict", "predict_proba", "predict_log_proba", "score"],
80
+ "classification",
81
+ ],
82
+ [RegressorMixin, ["predict", "score"], "regression"],
83
+ [ClusterMixin, ["fit_predict"], "classification"],
84
+ [TransformerMixin, ["fit_transform", "transform", "score"], "classification"],
85
+ [OutlierMixin, ["fit_predict", "predict"], "classification"],
86
+ [KNeighborsMixin, ["kneighbors"], None],
87
+ ]
88
+
89
+
90
+ SPECIAL_INSTANCES = {
91
+ str(i): i
92
+ for i in [
93
+ LocalOutlierFactor(novelty=True),
94
+ SVC(probability=True),
95
+ NuSVC(probability=True),
96
+ KNeighborsClassifier(algorithm="brute"),
97
+ KNeighborsRegressor(algorithm="brute"),
98
+ NearestNeighbors(algorithm="brute"),
99
+ LogisticRegression(solver="newton-cg"),
100
+ ]
101
+ }
102
+
103
+
104
+ def gen_models_info(algorithms):
105
+ output = []
106
+ for i in algorithms:
107
+
108
+ if i in PATCHED_MODELS:
109
+ est = PATCHED_MODELS[i]
110
+ elif i in SPECIAL_INSTANCES:
111
+ est = SPECIAL_INSTANCES[i].__class__
112
+ else:
113
+ raise KeyError(f"Unrecognized sklearnex estimator: {i}")
114
+
115
+ methods = set()
116
+ candidates = set(
117
+ [i for i in dir(est) if not i.startswith("_") and not i.endswith("_")]
118
+ )
119
+
120
+ for mixin, method, _ in mixin_map:
121
+ if issubclass(est, mixin):
122
+ methods |= candidates & set(method)
123
+
124
+ output += [[i, j] for j in methods] if methods else [[i, None]]
125
+
126
+ # In the case that no methods are available, set method to None.
127
+ # This will allow estimators without mixins to still test the fit
128
+ # method in various tests.
129
+ return output
130
+
131
+
132
+ def gen_dataset(estimator, queue=None, target_df=None, dtype=np.float64):
133
+ dataset = None
134
+ name = estimator.__class__.__name__
135
+ est = PATCHED_MODELS[name]
136
+ for mixin, _, data in mixin_map:
137
+ if issubclass(est, mixin) and data is not None:
138
+ dataset = data
139
+ # load data
140
+ if dataset == "classification" or dataset is None:
141
+ X, y = load_iris(return_X_y=True)
142
+ elif dataset == "regression":
143
+ X, y = load_diabetes(return_X_y=True)
144
+ else:
145
+ raise ValueError("Unknown dataset type")
146
+
147
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=target_df, dtype=dtype)
148
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=target_df, dtype=dtype)
149
+ return X, y
150
+
151
+
152
+ DTYPES = [
153
+ np.int8,
154
+ np.int16,
155
+ np.int32,
156
+ np.int64,
157
+ np.float16,
158
+ np.float32,
159
+ np.float64,
160
+ np.uint8,
161
+ np.uint16,
162
+ np.uint32,
163
+ np.uint64,
164
+ ]
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
17
18
  import gc
18
19
  import logging
19
20
  import tracemalloc
@@ -30,7 +31,6 @@ from sklearn.model_selection import KFold
30
31
  from sklearnex import get_patch_map
31
32
  from sklearnex.metrics import pairwise_distances, roc_auc_score
32
33
  from sklearnex.model_selection import train_test_split
33
- from sklearnex.preview.decomposition import PCA as PreviewPCA
34
34
  from sklearnex.utils import _assert_all_finite
35
35
 
36
36
 
@@ -75,6 +75,8 @@ class RocAucEstimator:
75
75
 
76
76
 
77
77
  # add all daal4py estimators enabled in patching (except banned)
78
+
79
+
78
80
  def get_patched_estimators(ban_list, output_list):
79
81
  patched_estimators = get_patch_map().values()
80
82
  for listing in patched_estimators:
@@ -94,12 +96,8 @@ def remove_duplicated_estimators(estimators_list):
94
96
  return estimators_map.values()
95
97
 
96
98
 
97
- BANNED_ESTIMATORS = (
98
- "LocalOutlierFactor", # fails on ndarray_c for sklearn > 1.0
99
- "TSNE", # too slow for using in testing on common data size
100
- )
99
+ BANNED_ESTIMATORS = ("TSNE",) # too slow for using in testing on common data size
101
100
  estimators = [
102
- PreviewPCA,
103
101
  TrainTestSplitEstimator,
104
102
  FiniteCheckEstimator,
105
103
  CosineDistancesEstimator,
@@ -156,6 +154,7 @@ def split_train_inference(kf, x, y, estimator):
156
154
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]
157
155
  # TODO: add parameters for all estimators to prevent
158
156
  # fallback to stock scikit-learn with default parameters
157
+
159
158
  alg = estimator()
160
159
  alg.fit(x_train, y_train)
161
160
  if hasattr(alg, "predict"):
@@ -166,7 +165,6 @@ def split_train_inference(kf, x, y, estimator):
166
165
  alg.kneighbors(x_test)
167
166
  del alg, x_train, x_test, y_train, y_test
168
167
  mem_tracks.append(tracemalloc.get_traced_memory()[0])
169
-
170
168
  return mem_tracks
171
169
 
172
170
 
@@ -218,6 +216,10 @@ def _kfold_function_template(estimator, data_transform_function, data_shape):
218
216
  )
219
217
 
220
218
 
219
+ # disable fallback check as logging impacts memory use
220
+
221
+
222
+ @pytest.mark.allow_sklearn_fallback
221
223
  @pytest.mark.parametrize("data_transform_function", data_transforms)
222
224
  @pytest.mark.parametrize("estimator", estimators)
223
225
  @pytest.mark.parametrize("data_shape", data_shapes)
@@ -17,6 +17,12 @@
17
17
  import sklearnex
18
18
  from daal4py.sklearn._utils import daal_check_version
19
19
 
20
+ # General use of patch_sklearn and unpatch_sklearn in pytest is not recommended.
21
+ # It changes global state and can impact the operation of other tests. This file
22
+ # specifically tests patch_sklearn and unpatch_sklearn and is exempt from this.
23
+ # If sklearnex patching is necessary in testing, use the 'with_sklearnex' pytest
24
+ # fixture.
25
+
20
26
 
21
27
  def test_monkey_patching():
22
28
  _tokens = sklearnex.get_patch_names()
@@ -27,129 +33,170 @@ def test_monkey_patching():
27
33
  for c in v:
28
34
  _classes.append(c[0])
29
35
 
30
- sklearnex.patch_sklearn()
31
-
32
- for i, _ in enumerate(_tokens):
33
- t = _tokens[i]
34
- p = _classes[i][0]
35
- n = _classes[i][1]
36
-
37
- class_module = getattr(p, n).__module__
38
- assert class_module.startswith("daal4py") or class_module.startswith(
39
- "sklearnex"
40
- ), "Patching has completed with error."
41
-
42
- for i, _ in enumerate(_tokens):
43
- t = _tokens[i]
44
- p = _classes[i][0]
45
- n = _classes[i][1]
46
-
47
- sklearnex.unpatch_sklearn(t)
48
- class_module = getattr(p, n).__module__
49
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
50
-
51
- sklearnex.unpatch_sklearn()
52
-
53
- for i, _ in enumerate(_tokens):
54
- t = _tokens[i]
55
- p = _classes[i][0]
56
- n = _classes[i][1]
57
-
58
- class_module = getattr(p, n).__module__
59
- assert class_module.startswith("sklearn"), "Unpatching has completed with error."
60
-
61
- sklearnex.unpatch_sklearn()
62
-
63
- for i, _ in enumerate(_tokens):
64
- t = _tokens[i]
65
- p = _classes[i][0]
66
- n = _classes[i][1]
67
-
68
- sklearnex.patch_sklearn(t)
69
-
70
- class_module = getattr(p, n).__module__
71
- assert class_module.startswith("daal4py") or class_module.startswith(
72
- "sklearnex"
73
- ), "Patching has completed with error."
74
-
75
- sklearnex.unpatch_sklearn()
36
+ try:
37
+ sklearnex.patch_sklearn()
38
+
39
+ for i, _ in enumerate(_tokens):
40
+ t = _tokens[i]
41
+ p = _classes[i][0]
42
+ n = _classes[i][1]
43
+
44
+ class_module = getattr(p, n).__module__
45
+ assert class_module.startswith("daal4py") or class_module.startswith(
46
+ "sklearnex"
47
+ ), "Patching has completed with error."
48
+
49
+ for i, _ in enumerate(_tokens):
50
+ t = _tokens[i]
51
+ p = _classes[i][0]
52
+ n = _classes[i][1]
53
+
54
+ sklearnex.unpatch_sklearn(t)
55
+ sklearn_class = getattr(p, n, None)
56
+ if sklearn_class is not None:
57
+ sklearn_class = sklearn_class.__module__
58
+ assert sklearn_class is None or sklearn_class.startswith(
59
+ "sklearn"
60
+ ), "Unpatching has completed with error."
61
+
62
+ finally:
63
+ sklearnex.unpatch_sklearn()
64
+
65
+ try:
66
+ for i, _ in enumerate(_tokens):
67
+ t = _tokens[i]
68
+ p = _classes[i][0]
69
+ n = _classes[i][1]
70
+
71
+ sklearn_class = getattr(p, n, None)
72
+ if sklearn_class is not None:
73
+ sklearn_class = sklearn_class.__module__
74
+ assert sklearn_class is None or sklearn_class.startswith(
75
+ "sklearn"
76
+ ), "Unpatching has completed with error."
77
+
78
+ finally:
79
+ sklearnex.unpatch_sklearn()
80
+
81
+ try:
82
+ for i, _ in enumerate(_tokens):
83
+ t = _tokens[i]
84
+ p = _classes[i][0]
85
+ n = _classes[i][1]
86
+
87
+ sklearnex.patch_sklearn(t)
88
+
89
+ class_module = getattr(p, n).__module__
90
+ assert class_module.startswith("daal4py") or class_module.startswith(
91
+ "sklearnex"
92
+ ), "Patching has completed with error."
93
+ finally:
94
+ sklearnex.unpatch_sklearn()
76
95
 
77
96
 
78
97
  def test_patch_by_list_simple():
79
- sklearnex.patch_sklearn(["LogisticRegression"])
80
-
81
- from sklearn.ensemble import RandomForestRegressor
82
- from sklearn.linear_model import LogisticRegression
83
- from sklearn.neighbors import KNeighborsRegressor
84
- from sklearn.svm import SVC
98
+ try:
99
+ sklearnex.patch_sklearn(["LogisticRegression"])
85
100
 
86
- assert RandomForestRegressor.__module__.startswith("sklearn")
87
- assert KNeighborsRegressor.__module__.startswith("sklearn")
88
- assert LogisticRegression.__module__.startswith("daal4py")
89
- assert SVC.__module__.startswith("sklearn")
101
+ from sklearn.ensemble import RandomForestRegressor
102
+ from sklearn.linear_model import LogisticRegression
103
+ from sklearn.neighbors import KNeighborsRegressor
104
+ from sklearn.svm import SVC
90
105
 
91
- sklearnex.unpatch_sklearn()
106
+ assert RandomForestRegressor.__module__.startswith("sklearn")
107
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
108
+ if daal_check_version((2024, "P", 1)):
109
+ assert LogisticRegression.__module__.startswith("sklearnex")
110
+ else:
111
+ assert LogisticRegression.__module__.startswith("daal4py")
112
+ assert SVC.__module__.startswith("sklearn")
113
+ finally:
114
+ sklearnex.unpatch_sklearn()
92
115
 
93
116
 
94
117
  def test_patch_by_list_many_estimators():
95
- sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
118
+ try:
119
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
96
120
 
97
- from sklearn.ensemble import RandomForestRegressor
98
- from sklearn.linear_model import LogisticRegression
99
- from sklearn.neighbors import KNeighborsRegressor
100
- from sklearn.svm import SVC
121
+ from sklearn.ensemble import RandomForestRegressor
122
+ from sklearn.linear_model import LogisticRegression
123
+ from sklearn.neighbors import KNeighborsRegressor
124
+ from sklearn.svm import SVC
101
125
 
102
- assert RandomForestRegressor.__module__.startswith("sklearn")
103
- assert KNeighborsRegressor.__module__.startswith("sklearn")
104
- assert LogisticRegression.__module__.startswith("daal4py")
105
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
126
+ assert RandomForestRegressor.__module__.startswith("sklearn")
127
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
128
+ if daal_check_version((2024, "P", 1)):
129
+ assert LogisticRegression.__module__.startswith("sklearnex")
130
+ else:
131
+ assert LogisticRegression.__module__.startswith("daal4py")
132
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
133
+ "sklearnex"
134
+ )
106
135
 
107
- sklearnex.unpatch_sklearn()
136
+ finally:
137
+ sklearnex.unpatch_sklearn()
108
138
 
109
139
 
110
140
  def test_unpatch_by_list_many_estimators():
111
- sklearnex.patch_sklearn()
141
+ try:
142
+ sklearnex.patch_sklearn()
112
143
 
113
- from sklearn.ensemble import RandomForestRegressor
114
- from sklearn.linear_model import LogisticRegression
115
- from sklearn.neighbors import KNeighborsRegressor
116
- from sklearn.svm import SVC
144
+ from sklearn.ensemble import RandomForestRegressor
145
+ from sklearn.linear_model import LogisticRegression
146
+ from sklearn.neighbors import KNeighborsRegressor
147
+ from sklearn.svm import SVC
117
148
 
118
- assert RandomForestRegressor.__module__.startswith("sklearnex")
119
- assert KNeighborsRegressor.__module__.startswith(
120
- "daal4py"
121
- ) or KNeighborsRegressor.__module__.startswith("sklearnex")
122
- assert LogisticRegression.__module__.startswith("daal4py")
123
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
149
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
150
+ assert KNeighborsRegressor.__module__.startswith(
151
+ "daal4py"
152
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
153
+ if daal_check_version((2024, "P", 1)):
154
+ assert LogisticRegression.__module__.startswith("sklearnex")
155
+ else:
156
+ assert LogisticRegression.__module__.startswith("daal4py")
157
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
158
+ "sklearnex"
159
+ )
124
160
 
125
- sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
161
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
126
162
 
127
- from sklearn.ensemble import RandomForestRegressor
128
- from sklearn.linear_model import LogisticRegression
129
- from sklearn.neighbors import KNeighborsRegressor
130
- from sklearn.svm import SVC
163
+ from sklearn.ensemble import RandomForestRegressor
164
+ from sklearn.linear_model import LogisticRegression
165
+ from sklearn.neighbors import KNeighborsRegressor
166
+ from sklearn.svm import SVC
131
167
 
132
- assert RandomForestRegressor.__module__.startswith("sklearn")
133
- assert KNeighborsRegressor.__module__.startswith("sklearn")
134
- assert LogisticRegression.__module__.startswith("daal4py")
135
- assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith("sklearnex")
168
+ assert RandomForestRegressor.__module__.startswith("sklearn")
169
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
170
+ if daal_check_version((2024, "P", 1)):
171
+ assert LogisticRegression.__module__.startswith("sklearnex")
172
+ else:
173
+ assert LogisticRegression.__module__.startswith("daal4py")
174
+
175
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
176
+ "sklearnex"
177
+ )
178
+ finally:
179
+ sklearnex.unpatch_sklearn()
136
180
 
137
181
 
138
182
  def test_patching_checker():
139
183
  for name in [None, "SVC", "PCA"]:
140
- sklearnex.patch_sklearn(name=name)
141
- assert sklearnex.sklearn_is_patched(name=name)
142
-
143
- sklearnex.unpatch_sklearn(name=name)
144
- assert not sklearnex.sklearn_is_patched(name=name)
184
+ try:
185
+ sklearnex.patch_sklearn(name=name)
186
+ assert sklearnex.sklearn_is_patched(name=name)
187
+
188
+ finally:
189
+ sklearnex.unpatch_sklearn(name=name)
190
+ assert not sklearnex.sklearn_is_patched(name=name)
191
+ try:
192
+ sklearnex.patch_sklearn()
193
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
194
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
195
+ for status in patching_status_map.values():
196
+ assert status
197
+ finally:
198
+ sklearnex.unpatch_sklearn()
145
199
 
146
- sklearnex.patch_sklearn()
147
- patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
148
- assert len(patching_status_map) == len(sklearnex.get_patch_names())
149
- for status in patching_status_map.values():
150
- assert status
151
-
152
- sklearnex.unpatch_sklearn()
153
200
  patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
154
201
  assert len(patching_status_map) == len(sklearnex.get_patch_names())
155
202
  for status in patching_status_map.values():
@@ -164,47 +211,58 @@ def test_preview_namespace():
164
211
  from sklearn.linear_model import LinearRegression
165
212
  from sklearn.svm import SVC
166
213
 
167
- return LinearRegression(), PCA(), DBSCAN(), SVC(), RandomForestClassifier()
214
+ return (
215
+ LinearRegression(),
216
+ PCA(),
217
+ DBSCAN(),
218
+ SVC(),
219
+ RandomForestClassifier(),
220
+ )
168
221
 
169
- # BUG: previous patching tests force PCA to be patched with daal4py.
170
- # This unpatching returns behavior to expected
171
- sklearnex.unpatch_sklearn()
172
- # behavior with enabled preview
173
- sklearnex.patch_sklearn(preview=True)
174
- assert sklearnex.dispatcher._is_preview_enabled()
222
+ from sklearnex.dispatcher import _is_preview_enabled
175
223
 
176
- lr, pca, dbscan, svc, rfc = get_estimators()
177
- assert "sklearnex" in rfc.__module__
224
+ try:
225
+ sklearnex.patch_sklearn(preview=True)
226
+
227
+ assert _is_preview_enabled()
178
228
 
179
- if daal_check_version((2023, "P", 100)):
180
- assert "sklearnex" in lr.__module__
181
- else:
182
- assert "daal4py" in lr.__module__
229
+ lr, pca, dbscan, svc, rfc = get_estimators()
230
+ assert "sklearnex" in rfc.__module__
183
231
 
184
- assert "sklearnex.preview" in pca.__module__
185
- assert "sklearnex" in dbscan.__module__
186
- assert "sklearnex" in svc.__module__
187
- sklearnex.unpatch_sklearn()
232
+ if daal_check_version((2023, "P", 100)):
233
+ assert "sklearnex" in lr.__module__
234
+ else:
235
+ assert "daal4py" in lr.__module__
236
+
237
+ assert "sklearnex" in pca.__module__
238
+ assert "sklearnex" in dbscan.__module__
239
+ assert "sklearnex" in svc.__module__
240
+
241
+ finally:
242
+ sklearnex.unpatch_sklearn()
188
243
 
189
244
  # no patching behavior
190
245
  lr, pca, dbscan, svc, rfc = get_estimators()
191
- assert "sklearn." in lr.__module__
192
- assert "sklearn." in pca.__module__
193
- assert "sklearn." in dbscan.__module__
194
- assert "sklearn." in svc.__module__
195
- assert "sklearn." in rfc.__module__
246
+ assert "sklearn." in lr.__module__ and "daal4py" not in lr.__module__
247
+ assert "sklearn." in pca.__module__ and "daal4py" not in pca.__module__
248
+ assert "sklearn." in dbscan.__module__ and "daal4py" not in dbscan.__module__
249
+ assert "sklearn." in svc.__module__ and "daal4py" not in svc.__module__
250
+ assert "sklearn." in rfc.__module__ and "daal4py" not in rfc.__module__
196
251
 
197
252
  # default patching behavior
198
- sklearnex.patch_sklearn()
199
- assert not sklearnex.dispatcher._is_preview_enabled()
200
-
201
- lr, pca, dbscan, svc, rfc = get_estimators()
202
- if daal_check_version((2023, "P", 100)):
203
- assert "sklearnex" in lr.__module__
204
- else:
205
- assert "daal4py" in lr.__module__
206
- assert "daal4py" in pca.__module__
207
- assert "sklearnex" in rfc.__module__
208
- assert "sklearnex" in dbscan.__module__
209
- assert "sklearnex" in svc.__module__
210
- sklearnex.unpatch_sklearn()
253
+ try:
254
+ sklearnex.patch_sklearn()
255
+ assert not _is_preview_enabled()
256
+
257
+ lr, pca, dbscan, svc, rfc = get_estimators()
258
+ if daal_check_version((2023, "P", 100)):
259
+ assert "sklearnex" in lr.__module__
260
+ else:
261
+ assert "daal4py" in lr.__module__
262
+
263
+ assert "sklearnex" in pca.__module__
264
+ assert "sklearnex" in rfc.__module__
265
+ assert "sklearnex" in dbscan.__module__
266
+ assert "sklearnex" in svc.__module__
267
+ finally:
268
+ sklearnex.unpatch_sklearn()