scikit-learn-intelex 2025.4.0__py313-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (282) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-313-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-313-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +696 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +204 -0
  62. onedal/_onedal_py_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-313-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +175 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +242 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  70. onedal/basic_statistics/tests/utils.py +50 -0
  71. onedal/cluster/__init__.py +27 -0
  72. onedal/cluster/dbscan.py +105 -0
  73. onedal/cluster/kmeans.py +557 -0
  74. onedal/cluster/kmeans_init.py +112 -0
  75. onedal/cluster/tests/test_dbscan.py +125 -0
  76. onedal/cluster/tests/test_kmeans.py +88 -0
  77. onedal/cluster/tests/test_kmeans_init.py +93 -0
  78. onedal/common/_base.py +38 -0
  79. onedal/common/_estimator_checks.py +47 -0
  80. onedal/common/_mixin.py +62 -0
  81. onedal/common/_policy.py +55 -0
  82. onedal/common/_spmd_policy.py +30 -0
  83. onedal/common/hyperparameters.py +125 -0
  84. onedal/common/tests/test_policy.py +76 -0
  85. onedal/common/tests/test_sycl.py +128 -0
  86. onedal/covariance/__init__.py +20 -0
  87. onedal/covariance/covariance.py +122 -0
  88. onedal/covariance/incremental_covariance.py +161 -0
  89. onedal/covariance/tests/test_covariance.py +50 -0
  90. onedal/covariance/tests/test_incremental_covariance.py +190 -0
  91. onedal/datatypes/__init__.py +19 -0
  92. onedal/datatypes/_data_conversion.py +121 -0
  93. onedal/datatypes/tests/common.py +126 -0
  94. onedal/datatypes/tests/test_data.py +475 -0
  95. onedal/decomposition/__init__.py +20 -0
  96. onedal/decomposition/incremental_pca.py +214 -0
  97. onedal/decomposition/pca.py +186 -0
  98. onedal/decomposition/tests/test_incremental_pca.py +285 -0
  99. onedal/ensemble/__init__.py +29 -0
  100. onedal/ensemble/forest.py +736 -0
  101. onedal/ensemble/tests/test_random_forest.py +97 -0
  102. onedal/linear_model/__init__.py +27 -0
  103. onedal/linear_model/incremental_linear_model.py +292 -0
  104. onedal/linear_model/linear_model.py +325 -0
  105. onedal/linear_model/logistic_regression.py +247 -0
  106. onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  107. onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  108. onedal/linear_model/tests/test_linear_regression.py +259 -0
  109. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  110. onedal/linear_model/tests/test_ridge.py +95 -0
  111. onedal/neighbors/__init__.py +19 -0
  112. onedal/neighbors/neighbors.py +763 -0
  113. onedal/neighbors/tests/test_knn_classification.py +49 -0
  114. onedal/primitives/__init__.py +27 -0
  115. onedal/primitives/get_tree.py +25 -0
  116. onedal/primitives/kernel_functions.py +152 -0
  117. onedal/primitives/tests/test_kernel_functions.py +159 -0
  118. onedal/spmd/__init__.py +25 -0
  119. onedal/spmd/_base.py +30 -0
  120. onedal/spmd/basic_statistics/__init__.py +20 -0
  121. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  122. onedal/spmd/basic_statistics/incremental_basic_statistics.py +71 -0
  123. onedal/spmd/cluster/__init__.py +28 -0
  124. onedal/spmd/cluster/dbscan.py +23 -0
  125. onedal/spmd/cluster/kmeans.py +56 -0
  126. onedal/spmd/covariance/__init__.py +20 -0
  127. onedal/spmd/covariance/covariance.py +26 -0
  128. onedal/spmd/covariance/incremental_covariance.py +83 -0
  129. onedal/spmd/decomposition/__init__.py +20 -0
  130. onedal/spmd/decomposition/incremental_pca.py +124 -0
  131. onedal/spmd/decomposition/pca.py +26 -0
  132. onedal/spmd/ensemble/__init__.py +19 -0
  133. onedal/spmd/ensemble/forest.py +28 -0
  134. onedal/spmd/linear_model/__init__.py +21 -0
  135. onedal/spmd/linear_model/incremental_linear_model.py +101 -0
  136. onedal/spmd/linear_model/linear_model.py +30 -0
  137. onedal/spmd/linear_model/logistic_regression.py +38 -0
  138. onedal/spmd/neighbors/__init__.py +19 -0
  139. onedal/spmd/neighbors/neighbors.py +75 -0
  140. onedal/svm/__init__.py +19 -0
  141. onedal/svm/svm.py +556 -0
  142. onedal/svm/tests/test_csr_svm.py +351 -0
  143. onedal/svm/tests/test_nusvc.py +204 -0
  144. onedal/svm/tests/test_nusvr.py +210 -0
  145. onedal/svm/tests/test_svc.py +176 -0
  146. onedal/svm/tests/test_svr.py +243 -0
  147. onedal/tests/test_common.py +57 -0
  148. onedal/tests/utils/_dataframes_support.py +162 -0
  149. onedal/tests/utils/_device_selection.py +102 -0
  150. onedal/utils/__init__.py +49 -0
  151. onedal/utils/_array_api.py +81 -0
  152. onedal/utils/_dpep_helpers.py +56 -0
  153. onedal/utils/tests/test_validation.py +142 -0
  154. onedal/utils/validation.py +464 -0
  155. scikit_learn_intelex-2025.4.0.dist-info/LICENSE.txt +202 -0
  156. scikit_learn_intelex-2025.4.0.dist-info/METADATA +190 -0
  157. scikit_learn_intelex-2025.4.0.dist-info/RECORD +282 -0
  158. scikit_learn_intelex-2025.4.0.dist-info/WHEEL +5 -0
  159. scikit_learn_intelex-2025.4.0.dist-info/top_level.txt +3 -0
  160. sklearnex/__init__.py +66 -0
  161. sklearnex/__main__.py +58 -0
  162. sklearnex/_config.py +116 -0
  163. sklearnex/_device_offload.py +126 -0
  164. sklearnex/_utils.py +177 -0
  165. sklearnex/basic_statistics/__init__.py +20 -0
  166. sklearnex/basic_statistics/basic_statistics.py +261 -0
  167. sklearnex/basic_statistics/incremental_basic_statistics.py +352 -0
  168. sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  169. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  170. sklearnex/cluster/__init__.py +20 -0
  171. sklearnex/cluster/dbscan.py +197 -0
  172. sklearnex/cluster/k_means.py +397 -0
  173. sklearnex/cluster/tests/test_dbscan.py +38 -0
  174. sklearnex/cluster/tests/test_kmeans.py +157 -0
  175. sklearnex/conftest.py +82 -0
  176. sklearnex/covariance/__init__.py +19 -0
  177. sklearnex/covariance/incremental_covariance.py +405 -0
  178. sklearnex/covariance/tests/test_incremental_covariance.py +287 -0
  179. sklearnex/decomposition/__init__.py +19 -0
  180. sklearnex/decomposition/pca.py +427 -0
  181. sklearnex/decomposition/tests/test_pca.py +58 -0
  182. sklearnex/dispatcher.py +534 -0
  183. sklearnex/doc/third-party-programs.txt +424 -0
  184. sklearnex/ensemble/__init__.py +29 -0
  185. sklearnex/ensemble/_forest.py +2029 -0
  186. sklearnex/ensemble/tests/test_forest.py +140 -0
  187. sklearnex/glob/__main__.py +72 -0
  188. sklearnex/glob/dispatcher.py +101 -0
  189. sklearnex/linear_model/__init__.py +32 -0
  190. sklearnex/linear_model/coordinate_descent.py +30 -0
  191. sklearnex/linear_model/incremental_linear.py +495 -0
  192. sklearnex/linear_model/incremental_ridge.py +432 -0
  193. sklearnex/linear_model/linear.py +346 -0
  194. sklearnex/linear_model/logistic_regression.py +415 -0
  195. sklearnex/linear_model/ridge.py +390 -0
  196. sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  197. sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  198. sklearnex/linear_model/tests/test_linear.py +142 -0
  199. sklearnex/linear_model/tests/test_logreg.py +134 -0
  200. sklearnex/linear_model/tests/test_ridge.py +256 -0
  201. sklearnex/manifold/__init__.py +19 -0
  202. sklearnex/manifold/t_sne.py +26 -0
  203. sklearnex/manifold/tests/test_tsne.py +250 -0
  204. sklearnex/metrics/__init__.py +23 -0
  205. sklearnex/metrics/pairwise.py +22 -0
  206. sklearnex/metrics/ranking.py +20 -0
  207. sklearnex/metrics/tests/test_metrics.py +39 -0
  208. sklearnex/model_selection/__init__.py +21 -0
  209. sklearnex/model_selection/split.py +22 -0
  210. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  211. sklearnex/neighbors/__init__.py +27 -0
  212. sklearnex/neighbors/_lof.py +236 -0
  213. sklearnex/neighbors/common.py +310 -0
  214. sklearnex/neighbors/knn_classification.py +231 -0
  215. sklearnex/neighbors/knn_regression.py +207 -0
  216. sklearnex/neighbors/knn_unsupervised.py +178 -0
  217. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  218. sklearnex/preview/__init__.py +17 -0
  219. sklearnex/preview/covariance/__init__.py +19 -0
  220. sklearnex/preview/covariance/covariance.py +142 -0
  221. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  222. sklearnex/preview/decomposition/__init__.py +19 -0
  223. sklearnex/preview/decomposition/incremental_pca.py +244 -0
  224. sklearnex/preview/decomposition/tests/test_incremental_pca.py +336 -0
  225. sklearnex/spmd/__init__.py +25 -0
  226. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  227. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  228. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  229. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  230. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +306 -0
  231. sklearnex/spmd/cluster/__init__.py +30 -0
  232. sklearnex/spmd/cluster/dbscan.py +50 -0
  233. sklearnex/spmd/cluster/kmeans.py +21 -0
  234. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  235. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +173 -0
  236. sklearnex/spmd/covariance/__init__.py +20 -0
  237. sklearnex/spmd/covariance/covariance.py +21 -0
  238. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  239. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  240. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  241. sklearnex/spmd/decomposition/__init__.py +20 -0
  242. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  243. sklearnex/spmd/decomposition/pca.py +21 -0
  244. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  245. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  246. sklearnex/spmd/ensemble/__init__.py +19 -0
  247. sklearnex/spmd/ensemble/forest.py +71 -0
  248. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  249. sklearnex/spmd/linear_model/__init__.py +21 -0
  250. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  251. sklearnex/spmd/linear_model/linear_model.py +21 -0
  252. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  253. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +331 -0
  254. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  255. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  256. sklearnex/spmd/neighbors/__init__.py +19 -0
  257. sklearnex/spmd/neighbors/neighbors.py +25 -0
  258. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  259. sklearnex/svm/__init__.py +29 -0
  260. sklearnex/svm/_common.py +339 -0
  261. sklearnex/svm/nusvc.py +371 -0
  262. sklearnex/svm/nusvr.py +170 -0
  263. sklearnex/svm/svc.py +399 -0
  264. sklearnex/svm/svr.py +167 -0
  265. sklearnex/svm/tests/test_svm.py +93 -0
  266. sklearnex/tests/test_common.py +491 -0
  267. sklearnex/tests/test_config.py +123 -0
  268. sklearnex/tests/test_hyperparameters.py +43 -0
  269. sklearnex/tests/test_memory_usage.py +347 -0
  270. sklearnex/tests/test_monkeypatch.py +269 -0
  271. sklearnex/tests/test_n_jobs_support.py +108 -0
  272. sklearnex/tests/test_parallel.py +48 -0
  273. sklearnex/tests/test_patching.py +377 -0
  274. sklearnex/tests/test_run_to_run_stability.py +326 -0
  275. sklearnex/tests/utils/__init__.py +48 -0
  276. sklearnex/tests/utils/base.py +436 -0
  277. sklearnex/tests/utils/spmd.py +198 -0
  278. sklearnex/utils/__init__.py +19 -0
  279. sklearnex/utils/_array_api.py +82 -0
  280. sklearnex/utils/parallel.py +59 -0
  281. sklearnex/utils/tests/test_validation.py +238 -0
  282. sklearnex/utils/validation.py +208 -0
