scikit-learn-intelex 2025.1.0__py310-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-310-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-310-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-310-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,503 @@
1
+ # ==============================================================================
2
+ # Copyright 2020 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ # daal4py KNN scikit-learn-compatible base classes
18
+
19
+ import logging
20
+ import numbers
21
+ import warnings
22
+
23
+ import numpy as np
24
+ from scipy import sparse as sp
25
+ from sklearn.base import is_classifier, is_regressor
26
+ from sklearn.neighbors import VALID_METRICS
27
+ from sklearn.neighbors._ball_tree import BallTree
28
+ from sklearn.neighbors._base import KNeighborsMixin as BaseKNeighborsMixin
29
+ from sklearn.neighbors._base import NeighborsBase as BaseNeighborsBase
30
+ from sklearn.neighbors._base import RadiusNeighborsMixin as BaseRadiusNeighborsMixin
31
+ from sklearn.neighbors._kd_tree import KDTree
32
+ from sklearn.utils.multiclass import check_classification_targets
33
+ from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
34
+
35
+ import daal4py as d4p
36
+
37
+ from .._utils import (
38
+ PatchingConditionsChain,
39
+ get_patch_message,
40
+ getFPType,
41
+ sklearn_check_version,
42
+ )
43
+
44
+ if not sklearn_check_version("1.2"):
45
+ from sklearn.neighbors._base import _check_weights
46
+
47
+
48
+ def training_algorithm(method, fptype, params):
49
+ if method == "brute":
50
+ train_alg = d4p.bf_knn_classification_training
51
+
52
+ else:
53
+ train_alg = d4p.kdtree_knn_classification_training
54
+
55
+ params["fptype"] = fptype
56
+ return train_alg(**params)
57
+
58
+
59
+ def prediction_algorithm(method, fptype, params):
60
+ if method == "brute":
61
+ predict_alg = d4p.bf_knn_classification_prediction
62
+ else:
63
+ predict_alg = d4p.kdtree_knn_classification_prediction
64
+
65
+ params["fptype"] = fptype
66
+ return predict_alg(**params)
67
+
68
+
69
+ def parse_auto_method(estimator, method, n_samples, n_features):
70
+ result_method = method
71
+
72
+ if method in ["auto", "ball_tree"]:
73
+ condition = (
74
+ estimator.n_neighbors is not None
75
+ and estimator.n_neighbors >= estimator.n_samples_fit_ // 2
76
+ )
77
+ if estimator.metric == "precomputed" or n_features > 11 or condition:
78
+ result_method = "brute"
79
+ else:
80
+ if estimator.effective_metric_ in VALID_METRICS["kd_tree"]:
81
+ result_method = "kd_tree"
82
+ else:
83
+ result_method = "brute"
84
+
85
+ return result_method
86
+
87
+
88
+ def daal4py_fit(estimator, X, fptype):
89
+ estimator._fit_X = X
90
+ estimator._fit_method = estimator.algorithm
91
+ estimator.effective_metric_ = "euclidean"
92
+ estimator._tree = None
93
+ weights = getattr(estimator, "weights", "uniform")
94
+
95
+ params = {
96
+ "method": "defaultDense",
97
+ "k": estimator.n_neighbors,
98
+ "voteWeights": "voteUniform" if weights == "uniform" else "voteDistance",
99
+ "resultsToCompute": "computeIndicesOfNeighbors|computeDistances",
100
+ "resultsToEvaluate": (
101
+ "none" if getattr(estimator, "_y", None) is None else "computeClassLabels"
102
+ ),
103
+ }
104
+ if hasattr(estimator, "classes_"):
105
+ params["nClasses"] = len(estimator.classes_)
106
+
107
+ if getattr(estimator, "_y", None) is None:
108
+ labels = None
109
+ else:
110
+ labels = estimator._y.reshape(-1, 1)
111
+
112
+ method = parse_auto_method(
113
+ estimator, estimator.algorithm, estimator.n_samples_fit_, estimator.n_features_in_
114
+ )
115
+ estimator._fit_method = method
116
+ train_alg = training_algorithm(method, fptype, params)
117
+ estimator._daal_model = train_alg.compute(X, labels).model
118
+
119
+
120
+ def daal4py_kneighbors(estimator, X=None, n_neighbors=None, return_distance=True):
121
+ n_features = getattr(estimator, "n_features_in_", None)
122
+ shape = getattr(X, "shape", None)
123
+ if n_features and shape and len(shape) > 1 and shape[1] != n_features:
124
+ raise ValueError(
125
+ (
126
+ f"X has {X.shape[1]} features, "
127
+ f"but kneighbors is expecting {n_features} features as input"
128
+ )
129
+ )
130
+
131
+ check_is_fitted(estimator)
132
+
133
+ if n_neighbors is None:
134
+ n_neighbors = estimator.n_neighbors
135
+ elif n_neighbors <= 0:
136
+ raise ValueError("Expected n_neighbors > 0. Got %d" % n_neighbors)
137
+ else:
138
+ if not isinstance(n_neighbors, numbers.Integral):
139
+ raise TypeError(
140
+ "n_neighbors does not take %s value, "
141
+ "enter integer value" % type(n_neighbors)
142
+ )
143
+
144
+ if X is not None:
145
+ query_is_train = False
146
+ X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
147
+ else:
148
+ query_is_train = True
149
+ X = estimator._fit_X
150
+ # Include an extra neighbor to account for the sample itself being
151
+ # returned, which is removed later
152
+ n_neighbors += 1
153
+
154
+ n_samples_fit = estimator.n_samples_fit_
155
+ if n_neighbors > n_samples_fit:
156
+ raise ValueError(
157
+ "Expected n_neighbors <= n_samples, "
158
+ " but n_samples = %d, n_neighbors = %d" % (n_samples_fit, n_neighbors)
159
+ )
160
+
161
+ chunked_results = None
162
+
163
+ try:
164
+ fptype = getFPType(X)
165
+ except ValueError:
166
+ fptype = None
167
+
168
+ weights = getattr(estimator, "weights", "uniform")
169
+
170
+ params = {
171
+ "method": "defaultDense",
172
+ "k": n_neighbors,
173
+ "voteWeights": "voteUniform" if weights == "uniform" else "voteDistance",
174
+ "resultsToCompute": "computeIndicesOfNeighbors|computeDistances",
175
+ "resultsToEvaluate": (
176
+ "none" if getattr(estimator, "_y", None) is None else "computeClassLabels"
177
+ ),
178
+ }
179
+ if hasattr(estimator, "classes_"):
180
+ params["nClasses"] = len(estimator.classes_)
181
+
182
+ method = parse_auto_method(
183
+ estimator, estimator._fit_method, estimator.n_samples_fit_, n_features
184
+ )
185
+
186
+ predict_alg = prediction_algorithm(method, fptype, params)
187
+ prediction_result = predict_alg.compute(X, estimator._daal_model)
188
+
189
+ distances = prediction_result.distances
190
+ indices = prediction_result.indices
191
+
192
+ if method == "kd_tree":
193
+ for i in range(distances.shape[0]):
194
+ seq = distances[i].argsort()
195
+ indices[i] = indices[i][seq]
196
+ distances[i] = distances[i][seq]
197
+
198
+ if return_distance:
199
+ results = distances, indices.astype(int)
200
+ else:
201
+ results = indices.astype(int)
202
+
203
+ if chunked_results is not None:
204
+ if return_distance:
205
+ neigh_dist, neigh_ind = zip(*chunked_results)
206
+ results = np.vstack(neigh_dist), np.vstack(neigh_ind)
207
+ else:
208
+ results = np.vstack(chunked_results)
209
+
210
+ if not query_is_train:
211
+ return results
212
+ # If the query data is the same as the indexed data, we would like
213
+ # to ignore the first nearest neighbor of every sample, i.e
214
+ # the sample itself.
215
+ if return_distance:
216
+ neigh_dist, neigh_ind = results
217
+ else:
218
+ neigh_ind = results
219
+
220
+ n_queries, _ = X.shape
221
+ sample_range = np.arange(n_queries)[:, None]
222
+ sample_mask = neigh_ind != sample_range
223
+
224
+ # Corner case: When the number of duplicates are more
225
+ # than the number of neighbors, the first NN will not
226
+ # be the sample, but a duplicate.
227
+ # In that case mask the first duplicate.
228
+ dup_gr_nbrs = np.all(sample_mask, axis=1)
229
+ sample_mask[:, 0][dup_gr_nbrs] = False
230
+ neigh_ind = np.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors - 1))
231
+
232
+ if return_distance:
233
+ neigh_dist = np.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors - 1))
234
+ return neigh_dist, neigh_ind
235
+ return neigh_ind
236
+
237
+
238
+ def validate_data(
239
+ estimator, X, y=None, reset=True, validate_separately=False, **check_params
240
+ ):
241
+ if y is None:
242
+ try:
243
+ requires_y = estimator._get_tags()["requires_y"]
244
+ except KeyError:
245
+ requires_y = False
246
+
247
+ if requires_y:
248
+ raise ValueError(
249
+ f"This {estimator.__class__.__name__} estimator "
250
+ f"requires y to be passed, but the target y is None."
251
+ )
252
+ X = check_array(X, **check_params)
253
+ out = X, y
254
+ else:
255
+ if validate_separately:
256
+ # We need this because some estimators validate X and y
257
+ # separately, and in general, separately calling check_array()
258
+ # on X and y isn't equivalent to just calling check_X_y()
259
+ # :(
260
+ check_X_params, check_y_params = validate_separately
261
+ X = check_array(X, **check_X_params)
262
+ y = check_array(y, **check_y_params)
263
+ else:
264
+ X, y = check_X_y(X, y, **check_params)
265
+ out = X, y
266
+
267
+ if check_params.get("ensure_2d", True):
268
+ estimator._check_n_features(X, reset=reset)
269
+
270
+ return out
271
+
272
+
273
+ class NeighborsBase(BaseNeighborsBase):
274
+ def __init__(
275
+ self,
276
+ n_neighbors=None,
277
+ radius=None,
278
+ algorithm="auto",
279
+ leaf_size=30,
280
+ metric="minkowski",
281
+ p=2,
282
+ metric_params=None,
283
+ n_jobs=None,
284
+ ):
285
+ super().__init__(
286
+ n_neighbors=n_neighbors,
287
+ radius=radius,
288
+ algorithm=algorithm,
289
+ leaf_size=leaf_size,
290
+ metric=metric,
291
+ p=p,
292
+ metric_params=metric_params,
293
+ n_jobs=n_jobs,
294
+ )
295
+
296
+ def _fit(self, X, y=None):
297
+ if self.metric_params is not None and "p" in self.metric_params:
298
+ if self.p is not None:
299
+ warnings.warn(
300
+ "Parameter p is found in metric_params. "
301
+ "The corresponding parameter from __init__ "
302
+ "is ignored.",
303
+ SyntaxWarning,
304
+ stacklevel=2,
305
+ )
306
+
307
+ if (
308
+ hasattr(self, "weights")
309
+ and sklearn_check_version("1.0")
310
+ and not sklearn_check_version("1.2")
311
+ ):
312
+ self.weights = _check_weights(self.weights)
313
+
314
+ if sklearn_check_version("1.0"):
315
+ self._check_feature_names(X, reset=True)
316
+
317
+ X_incorrect_type = isinstance(
318
+ X, (KDTree, BallTree, NeighborsBase, BaseNeighborsBase)
319
+ )
320
+ single_output = True
321
+ self._daal_model = None
322
+ shape = None
323
+ correct_n_classes = True
324
+
325
+ try:
326
+ requires_y = self._get_tags()["requires_y"]
327
+ except KeyError:
328
+ requires_y = False
329
+
330
+ if y is not None or requires_y:
331
+ if not X_incorrect_type or y is None:
332
+ X, y = validate_data(
333
+ self,
334
+ X,
335
+ y,
336
+ accept_sparse="csr",
337
+ multi_output=True,
338
+ dtype=[np.float64, np.float32],
339
+ )
340
+ single_output = False if y.ndim > 1 and y.shape[1] > 1 else True
341
+
342
+ shape = y.shape
343
+
344
+ if is_classifier(self) or is_regressor(self):
345
+ if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1:
346
+ self.outputs_2d_ = False
347
+ y = y.reshape((-1, 1))
348
+ else:
349
+ self.outputs_2d_ = True
350
+
351
+ if is_classifier(self):
352
+ check_classification_targets(y)
353
+ self.classes_ = []
354
+ self._y = np.empty(y.shape, dtype=int)
355
+ for k in range(self._y.shape[1]):
356
+ classes, self._y[:, k] = np.unique(y[:, k], return_inverse=True)
357
+ self.classes_.append(classes)
358
+
359
+ if not self.outputs_2d_:
360
+ self.classes_ = self.classes_[0]
361
+ self._y = self._y.ravel()
362
+
363
+ n_classes = len(self.classes_)
364
+ if n_classes < 2:
365
+ correct_n_classes = False
366
+ else:
367
+ self._y = y
368
+ else:
369
+ if not X_incorrect_type:
370
+ X, _ = validate_data(
371
+ self, X, accept_sparse="csr", dtype=[np.float64, np.float32]
372
+ )
373
+
374
+ if not X_incorrect_type:
375
+ self.n_samples_fit_ = X.shape[0]
376
+ self.n_features_in_ = X.shape[1]
377
+
378
+ try:
379
+ fptype = getFPType(X)
380
+ except ValueError:
381
+ fptype = None
382
+
383
+ weights = getattr(self, "weights", "uniform")
384
+
385
+ def stock_fit(self, X, y):
386
+ result = super(NeighborsBase, self)._fit(X, y)
387
+ return result
388
+
389
+ if self.n_neighbors is not None:
390
+ if self.n_neighbors <= 0:
391
+ raise ValueError("Expected n_neighbors > 0. Got %d" % self.n_neighbors)
392
+ if not isinstance(self.n_neighbors, numbers.Integral):
393
+ raise TypeError(
394
+ "n_neighbors does not take %s value, "
395
+ "enter integer value" % type(self.n_neighbors)
396
+ )
397
+
398
+ _patching_status = PatchingConditionsChain(
399
+ "sklearn.neighbors.KNeighborsMixin.kneighbors"
400
+ )
401
+ _dal_ready = _patching_status.and_conditions(
402
+ [
403
+ (
404
+ self.metric == "minkowski"
405
+ and self.p == 2
406
+ or self.metric == "euclidean",
407
+ f"'{self.metric}' (p={self.p}) metric is not supported. "
408
+ "Only 'euclidean' or 'minkowski' with p=2 metrics are supported.",
409
+ ),
410
+ (not X_incorrect_type, "X is not Tree or Neighbors instance or array."),
411
+ (
412
+ weights in ["uniform", "distance"],
413
+ f"'{weights}' weights is not supported. "
414
+ "Only 'uniform' and 'distance' weights are supported.",
415
+ ),
416
+ (
417
+ self.algorithm in ["brute", "kd_tree", "auto", "ball_tree"],
418
+ f"'{self.algorithm}' algorithm is not supported. "
419
+ "Only 'brute', 'kd_tree', 'auto' and 'ball_tree' "
420
+ "algorithms are supported.",
421
+ ),
422
+ (single_output, "Multiple outputs are not supported."),
423
+ (fptype is not None, "Unable to get dtype."),
424
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
425
+ (correct_n_classes, "Number of classes < 2."),
426
+ ]
427
+ )
428
+ _patching_status.write_log()
429
+ if _dal_ready:
430
+ try:
431
+ daal4py_fit(self, X, fptype)
432
+ result = self
433
+ except RuntimeError:
434
+ logging.info(
435
+ "sklearn.neighbors.KNeighborsMixin."
436
+ "kneighbors: " + get_patch_message("sklearn_after_daal")
437
+ )
438
+ result = stock_fit(self, X, y)
439
+ else:
440
+ result = stock_fit(self, X, y)
441
+
442
+ if y is not None and is_regressor(self):
443
+ self._y = y if shape is None else y.reshape(shape)
444
+
445
+ return result
446
+
447
+
448
+ class KNeighborsMixin(BaseKNeighborsMixin):
449
+ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
450
+ daal_model = getattr(self, "_daal_model", None)
451
+ if X is not None and self.metric != "precomputed":
452
+ X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
453
+ x = self._fit_X if X is None else X
454
+ try:
455
+ fptype = getFPType(x)
456
+ except ValueError:
457
+ fptype = None
458
+
459
+ _patching_status = PatchingConditionsChain(
460
+ "sklearn.neighbors.KNeighborsMixin.kneighbors"
461
+ )
462
+ _dal_ready = _patching_status.and_conditions(
463
+ [
464
+ (daal_model is not None, "oneDAL model was not trained."),
465
+ (fptype is not None, "Unable to get dtype."),
466
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
467
+ ]
468
+ )
469
+ _patching_status.write_log()
470
+
471
+ if _dal_ready:
472
+ result = daal4py_kneighbors(self, X, n_neighbors, return_distance)
473
+ else:
474
+ if (
475
+ daal_model is not None
476
+ or getattr(self, "_tree", 0) is None
477
+ and self._fit_method == "kd_tree"
478
+ ):
479
+ BaseNeighborsBase._fit(self, self._fit_X, getattr(self, "_y", None))
480
+ result = super(KNeighborsMixin, self).kneighbors(
481
+ X, n_neighbors, return_distance
482
+ )
483
+
484
+ return result
485
+
486
+
487
+ class RadiusNeighborsMixin(BaseRadiusNeighborsMixin):
488
+ def radius_neighbors(
489
+ self, X=None, radius=None, return_distance=True, sort_results=False
490
+ ):
491
+ daal_model = getattr(self, "_daal_model", None)
492
+
493
+ if (
494
+ daal_model is not None
495
+ or getattr(self, "_tree", 0) is None
496
+ and self._fit_method == "kd_tree"
497
+ ):
498
+ BaseNeighborsBase._fit(self, self._fit_X, getattr(self, "_y", None))
499
+ result = BaseRadiusNeighborsMixin.radius_neighbors(
500
+ self, X, radius, return_distance, sort_results
501
+ )
502
+
503
+ return result
@@ -0,0 +1,139 @@
1
+ # ==============================================================================
2
+ # Copyright 2020 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ # daal4py KNN classification scikit-learn-compatible classes
18
+
19
+ import numpy as np
20
+ from scipy import sparse as sp
21
+ from sklearn.base import ClassifierMixin as BaseClassifierMixin
22
+ from sklearn.neighbors._classification import (
23
+ KNeighborsClassifier as BaseKNeighborsClassifier,
24
+ )
25
+ from sklearn.utils.validation import check_array
26
+
27
+ from .._utils import PatchingConditionsChain, getFPType, sklearn_check_version
28
+ from ._base import KNeighborsMixin, NeighborsBase, parse_auto_method, prediction_algorithm
29
+
30
+ if not sklearn_check_version("1.2"):
31
+ from sklearn.neighbors._base import _check_weights
32
+
33
+ from sklearn.utils.validation import _deprecate_positional_args
34
+
35
+
36
+ def daal4py_classifier_predict(estimator, X, base_predict):
37
+ if sklearn_check_version("1.0"):
38
+ estimator._check_feature_names(X, reset=False)
39
+ X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
40
+ daal_model = getattr(estimator, "_daal_model", None)
41
+ n_features = getattr(estimator, "n_features_in_", None)
42
+ shape = getattr(X, "shape", None)
43
+ if n_features and shape and len(shape) > 1 and shape[1] != n_features:
44
+ raise ValueError(
45
+ (
46
+ f"X has {X.shape[1]} features, "
47
+ f"but KNNClassifier is expecting "
48
+ f"{n_features} features as input"
49
+ )
50
+ )
51
+
52
+ try:
53
+ fptype = getFPType(X)
54
+ except ValueError:
55
+ fptype = None
56
+
57
+ _patching_status = PatchingConditionsChain(
58
+ "sklearn.neighbors.KNeighborsClassifier.predict"
59
+ )
60
+ _dal_ready = _patching_status.and_conditions(
61
+ [
62
+ (daal_model is not None, "oneDAL model was not trained."),
63
+ (fptype is not None, "Unable to get dtype."),
64
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
65
+ ]
66
+ )
67
+ _patching_status.write_log()
68
+
69
+ if _dal_ready:
70
+ params = {
71
+ "method": "defaultDense",
72
+ "k": estimator.n_neighbors,
73
+ "nClasses": len(estimator.classes_),
74
+ "voteWeights": (
75
+ "voteUniform" if estimator.weights == "uniform" else "voteDistance"
76
+ ),
77
+ "resultsToEvaluate": "computeClassLabels",
78
+ "resultsToCompute": "",
79
+ }
80
+
81
+ method = parse_auto_method(
82
+ estimator, estimator.algorithm, estimator.n_samples_fit_, n_features
83
+ )
84
+ predict_alg = prediction_algorithm(method, fptype, params)
85
+ prediction_result = predict_alg.compute(X, daal_model)
86
+ result = estimator.classes_.take(
87
+ np.asarray(prediction_result.prediction.ravel(), dtype=np.intp)
88
+ )
89
+ else:
90
+ result = base_predict(estimator, X)
91
+
92
+ return result
93
+
94
+
95
+ class KNeighborsClassifier(KNeighborsMixin, BaseClassifierMixin, NeighborsBase):
96
+ __doc__ = BaseKNeighborsClassifier.__doc__
97
+
98
+ @_deprecate_positional_args
99
+ def __init__(
100
+ self,
101
+ n_neighbors=5,
102
+ *,
103
+ weights="uniform",
104
+ algorithm="auto",
105
+ leaf_size=30,
106
+ p=2,
107
+ metric="minkowski",
108
+ metric_params=None,
109
+ n_jobs=None,
110
+ **kwargs,
111
+ ):
112
+ super().__init__(
113
+ n_neighbors=n_neighbors,
114
+ algorithm=algorithm,
115
+ leaf_size=leaf_size,
116
+ metric=metric,
117
+ p=p,
118
+ metric_params=metric_params,
119
+ n_jobs=n_jobs,
120
+ **kwargs,
121
+ )
122
+ self.weights = (
123
+ weights if sklearn_check_version("1.0") else _check_weights(weights)
124
+ )
125
+
126
+ def fit(self, X, y):
127
+ return NeighborsBase._fit(self, X, y)
128
+
129
+ def predict(self, X):
130
+ return daal4py_classifier_predict(self, X, BaseKNeighborsClassifier.predict)
131
+
132
+ def predict_proba(self, X):
133
+ if sklearn_check_version("1.0"):
134
+ self._check_feature_names(X, reset=False)
135
+ return BaseKNeighborsClassifier.predict_proba(self, X)
136
+
137
+ fit.__doc__ = BaseKNeighborsClassifier.fit.__doc__
138
+ predict.__doc__ = BaseKNeighborsClassifier.predict.__doc__
139
+ predict_proba.__doc__ = BaseKNeighborsClassifier.predict_proba.__doc__