scikit-learn-intelex 2024.2.0__py312-none-win_amd64.whl → 2024.3.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (107) hide show
  1. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  3. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  4. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +338 -0
  5. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  6. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +72 -41
  7. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +10 -14
  8. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  9. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +13 -2
  10. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -2
  11. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +39 -2
  12. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +7 -9
  13. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +6 -9
  14. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +5 -8
  15. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -5
  16. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  17. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  18. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +4 -0
  19. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +4 -0
  20. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +155 -0
  21. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +8 -3
  22. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  23. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  24. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +361 -0
  25. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  26. scikit_learn_intelex-2024.3.0.dist-info/RECORD +98 -0
  27. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  28. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  29. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  30. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
  31. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  32. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
  33. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -136
  34. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  35. scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
  36. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  37. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  38. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  39. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  40. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  41. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  42. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  43. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  44. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  45. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  46. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  47. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  48. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  49. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  51. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  52. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  53. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  54. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  55. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  56. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -0
  57. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  58. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  59. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  60. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  61. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  62. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  63. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  65. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  66. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  67. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  69. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  70. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  72. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  73. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  74. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  75. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  76. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  77. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  78. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  81. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  83. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  84. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  86. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  87. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  88. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  91. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  92. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  94. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  96. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -3
  97. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
  98. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  99. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  100. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
  101. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  102. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  104. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  105. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  106. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  107. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
17
18
  import gc
18
19
  import logging
19
20
  import tracemalloc
@@ -30,7 +31,6 @@ from sklearn.model_selection import KFold
30
31
  from sklearnex import get_patch_map
31
32
  from sklearnex.metrics import pairwise_distances, roc_auc_score
32
33
  from sklearnex.model_selection import train_test_split
33
- from sklearnex.preview.decomposition import PCA as PreviewPCA
34
34
  from sklearnex.utils import _assert_all_finite
35
35
 
36
36
 
@@ -75,6 +75,8 @@ class RocAucEstimator:
75
75
 
76
76
 
77
77
  # add all daal4py estimators enabled in patching (except banned)
78
+
79
+
78
80
  def get_patched_estimators(ban_list, output_list):
79
81
  patched_estimators = get_patch_map().values()
80
82
  for listing in patched_estimators:
@@ -96,7 +98,6 @@ def remove_duplicated_estimators(estimators_list):
96
98
 
97
99
  BANNED_ESTIMATORS = ("TSNE",) # too slow for using in testing on common data size
98
100
  estimators = [
99
- PreviewPCA,
100
101
  TrainTestSplitEstimator,
101
102
  FiniteCheckEstimator,
102
103
  CosineDistancesEstimator,
@@ -153,6 +154,7 @@ def split_train_inference(kf, x, y, estimator):
153
154
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]
154
155
  # TODO: add parameters for all estimators to prevent
155
156
  # fallback to stock scikit-learn with default parameters
157
+
156
158
  alg = estimator()
157
159
  alg.fit(x_train, y_train)
158
160
  if hasattr(alg, "predict"):
@@ -163,7 +165,6 @@ def split_train_inference(kf, x, y, estimator):
163
165
  alg.kneighbors(x_test)
164
166
  del alg, x_train, x_test, y_train, y_test
165
167
  mem_tracks.append(tracemalloc.get_traced_memory()[0])
166
-
167
168
  return mem_tracks
168
169
 
169
170
 
@@ -215,6 +216,10 @@ def _kfold_function_template(estimator, data_transform_function, data_shape):
215
216
  )
216
217
 
217
218
 
219
+ # disable fallback check as logging impacts memory use
220
+
221
+
222
+ @pytest.mark.allow_sklearn_fallback
218
223
  @pytest.mark.parametrize("data_transform_function", data_transforms)
219
224
  @pytest.mark.parametrize("estimator", estimators)
220
225
  @pytest.mark.parametrize("data_shape", data_shapes)