@@ -0,0 +1,245 @@
1
+ # ==============================================================================
2
+ # Copyright 2014 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import functools
18
+ import os
19
+ import sys
20
+ import warnings
21
+ from typing import Any, Tuple
22
+
23
+ import numpy as np
24
+ from numpy.lib.recfunctions import require_fields
25
+ from sklearn import __version__ as sklearn_version
26
+
27
+ from daal4py import _get__daal_link_version__ as dv
28
+
29
+ DaalVersionTuple = Tuple[int, str, int]
30
+
31
+ import logging
32
+
33
+ try:
34
+ from packaging.version import Version
35
+ except ImportError:
36
+ from distutils.version import LooseVersion as Version
37
+
38
+ try:
39
+ from pandas import DataFrame
40
+ from pandas.core.dtypes.cast import find_common_type
41
+
42
+ pandas_is_imported = True
43
+ except (ImportError, ModuleNotFoundError):
44
+ pandas_is_imported = False
45
+
46
+
47
+ def set_idp_sklearn_verbose():
48
+ logLevel = os.environ.get("IDP_SKLEARN_VERBOSE")
49
+ try:
50
+ if logLevel is not None:
51
+ logging.basicConfig(
52
+ stream=sys.stdout,
53
+ format="%(levelname)s: %(message)s",
54
+ level=logLevel.upper(),
55
+ )
56
+ except Exception:
57
+ warnings.warn(
58
+ 'Unknown level "{}" for logging.\n'
59
+ 'Please, use one of "CRITICAL", "ERROR", '
60
+ '"WARNING", "INFO", "DEBUG".'.format(logLevel)
61
+ )
62
+
63
+
64
+ def get_daal_version() -> DaalVersionTuple:
65
+ return int(dv()[0:4]), str(dv()[10:11]), int(dv()[4:8])
66
+
67
+
68
+ @functools.lru_cache(maxsize=256, typed=False)
69
+ def daal_check_version(
70
+ required_version: Tuple[Any, ...],
71
+ daal_version: Tuple[Any, ...] = get_daal_version(),
72
+ ) -> bool:
73
+ """Check daal version provided as (MAJOR, STATUS, MINOR+PATCH)
74
+
75
+ This function also accepts a list or tuple of daal versions. It will return true if
76
+ any version in the list/tuple is <= `daal_version`.
77
+ """
78
+ if isinstance(required_version[0], (list, tuple)):
79
+ # a list of version candidates was provided, recursively check if any is <= daal_version
80
+ return any(
81
+ map(lambda ver: daal_check_version(ver, daal_version), required_version)
82
+ )
83
+
84
+ major_required, status_required, patch_required = required_version
85
+ major, status, patch = daal_version
86
+
87
+ if status != status_required:
88
+ return False
89
+
90
+ if major_required < major:
91
+ return True
92
+ if major == major_required:
93
+ return patch_required <= patch
94
+
95
+ return False
96
+
97
+
98
+ def _package_check_version(version_to_check, available_version):
99
+ if hasattr(Version(version_to_check), "base_version"):
100
+ base_package_version = Version(available_version).base_version
101
+ res = bool(Version(base_package_version) >= Version(version_to_check))
102
+ else:
103
+ # packaging module not available
104
+ res = bool(Version(available_version) >= Version(version_to_check))
105
+ return res
106
+
107
+
108
+ @functools.lru_cache(maxsize=256, typed=False)
109
+ def sklearn_check_version(ver):
110
+ return _package_check_version(ver, sklearn_version)
111
+
112
+
113
+ def parse_dtype(dt):
114
+ if dt == np.double:
115
+ return "double"
116
+ if dt == np.single:
117
+ return "float"
118
+ raise ValueError(f"Input array has unexpected dtype = {dt}")
119
+
120
+
121
+ def getFPType(X):
122
+ if pandas_is_imported:
123
+ if isinstance(X, DataFrame):
124
+ dt = find_common_type(X.dtypes.tolist())
125
+ return parse_dtype(dt)
126
+
127
+ dt = getattr(X, "dtype", None)
128
+ return parse_dtype(dt)
129
+
130
+
131
+ def make2d(X):
132
+ if np.isscalar(X):
133
+ X = np.asarray(X)[np.newaxis, np.newaxis]
134
+ elif isinstance(X, np.ndarray) and X.ndim == 1:
135
+ X = X.reshape((X.size, 1))
136
+ return X
137
+
138
+
139
+ def get_patch_message(s):
140
+ if s == "daal":
141
+ message = "running accelerated version on CPU"
142
+
143
+ elif s == "sklearn":
144
+ message = "fallback to original Scikit-learn"
145
+ elif s == "sklearn_after_daal":
146
+ message = "failed to run accelerated version, fallback to original Scikit-learn"
147
+ else:
148
+ raise ValueError(
149
+ f"Invalid input - expected one of 'daal','sklearn',"
150
+ f" 'sklearn_after_daal', got {s}"
151
+ )
152
+ return message
153
+
154
+
155
+ def is_DataFrame(X):
156
+ if pandas_is_imported:
157
+ return isinstance(X, DataFrame)
158
+ else:
159
+ return False
160
+
161
+
162
+ def get_dtype(X):
163
+ if pandas_is_imported:
164
+ return find_common_type(list(X.dtypes)) if is_DataFrame(X) else X.dtype
165
+ else:
166
+ return getattr(X, "dtype", None)
167
+
168
+
169
+ def get_number_of_types(dataframe):
170
+ dtypes = getattr(dataframe, "dtypes", None)
171
+ try:
172
+ return len(set(dtypes))
173
+ except TypeError:
174
+ return 1
175
+
176
+
177
+ def check_tree_nodes(tree_nodes):
178
+ def convert_to_old_tree_nodes(tree_nodes):
179
+ # conversion from sklearn>=1.3 tree nodes format to previous format:
180
+ # removal of 'missing_go_to_left' field from node dtype
181
+ new_field = "missing_go_to_left"
182
+ new_dtype = tree_nodes.dtype
183
+ old_dtype = np.dtype(
184
+ [
185
+ (key, value[0])
186
+ for key, value in new_dtype.fields.items()
187
+ if key != new_field
188
+ ]
189
+ )
190
+ return require_fields(tree_nodes, old_dtype)
191
+
192
+ if sklearn_check_version("1.3"):
193
+ return tree_nodes
194
+ else:
195
+ return convert_to_old_tree_nodes(tree_nodes)
196
+
197
+
198
+ class PatchingConditionsChain:
199
+ def __init__(self, scope_name):
200
+ self.scope_name = scope_name
201
+ self.patching_is_enabled = True
202
+ self.messages = []
203
+ self.logger = logging.getLogger("sklearnex")
204
+
205
+ def _iter_conditions(self, conditions_and_messages):
206
+ result = []
207
+ for condition, message in conditions_and_messages:
208
+ result.append(condition)
209
+ if not condition:
210
+ self.messages.append(message)
211
+ return result
212
+
213
+ def and_conditions(self, conditions_and_messages, conditions_merging=all):
214
+ self.patching_is_enabled &= conditions_merging(
215
+ self._iter_conditions(conditions_and_messages)
216
+ )
217
+ return self.patching_is_enabled
218
+
219
+ def and_condition(self, condition, message):
220
+ return self.and_conditions([(condition, message)])
221
+
222
+ def or_conditions(self, conditions_and_messages, conditions_merging=all):
223
+ self.patching_is_enabled |= conditions_merging(
224
+ self._iter_conditions(conditions_and_messages)
225
+ )
226
+ return self.patching_is_enabled
227
+
228
+ def write_log(self):
229
+ if self.patching_is_enabled:
230
+ self.logger.info(f"{self.scope_name}: {get_patch_message('daal')}")
231
+ else:
232
+ self.logger.debug(
233
+ f"{self.scope_name}: debugging for the patch is enabled to track"
234
+ " the usage of Intel® oneAPI Data Analytics Library (oneDAL)"
235
+ )
236
+ for message in self.messages:
237
+ self.logger.debug(
238
+ f"{self.scope_name}: patching failed with cause - {message}"
239
+ )
240
+ self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
241
+
242
+ def get_status(self, logs=False):
243
+ if logs:
244
+ self.write_log()
245
+ return self.patching_is_enabled
@@ -0,0 +1,20 @@
1
+ # ==============================================================================
2
+ # Copyright 2014 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from .dbscan import DBSCAN
18
+ from .k_means import KMeans
19
+
20
+ __all__ = ["KMeans", "DBSCAN"]
@@ -0,0 +1,165 @@
1
+ # ==============================================================================
2
+ # Copyright 2014 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+
19
+ import numpy as np
20
+ from scipy import sparse as sp
21
+ from sklearn.cluster import DBSCAN as DBSCAN_original
22
+ from sklearn.utils import check_array
23
+ from sklearn.utils.validation import _check_sample_weight
24
+
25
+ import daal4py
26
+
27
+ from .._n_jobs_support import control_n_jobs
28
+ from .._utils import PatchingConditionsChain, getFPType, make2d, sklearn_check_version
29
+
30
+ if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
31
+ from sklearn.utils import check_scalar
32
+
33
+
34
+ def _daal_dbscan(X, eps=0.5, min_samples=5, sample_weight=None):
35
+ ww = make2d(sample_weight) if sample_weight is not None else None
36
+ XX = make2d(X)
37
+
38
+ fpt = getFPType(XX)
39
+ alg = daal4py.dbscan(
40
+ method="defaultDense",
41
+ fptype=fpt,
42
+ epsilon=float(eps),
43
+ minObservations=int(min_samples),
44
+ memorySavingMode=False,
45
+ resultsToCompute="computeCoreIndices",
46
+ )
47
+
48
+ daal_res = alg.compute(XX, ww)
49
+ assignments = daal_res.assignments.ravel()
50
+ if daal_res.coreIndices is not None:
51
+ core_ind = daal_res.coreIndices.ravel()
52
+ else:
53
+ core_ind = np.array([], dtype=np.intc)
54
+
55
+ return (core_ind, assignments)
56
+
57
+
58
+ @control_n_jobs(decorated_methods=["fit"])
59
+ class DBSCAN(DBSCAN_original):
60
+ __doc__ = DBSCAN_original.__doc__
61
+
62
+ if sklearn_check_version("1.2"):
63
+ _parameter_constraints: dict = {**DBSCAN_original._parameter_constraints}
64
+
65
+ def __init__(
66
+ self,
67
+ eps=0.5,
68
+ min_samples=5,
69
+ metric="euclidean",
70
+ metric_params=None,
71
+ algorithm="auto",
72
+ leaf_size=30,
73
+ p=None,
74
+ n_jobs=None,
75
+ ):
76
+ self.eps = eps
77
+ self.min_samples = min_samples
78
+ self.metric = metric
79
+ self.metric_params = metric_params
80
+ self.algorithm = algorithm
81
+ self.leaf_size = leaf_size
82
+ self.p = p
83
+ self.n_jobs = n_jobs
84
+
85
+ def fit(self, X, y=None, sample_weight=None):
86
+ if sklearn_check_version("1.2"):
87
+ self._validate_params()
88
+ elif sklearn_check_version("1.1"):
89
+ check_scalar(
90
+ self.eps,
91
+ "eps",
92
+ target_type=numbers.Real,
93
+ min_val=0.0,
94
+ include_boundaries="neither",
95
+ )
96
+ check_scalar(
97
+ self.min_samples,
98
+ "min_samples",
99
+ target_type=numbers.Integral,
100
+ min_val=1,
101
+ include_boundaries="left",
102
+ )
103
+ check_scalar(
104
+ self.leaf_size,
105
+ "leaf_size",
106
+ target_type=numbers.Integral,
107
+ min_val=1,
108
+ include_boundaries="left",
109
+ )
110
+ if self.p is not None:
111
+ check_scalar(
112
+ self.p,
113
+ "p",
114
+ target_type=numbers.Real,
115
+ min_val=0.0,
116
+ include_boundaries="left",
117
+ )
118
+ if self.n_jobs is not None:
119
+ check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)
120
+ else:
121
+ if self.eps <= 0.0:
122
+ raise ValueError(f"eps == {self.eps}, must be > 0.0.")
123
+
124
+ if sklearn_check_version("1.0"):
125
+ self._check_feature_names(X, reset=True)
126
+
127
+ if sample_weight is not None:
128
+ sample_weight = _check_sample_weight(sample_weight, X)
129
+
130
+ _patching_status = PatchingConditionsChain("sklearn.cluster.DBSCAN.fit")
131
+ _dal_ready = _patching_status.and_conditions(
132
+ [
133
+ (
134
+ self.algorithm in ["auto", "brute"],
135
+ f"'{self.algorithm}' algorithm is not supported. "
136
+ "Only 'auto' and 'brute' algorithms are supported",
137
+ ),
138
+ (
139
+ self.metric == "euclidean"
140
+ or (self.metric == "minkowski" and self.p == 2),
141
+ f"'{self.metric}' (p={self.p}) metric is not supported. "
142
+ "Only 'euclidean' or 'minkowski' with p=2 metrics are supported.",
143
+ ),
144
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
145
+ ]
146
+ )
147
+
148
+ _patching_status.write_log()
149
+ if _dal_ready:
150
+ X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
151
+ core_ind, assignments = _daal_dbscan(
152
+ X, self.eps, self.min_samples, sample_weight=sample_weight
153
+ )
154
+ self.core_sample_indices_ = core_ind
155
+ self.labels_ = assignments
156
+ self.components_ = np.take(X, core_ind, axis=0)
157
+ self.n_features_in_ = X.shape[1]
158
+ return self
159
+ return super().fit(X, y, sample_weight=sample_weight)
160
+
161
+ def fit_predict(self, X, y=None, sample_weight=None):
162
+ return super().fit_predict(X, y, sample_weight)
163
+
164
+ fit.__doc__ = DBSCAN_original.fit.__doc__
165
+ fit_predict.__doc__ = DBSCAN_original.fit_predict.__doc__