scikit-learn-intelex 2024.2.0__py312-none-win_amd64.whl → 2024.3.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (107) hide show
  1. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  3. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  4. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +338 -0
  5. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  6. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +72 -41
  7. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +10 -14
  8. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  9. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +13 -2
  10. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -2
  11. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +39 -2
  12. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +7 -9
  13. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +6 -9
  14. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +5 -8
  15. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -5
  16. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  17. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  18. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +4 -0
  19. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +4 -0
  20. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +155 -0
  21. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +8 -3
  22. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  23. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  24. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +361 -0
  25. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  26. scikit_learn_intelex-2024.3.0.dist-info/RECORD +98 -0
  27. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  28. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  29. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  30. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
  31. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  32. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
  33. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -136
  34. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  35. scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
  36. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  37. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  38. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  39. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  40. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  41. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  42. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  43. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  44. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  45. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  46. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  47. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  48. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  49. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  51. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  52. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  53. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  54. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  55. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  56. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -0
  57. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  58. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  59. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  60. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  61. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  62. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  63. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  65. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  66. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  67. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  69. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  70. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  72. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  73. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  74. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  75. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  76. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  77. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  78. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  81. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  83. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  84. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  86. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  87. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  88. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  91. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  92. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  94. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  96. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -3
  97. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
  98. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  99. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  100. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
  101. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  102. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  104. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  105. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  106. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  107. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # ==============================================================================
2
2
  # Copyright 2021 Intel Corporation
3
+ # Copyright 2024 Fujitsu Limited
3
4
  #
4
5
  # Licensed under the Apache License, Version 2.0 (the "License");
5
6
  # you may not use this file except in compliance with the License.
@@ -14,7 +15,7 @@
14
15
  # limitations under the License.
15
16
  # ==============================================================================
16
17
 
17
- from onedal.common.hyperparameters import get_hyperparameters
18
+ import os
18
19
 
19
20
  from . import utils
20
21
  from ._config import config_context, get_config, set_config
@@ -41,21 +42,22 @@ __all__ = [
41
42
  "linear_model",
42
43
  "manifold",
43
44
  "metrics",
45
+ "model_selection",
44
46
  "neighbors",
45
47
  "patch_sklearn",
46
48
  "set_config",
47
49
  "sklearn_is_patched",
48
- "sklearn_is_patchedget_patch_map",
49
50
  "svm",
50
51
  "unpatch_sklearn",
51
52
  "utils",
52
53
  ]
54
+ onedal_iface_flag = os.environ.get("OFF_ONEDAL_IFACE", "0")
55
+ if onedal_iface_flag == "0":
56
+ from onedal import _is_spmd_backend
57
+ from onedal.common.hyperparameters import get_hyperparameters
53
58
 
54
-
55
- from onedal import _is_dpc_backend
56
-
57
- if _is_dpc_backend:
58
- __all__.append("spmd")
59
+ if _is_spmd_backend:
60
+ __all__.append("spmd")
59
61
 
60
62
 
61
63
  from ._utils import set_sklearn_ex_verbose
@@ -26,7 +26,7 @@ from daal4py.sklearn._n_jobs_support import control_n_jobs
26
26
  from daal4py.sklearn._utils import sklearn_check_version
27
27
  from onedal.cluster import DBSCAN as onedal_DBSCAN
28
28
 
29
- from .._device_offload import dispatch, wrap_output_data
29
+ from .._device_offload import dispatch
30
30
  from .._utils import PatchingConditionsChain
31
31
 
32
32
  if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
@@ -186,3 +186,5 @@ class DBSCAN(sklearn_DBSCAN, BaseDBSCAN):
186
186
  )
187
187
 
188
188
  return self