@@ -0,0 +1,268 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import sklearnex
18
+ from daal4py.sklearn._utils import daal_check_version
19
+
20
+ # General use of patch_sklearn and unpatch_sklearn in pytest is not recommended.
21
+ # It changes global state and can impact the operation of other tests. This file
22
+ # specifically tests patch_sklearn and unpatch_sklearn and is exempt from this.
23
+ # If sklearnex patching is necessary in testing, use the 'with_sklearnex' pytest
24
+ # fixture.
25
+
26
+
27
+ def test_monkey_patching():
28
+ _tokens = sklearnex.get_patch_names()
29
+ _values = sklearnex.get_patch_map().values()
30
+ _classes = list()
31
+
32
+ for v in _values:
33
+ for c in v:
34
+ _classes.append(c[0])
35
+
36
+ try:
37
+ sklearnex.patch_sklearn()
38
+
39
+ for i, _ in enumerate(_tokens):
40
+ t = _tokens[i]
41
+ p = _classes[i][0]
42
+ n = _classes[i][1]
43
+
44
+ class_module = getattr(p, n).__module__
45
+ assert class_module.startswith("daal4py") or class_module.startswith(
46
+ "sklearnex"
47
+ ), "Patching has completed with error."
48
+
49
+ for i, _ in enumerate(_tokens):
50
+ t = _tokens[i]
51
+ p = _classes[i][0]
52
+ n = _classes[i][1]
53
+
54
+ sklearnex.unpatch_sklearn(t)
55
+ sklearn_class = getattr(p, n, None)
56
+ if sklearn_class is not None:
57
+ sklearn_class = sklearn_class.__module__
58
+ assert sklearn_class is None or sklearn_class.startswith(
59
+ "sklearn"
60
+ ), "Unpatching has completed with error."
61
+
62
+ finally:
63
+ sklearnex.unpatch_sklearn()
64
+
65
+ try:
66
+ for i, _ in enumerate(_tokens):
67
+ t = _tokens[i]
68
+ p = _classes[i][0]
69
+ n = _classes[i][1]
70
+
71
+ sklearn_class = getattr(p, n, None)
72
+ if sklearn_class is not None:
73
+ sklearn_class = sklearn_class.__module__
74
+ assert sklearn_class is None or sklearn_class.startswith(
75
+ "sklearn"
76
+ ), "Unpatching has completed with error."
77
+
78
+ finally:
79
+ sklearnex.unpatch_sklearn()
80
+
81
+ try:
82
+ for i, _ in enumerate(_tokens):
83
+ t = _tokens[i]
84
+ p = _classes[i][0]
85
+ n = _classes[i][1]
86
+
87
+ sklearnex.patch_sklearn(t)
88
+
89
+ class_module = getattr(p, n).__module__
90
+ assert class_module.startswith("daal4py") or class_module.startswith(
91
+ "sklearnex"
92
+ ), "Patching has completed with error."
93
+ finally:
94
+ sklearnex.unpatch_sklearn()
95
+
96
+
97
+ def test_patch_by_list_simple():
98
+ try:
99
+ sklearnex.patch_sklearn(["LogisticRegression"])
100
+
101
+ from sklearn.ensemble import RandomForestRegressor
102
+ from sklearn.linear_model import LogisticRegression
103
+ from sklearn.neighbors import KNeighborsRegressor
104
+ from sklearn.svm import SVC
105
+
106
+ assert RandomForestRegressor.__module__.startswith("sklearn")
107
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
108
+ if daal_check_version((2024, "P", 1)):
109
+ assert LogisticRegression.__module__.startswith("sklearnex")
110
+ else:
111
+ assert LogisticRegression.__module__.startswith("daal4py")
112
+ assert SVC.__module__.startswith("sklearn")
113
+ finally:
114
+ sklearnex.unpatch_sklearn()
115
+
116
+
117
+ def test_patch_by_list_many_estimators():
118
+ try:
119
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
120
+
121
+ from sklearn.ensemble import RandomForestRegressor
122
+ from sklearn.linear_model import LogisticRegression
123
+ from sklearn.neighbors import KNeighborsRegressor
124
+ from sklearn.svm import SVC
125
+
126
+ assert RandomForestRegressor.__module__.startswith("sklearn")
127
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
128
+ if daal_check_version((2024, "P", 1)):
129
+ assert LogisticRegression.__module__.startswith("sklearnex")
130
+ else:
131
+ assert LogisticRegression.__module__.startswith("daal4py")
132
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
133
+ "sklearnex"
134
+ )
135
+
136
+ finally:
137
+ sklearnex.unpatch_sklearn()
138
+
139
+
140
+ def test_unpatch_by_list_many_estimators():
141
+ try:
142
+ sklearnex.patch_sklearn()
143
+
144
+ from sklearn.ensemble import RandomForestRegressor
145
+ from sklearn.linear_model import LogisticRegression
146
+ from sklearn.neighbors import KNeighborsRegressor
147
+ from sklearn.svm import SVC
148
+
149
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
150
+ assert KNeighborsRegressor.__module__.startswith(
151
+ "daal4py"
152
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
153
+ if daal_check_version((2024, "P", 1)):
154
+ assert LogisticRegression.__module__.startswith("sklearnex")
155
+ else:
156
+ assert LogisticRegression.__module__.startswith("daal4py")
157
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
158
+ "sklearnex"
159
+ )
160
+
161
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
162
+
163
+ from sklearn.ensemble import RandomForestRegressor
164
+ from sklearn.linear_model import LogisticRegression
165
+ from sklearn.neighbors import KNeighborsRegressor
166
+ from sklearn.svm import SVC
167
+
168
+ assert RandomForestRegressor.__module__.startswith("sklearn")
169
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
170
+ if daal_check_version((2024, "P", 1)):
171
+ assert LogisticRegression.__module__.startswith("sklearnex")
172
+ else:
173
+ assert LogisticRegression.__module__.startswith("daal4py")
174
+
175
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
176
+ "sklearnex"
177
+ )
178
+ finally:
179
+ sklearnex.unpatch_sklearn()
180
+
181
+
182
+ def test_patching_checker():
183
+ for name in [None, "SVC", "PCA"]:
184
+ try:
185
+ sklearnex.patch_sklearn(name=name)
186
+ assert sklearnex.sklearn_is_patched(name=name)
187
+
188
+ finally:
189
+ sklearnex.unpatch_sklearn(name=name)
190
+ assert not sklearnex.sklearn_is_patched(name=name)
191
+ try:
192
+ sklearnex.patch_sklearn()
193
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
194
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
195
+ for status in patching_status_map.values():
196
+ assert status
197
+ finally:
198
+ sklearnex.unpatch_sklearn()
199
+
200
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
201
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
202
+ for status in patching_status_map.values():
203
+ assert not status
204
+
205
+
206
+ def test_preview_namespace():
207
+ def get_estimators():
208
+ from sklearn.cluster import DBSCAN
209
+ from sklearn.decomposition import PCA
210
+ from sklearn.ensemble import RandomForestClassifier
211
+ from sklearn.linear_model import LinearRegression
212
+ from sklearn.svm import SVC
213
+
214
+ return (
215
+ LinearRegression(),
216
+ PCA(),
217
+ DBSCAN(),
218
+ SVC(),
219
+ RandomForestClassifier(),
220
+ )
221
+
222
+ from sklearnex.dispatcher import _is_preview_enabled
223
+
224
+ try:
225
+ sklearnex.patch_sklearn(preview=True)
226
+
227
+ assert _is_preview_enabled()
228
+
229
+ lr, pca, dbscan, svc, rfc = get_estimators()
230
+ assert "sklearnex" in rfc.__module__
231
+
232
+ if daal_check_version((2023, "P", 100)):
233
+ assert "sklearnex" in lr.__module__
234
+ else:
235
+ assert "daal4py" in lr.__module__
236
+
237
+ assert "sklearnex" in pca.__module__
238
+ assert "sklearnex" in dbscan.__module__
239
+ assert "sklearnex" in svc.__module__
240
+
241
+ finally:
242
+ sklearnex.unpatch_sklearn()
243
+
244
+ # no patching behavior
245
+ lr, pca, dbscan, svc, rfc = get_estimators()
246
+ assert "sklearn." in lr.__module__ and "daal4py" not in lr.__module__
247
+ assert "sklearn." in pca.__module__ and "daal4py" not in pca.__module__
248
+ assert "sklearn." in dbscan.__module__ and "daal4py" not in dbscan.__module__
249
+ assert "sklearn." in svc.__module__ and "daal4py" not in svc.__module__
250
+ assert "sklearn." in rfc.__module__ and "daal4py" not in rfc.__module__
251
+
252
+ # default patching behavior
253
+ try:
254
+ sklearnex.patch_sklearn()
255
+ assert not _is_preview_enabled()
256
+
257
+ lr, pca, dbscan, svc, rfc = get_estimators()
258
+ if daal_check_version((2023, "P", 100)):
259
+ assert "sklearnex" in lr.__module__
260
+ else:
261
+ assert "daal4py" in lr.__module__
262
+
263
+ assert "sklearnex" in pca.__module__
264
+ assert "sklearnex" in rfc.__module__
265
+ assert "sklearnex" in dbscan.__module__
266
+ assert "sklearnex" in svc.__module__
267
+ finally:
268
+ sklearnex.unpatch_sklearn()
@@ -15,13 +15,7 @@
15
15
  # ==============================================================================
16
16
  import pytest
17
17
 
18
- from sklearnex import config_context, patch_sklearn
19
-
20
- patch_sklearn()
21
-
22
- from sklearn.datasets import make_classification
23
- from sklearn.ensemble import BaggingClassifier
24
- from sklearn.svm import SVC
18
+ from sklearnex import config_context
25
19
 
26
20
  try:
27
21
  import dpctl
@@ -38,7 +32,11 @@ except (ImportError, ModuleNotFoundError):
38
32
  "to see raised 'SyclQueueCreationError'. "
39
33
  "'dpctl' module is required for test.",
40
34
  )
41
- def test_config_context_in_parallel():
35
+ def test_config_context_in_parallel(with_sklearnex):
36
+ from sklearn.datasets import make_classification
37
+ from sklearn.ensemble import BaggingClassifier
38
+ from sklearn.svm import SVC
39
+
42
40
  x, y = make_classification(random_state=42)
43
41
  try:
44
42
  with config_context(target_offload="gpu", allow_fallback_to_host=False):