189
+
190
+ fit.__doc__ = sklearn_DBSCAN.fit.__doc__
@@ -0,0 +1,63 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import io
18
+ import logging
19
+
20
+ import pytest
21
+
22
+ from sklearnex import patch_sklearn, unpatch_sklearn
23
+
24
+
25
+ def pytest_configure(config):
26
+ config.addinivalue_line(
27
+ "markers", "allow_sklearn_fallback: mark test to not check for sklearnex usage"
28
+ )
29
+
30
+
31
+ @pytest.hookimpl(hookwrapper=True)
32
+ def pytest_runtest_call(item):
33
+ # setup logger to check for sklearn fallback
34
+ if not item.get_closest_marker("allow_sklearn_fallback"):
35
+ log_stream = io.StringIO()
36
+ log_handler = logging.StreamHandler(log_stream)
37
+ sklearnex_logger = logging.getLogger("sklearnex")
38
+ level = sklearnex_logger.level
39
+ sklearnex_stderr_handler = sklearnex_logger.handlers
40
+ sklearnex_logger.handlers = []
41
+ sklearnex_logger.addHandler(log_handler)
42
+ sklearnex_logger.setLevel(logging.INFO)
43
+ log_handler.setLevel(logging.INFO)
44
+
45
+ yield
46
+
47
+ sklearnex_logger.handlers = sklearnex_stderr_handler
48
+ sklearnex_logger.setLevel(level)
49
+ sklearnex_logger.removeHandler(log_handler)
50
+ text = log_stream.getvalue()
51
+ if "fallback to original Scikit-learn" in text:
52
+ raise TypeError(
53
+ f"test did not properly evaluate sklearnex functionality and fell back to sklearn:\n{text}"
54
+ )
55
+ else:
56
+ yield
57
+
58
+
59
+ @pytest.fixture
60
+ def with_sklearnex():
61
+ patch_sklearn()
62
+ yield
63
+ unpatch_sklearn()
@@ -0,0 +1,338 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import logging
18
+
19
+ from daal4py.sklearn._utils import daal_check_version
20
+
21
+ if daal_check_version((2024, "P", 100)):
22
+ import numbers
23
+ from math import sqrt
24
+
25
+ import numpy as np
26
+ from scipy.sparse import issparse
27
+ from sklearn.utils.validation import check_is_fitted
28
+
29
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
30
+ from daal4py.sklearn._utils import sklearn_check_version
31
+
32
+ from .._device_offload import dispatch, wrap_output_data
33
+ from .._utils import PatchingConditionsChain
34
+
35
+ if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
36
+ from sklearn.utils import check_scalar
37
+
38
+ from sklearn.decomposition import PCA as sklearn_PCA
39
+
40
+ from onedal.decomposition import PCA as onedal_PCA
41
+
42
+ @control_n_jobs(decorated_methods=["fit", "transform", "fit_transform"])
43
+ class PCA(sklearn_PCA):
44
+ __doc__ = sklearn_PCA.__doc__
45
+
46
+ if sklearn_check_version("1.2"):
47
+ _parameter_constraints: dict = {**sklearn_PCA._parameter_constraints}
48
+
49
+ if sklearn_check_version("1.1"):
50
+
51
+ def __init__(
52
+ self,
53
+ n_components=None,
54
+ *,
55
+ copy=True,
56
+ whiten=False,
57
+ svd_solver="auto",
58
+ tol=0.0,
59
+ iterated_power="auto",
60
+ n_oversamples=10,
61
+ power_iteration_normalizer="auto",
62
+ random_state=None,
63
+ ):
64
+ self.n_components = n_components
65
+ self.copy = copy
66
+ self.whiten = whiten
67
+ self.svd_solver = svd_solver
68
+ self.tol = tol
69
+ self.iterated_power = iterated_power
70
+ self.n_oversamples = n_oversamples
71
+ self.power_iteration_normalizer = power_iteration_normalizer
72
+ self.random_state = random_state
73
+
74
+ else:
75
+
76
+ def __init__(
77
+ self,
78
+ n_components=None,
79
+ copy=True,
80
+ whiten=False,
81
+ svd_solver="auto",
82
+ tol=0.0,
83
+ iterated_power="auto",
84
+ random_state=None,
85
+ ):
86
+ self.n_components = n_components
87
+ self.copy = copy
88
+ self.whiten = whiten
89
+ self.svd_solver = svd_solver
90
+ self.tol = tol
91
+ self.iterated_power = iterated_power
92
+ self.random_state = random_state
93
+
94
+ def fit(self, X, y=None):
95
+ self._fit(X)
96
+ return self
97
+
98
+ def _fit(self, X):
99
+ if sklearn_check_version("1.2"):
100
+ self._validate_params()
101
+ elif sklearn_check_version("1.1"):
102
+ check_scalar(
103
+ self.n_oversamples,
104
+ "n_oversamples",
105
+ min_val=1,
106
+ target_type=numbers.Integral,
107
+ )
108
+
109
+ U, S, Vt = dispatch(
110
+ self,
111
+ "fit",
112
+ {
113
+ "onedal": self.__class__._onedal_fit,
114
+ "sklearn": sklearn_PCA._fit,
115
+ },
116
+ X,
117
+ )
118
+ return U, S, Vt
119
+
120
+ def _onedal_fit(self, X, queue=None):
121
+ X = self._validate_data(
122
+ X,
123
+ dtype=[np.float64, np.float32],
124
+ ensure_2d=True,
125
+ copy=self.copy,
126
+ )
127
+
128
+ onedal_params = {
129
+ "n_components": self.n_components,
130
+ "is_deterministic": True,
131
+ "method": "cov",
132
+ "whiten": self.whiten,
133
+ }
134
+ self._onedal_estimator = onedal_PCA(**onedal_params)
135
+ self._onedal_estimator.fit(X, queue=queue)
136
+ self._save_attributes()
137
+
138
+ U = None
139
+ S = self.singular_values_
140
+ Vt = self.components_
141
+
142
+ return U, S, Vt
143
+
144
+ @wrap_output_data
145
+ def transform(self, X):
146
+ return dispatch(
147
+ self,
148
+ "transform",
149
+ {
150
+ "onedal": self.__class__._onedal_transform,
151
+ "sklearn": sklearn_PCA.transform,
152
+ },
153
+ X,
154
+ )
155
+
156
+ def _onedal_transform(self, X, queue=None):
157
+ check_is_fitted(self)
158
+ X = self._validate_data(
159
+ X,
160
+ dtype=[np.float64, np.float32],
161
+ reset=False,
162
+ )
163
+ self._validate_n_features_in_after_fitting(X)
164
+ if sklearn_check_version("1.0"):
165
+ self._check_feature_names(X, reset=False)
166
+
167
+ return self._onedal_estimator.predict(X, queue=queue)
168
+
169
+ @wrap_output_data
170
+ def fit_transform(self, X, y=None):
171
+ U, S, Vt = self._fit(X)
172
+ if U is None:
173
+ # oneDAL PCA was fit
174
+ X_transformed = self._onedal_transform(X)
175
+ return X_transformed
176
+ else:
177
+ # Scikit-learn PCA was fit
178
+ U = U[:, : self.n_components_]
179
+
180
+ if self.whiten:
181
+ U *= sqrt(X.shape[0] - 1)
182
+ else:
183
+ U *= S[: self.n_components_]
184
+
185
+ return U
186
+
187
+ def _onedal_supported(self, method_name, X):
188
+ class_name = self.__class__.__name__
189
+ patching_status = PatchingConditionsChain(
190
+ f"sklearn.decomposition.{class_name}.{method_name}"
191
+ )
192
+
193
+ if method_name == "fit":
194
+ shape_tuple, _is_shape_compatible = self._get_shape_compatibility(X)
195
+ patching_status.and_conditions(
196
+ [
197
+ (
198
+ _is_shape_compatible,
199
+ "Data shape is not compatible.",
200
+ ),
201
+ (
202
+ self._is_solver_compatible_with_onedal(shape_tuple),
203
+ f"Only 'full' svd solver is supported.",
204
+ ),
205
+ (not issparse(X), "oneDAL PCA does not support sparse data"),
206
+ ]
207
+ )
208
+ return patching_status
209
+
210
+ if method_name == "transform":
211
+ patching_status.and_conditions(
212
+ [
213
+ (
214
+ hasattr(self, "_onedal_estimator"),
215
+ "oneDAL model was not trained",
216
+ ),
217
+ ]
218
+ )
219
+ return patching_status
220
+
221
+ raise RuntimeError(
222
+ f"Unknown method {method_name} in {self.__class__.__name__}"
223
+ )
224
+
225
+ def _onedal_cpu_supported(self, method_name, *data):
226
+ return self._onedal_supported(method_name, *data)
227
+
228
+ def _onedal_gpu_supported(self, method_name, *data):
229
+ return self._onedal_supported(method_name, *data)
230
+
231
+ def _get_shape_compatibility(self, X):
232
+ _is_shape_compatible = False
233
+ _empty_shape = (0, 0)
234
+ if hasattr(X, "shape"):
235
+ shape_tuple = X.shape
236
+ if len(shape_tuple) == 1:
237
+ shape_tuple = (1, shape_tuple[0])
238
+ elif isinstance(X, list):
239
+ if np.ndim(X) == 1:
240
+ shape_tuple = (1, len(X))
241
+ elif np.ndim(X) == 2:
242
+ shape_tuple = (len(X), len(X[0]))
243
+ else:
244
+ return _empty_shape, _is_shape_compatible
245
+
246
+ if shape_tuple[0] > 0 and shape_tuple[1] > 0 and len(shape_tuple) == 2:
247
+ _is_shape_compatible = shape_tuple[1] / shape_tuple[0] < 2
248
+
249
+ return shape_tuple, _is_shape_compatible
250
+
251
+ def _is_solver_compatible_with_onedal(self, shape_tuple):
252
+ self._fit_svd_solver = self.svd_solver
253
+ n_sf_min = min(shape_tuple)
254
+ n_components = n_sf_min if self.n_components is None else self.n_components
255
+
256
+ if self._fit_svd_solver == "auto":
257
+ if sklearn_check_version("1.1"):
258
+ if max(shape_tuple) <= 500 or n_components == "mle":
259
+ self._fit_svd_solver = "full"
260
+ elif 1 <= n_components < 0.8 * n_sf_min:
261
+ self._fit_svd_solver = "randomized"
262
+ else:
263
+ self._fit_svd_solver = "full"
264
+ else:
265
+ if n_components == "mle":
266
+ self._fit_svd_solver = "full"
267
+ else:
268
+ # check if sklearnex is faster than randomized sklearn
269
+ # Refer to daal4py
270
+ regression_coefs = np.array(
271
+ [
272
+ [
273
+ 9.779873e-11,
274
+ shape_tuple[0] * shape_tuple[1] * n_components,
275
+ ],
276
+ [
277
+ -1.122062e-11,
278
+ shape_tuple[0] * shape_tuple[1] * shape_tuple[1],
279
+ ],
280
+ [1.127905e-09, shape_tuple[0] ** 2],
281
+ ]
282
+ )
283
+ if (
284
+ n_components >= 1
285
+ and np.dot(regression_coefs[:, 0], regression_coefs[:, 1])
286
+ <= 0
287
+ ):
288
+ self._fit_svd_solver = "randomized"
289
+ else:
290
+ self._fit_svd_solver = "full"
291
+
292
+ if self._fit_svd_solver == "full":
293
+ return True
294
+ else:
295
+ return False
296
+
297
+ def _save_attributes(self):
298
+ self.n_samples_ = self._onedal_estimator.n_samples_
299
+ if sklearn_check_version("1.2"):
300
+ self.n_features_in_ = self._onedal_estimator.n_features_
301
+ elif sklearn_check_version("0.24"):
302
+ self.n_features_ = self._onedal_estimator.n_features_
303
+ self.n_features_in_ = self._onedal_estimator.n_features_
304
+ else:
305
+ self.n_features_ = self._onedal_estimator.n_features_
306
+ self.n_components_ = self._onedal_estimator.n_components_
307
+ self.components_ = self._onedal_estimator.components_
308
+ self.mean_ = self._onedal_estimator.mean_
309
+ self.singular_values_ = self._onedal_estimator.singular_values_
310
+ self.explained_variance_ = self._onedal_estimator.explained_variance_.ravel()
311
+ self.explained_variance_ratio_ = (
312
+ self._onedal_estimator.explained_variance_ratio_
313
+ )
314
+ self.noise_variance_ = self._onedal_estimator.noise_variance_
315
+
316
+ def _validate_n_features_in_after_fitting(self, X):
317
+ if sklearn_check_version("1.2"):
318
+ expected_n_features = self.n_features_in_
319
+ else:
320
+ expected_n_features = self.n_features_
321
+ if X.shape[1] != expected_n_features:
322
+ raise ValueError(
323
+ (
324
+ f"X has {X.shape[1]} features, "
325
+ f"but PCA is expecting {expected_n_features} features as input"
326
+ )
327
+ )
328
+
329
+ fit.__doc__ = sklearn_PCA.fit.__doc__
330
+ transform.__doc__ = sklearn_PCA.transform.__doc__
331
+ fit_transform.__doc__ = sklearn_PCA.fit_transform.__doc__
332
+
333
+ else:
334
+ from daal4py.sklearn.decomposition import PCA
335
+
336
+ logging.warning(
337
+ "Sklearnex PCA requires oneDAL version >= 2024.1.0 but it was not found"
338
+ )
@@ -27,16 +27,30 @@ from onedal.tests.utils._dataframes_support import (
27
27
 
28
28
 
29
29
  @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
30
- @pytest.mark.parametrize("macro_block", [None, 1024])
31
- def test_sklearnex_import(dataframe, queue, macro_block):
32
- from sklearnex.preview.decomposition import PCA
30
+ def test_sklearnex_import(dataframe, queue):
31
+ from sklearnex.decomposition import PCA
33
32
 
34
33
  X = [[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]
35
34
  X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
35
+ X_transformed_expected = [
36
+ [-1.38340578, -0.2935787],
37
+ [-2.22189802, 0.25133484],
38
+ [-3.6053038, -0.04224385],
39
+ [1.38340578, 0.2935787],
40
+ [2.22189802, -0.25133484],
41
+ [3.6053038, 0.04224385],
42
+ ]
43
+
36
44
  pca = PCA(n_components=2, svd_solver="full")
37
- if daal_check_version((2024, "P", 0)) and macro_block is not None:
38
- pca.get_hyperparameters("fit").cpu_macro_block = macro_block
39
45
  pca.fit(X)
40
- assert "sklearnex" in pca.__module__
41
- assert hasattr(pca, "_onedal_estimator")
42
- assert_allclose(_as_numpy(pca.singular_values_), [6.30061232, 0.54980396])
46
+ X_transformed = pca.transform(X)
47
+ X_fit_transformed = PCA(n_components=2, svd_solver="full").fit_transform(X)
48
+
49
+ if daal_check_version((2024, "P", 100)):
50
+ assert "sklearnex" in pca.__module__
51
+ assert hasattr(pca, "_onedal_estimator")
52
+ else:
53
+ assert "daal4py" in pca.__module__
54
+ assert_allclose([6.30061232, 0.54980396], _as_numpy(pca.singular_values_))
55
+ assert_allclose(X_transformed_expected, _as_numpy(X_transformed))
56
+ assert_allclose(X_transformed_expected, _as_numpy(X_fit_transformed))
@@ -1,5 +1,6 @@
1
1
  # ==============================================================================
2
2
  # Copyright 2021 Intel Corporation
3
+ # Copyright 2024 Fujitsu Limited
3
4
  #
4
5
  # Licensed under the Apache License, Version 2.0 (the "License");
5
6
  # you may not use this file except in compliance with the License.
@@ -22,7 +23,7 @@ from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
22
23
 
23
24
 
24
25
  def _is_new_patching_available():
25
- return os.environ.get("OFF_ONEDAL_IFACE") is None and daal_check_version(
26
+ return os.environ.get("OFF_ONEDAL_IFACE", "0") == "0" and daal_check_version(
26
27
  (2021, "P", 300)
27
28
  )
28
29
 
@@ -32,16 +33,66 @@ def _is_preview_enabled():
32
33
 
33
34
 
34
35
  @lru_cache(maxsize=None)
35
- def get_patch_map():
36
+ def get_patch_map_core(preview=False):
37
+ if preview:
38
+ # use recursion to guarantee that state of preview
39
+ # and non-preview maps are done at the same time.
40
+ # The two lru_cache dicts are actually one underneath.
41
+ # Preview is always secondary. Both sklearnex patch
42
+ # maps are referring to the daal4py dict unless the
43
+ # key has been replaced. Use with caution.
44
+ mapping = get_patch_map_core().copy()
45
+
46
+ if _is_new_patching_available():
47
+ import sklearn.covariance as covariance_module
48
+
49
+ # Preview classes for patching
50
+ from .preview.cluster import KMeans as KMeans_sklearnex
51
+ from .preview.covariance import (
52
+ EmpiricalCovariance as EmpiricalCovariance_sklearnex,
53
+ )
54
+
55
+ # Since the state of the lru_cache without preview cannot be
56
+ # guaranteed to not have already enabled sklearnex algorithms
57
+ # when preview is used, setting the mapping element[1] to None
58
+ # should NOT be done. This may lose track of the unpatched
59
+ # sklearn estimator or function.
60
+ # KMeans
61
+ cluster_module, _, _ = mapping["kmeans"][0][0]
62
+ sklearn_obj = mapping["kmeans"][0][1]
63
+ mapping.pop("kmeans")
64
+ mapping["kmeans"] = [
65
+ [(cluster_module, "kmeans", KMeans_sklearnex), sklearn_obj]
66
+ ]
67
+
68
+ # Covariance
69
+ mapping["empiricalcovariance"] = [
70
+ [
71
+ (
72
+ covariance_module,
73
+ "EmpiricalCovariance",
74
+ EmpiricalCovariance_sklearnex,
75
+ ),
76
+ None,
77
+ ]
78
+ ]
79
+ return mapping
80
+
36
81
  from daal4py.sklearn.monkeypatch.dispatcher import _get_map_of_algorithms
37
82
 
83
+ # NOTE: this is a shallow copy of a dict, modification is dangerous
38
84
  mapping = _get_map_of_algorithms().copy()
39
85
 
86
+ # NOTE: Use of daal4py _get_map_of_algorithms and
87
+ # get_patch_map/get_patch_map_core should not be used concurrently.
88
+ # The setting of elements to None below may cause loss of state
89
+ # when interacting with sklearn. A dictionary key must not be
90
+ # modified but totally replaced, otherwise it will cause chaos.
91
+ # Hence why pop is being used.
40
92
  if _is_new_patching_available():
41
93
  # Scikit-learn* modules
42
94
  import sklearn as base_module
43
95
  import sklearn.cluster as cluster_module
44
- import sklearn.covariance as covariance_module
45
96
  import sklearn.decomposition as decomposition_module
46
97
  import sklearn.ensemble as ensemble_module
47
98
  import sklearn.linear_model as linear_model_module
@@ -64,6 +115,7 @@ def get_patch_map():
64
115
  from .utils.parallel import _FuncWrapperOld as _FuncWrapper_sklearnex
65
116
 
66
117
  from .cluster import DBSCAN as DBSCAN_sklearnex
118
+ from .decomposition import PCA as PCA_sklearnex
67
119
  from .ensemble import ExtraTreesClassifier as ExtraTreesClassifier_sklearnex
68
120
  from .ensemble import ExtraTreesRegressor as ExtraTreesRegressor_sklearnex
69
121
  from .ensemble import RandomForestClassifier as RandomForestClassifier_sklearnex
@@ -74,53 +126,19 @@ def get_patch_map():
74
126
  from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex
75
127
  from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex
76
128
  from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex
77
-
78
- # Preview classes for patching
79
- from .preview.cluster import KMeans as KMeans_sklearnex
80
- from .preview.covariance import (
81
- EmpiricalCovariance as EmpiricalCovariance_sklearnex,
82
- )
83
- from .preview.decomposition import PCA as PCA_sklearnex
84
129
  from .svm import SVC as SVC_sklearnex
85
130
  from .svm import SVR as SVR_sklearnex
86
131
  from .svm import NuSVC as NuSVC_sklearnex
87
132
  from .svm import NuSVR as NuSVR_sklearnex
88
133
 
89
- # Patch for mapping
90
- if _is_preview_enabled():
91
- # PCA
92
- mapping.pop("pca")
93
- mapping["pca"] = [[(decomposition_module, "PCA", PCA_sklearnex), None]]
94
-
95
- # KMeans
96
- mapping.pop("kmeans")
97
- mapping["kmeans"] = [
98
- [
99
- (
100
- cluster_module,
101
- "KMeans",
102
- KMeans_sklearnex,
103
- ),
104
- None,
105
- ]
106
- ]
107
-
108
- # Covariance
109
- mapping["empiricalcovariance"] = [
110
- [
111
- (
112
- covariance_module,
113
- "EmpiricalCovariance",
114
- EmpiricalCovariance_sklearnex,
115
- ),
116
- None,
117
- ]
118
- ]
119
-
120
134
  # DBSCAN
121
135
  mapping.pop("dbscan")
122
136
  mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]]
123
137
 
138
+ # PCA
139
+ mapping.pop("pca")
140
+ mapping["pca"] = [[(decomposition_module, "PCA", PCA_sklearnex), None]]
141
+
124
142
  # SVM
125
143
  mapping.pop("svm")
126
144
  mapping.pop("svc")
@@ -276,6 +294,19 @@ def get_patch_map():
276
294
  return mapping
277
295
 
278
296
 
297
+ # This is necessary to properly cache the patch_map when
298
+ # using preview.
299
+ def get_patch_map():
300
+ preview = _is_preview_enabled()
301
+ return get_patch_map_core(preview=preview)
302
+
303
+
304
+ get_patch_map.cache_clear = get_patch_map_core.cache_clear
305
+
306
+
307
+ get_patch_map.cache_info = get_patch_map_core.cache_info
308
+
309
+
279
310
  def get_patch_names():
280
311
  return list(get_patch_map().keys())
281
312