scikit-learn-intelex 2024.0.1__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (90) hide show
  1. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__init__.py +61 -0
  2. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__main__.py +59 -0
  3. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_config.py +110 -0
  4. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_device_offload.py +223 -0
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  7. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +17 -0
  8. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +21 -0
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +18 -0
  11. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +37 -0
  12. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +31 -0
  13. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +20 -0
  14. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +18 -0
  15. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +28 -0
  16. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/dispatcher.py +329 -0
  17. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +424 -0
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +30 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  20. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  21. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/__main__.py +73 -0
  22. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +88 -0
  23. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +30 -0
  24. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +18 -0
  25. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +373 -0
  26. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +18 -0
  27. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +18 -0
  28. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +77 -0
  29. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +29 -0
  30. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +20 -0
  31. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +18 -0
  32. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +27 -0
  33. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +24 -0
  34. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +18 -0
  35. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +18 -0
  36. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +40 -0
  37. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +22 -0
  38. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/split.py +18 -0
  39. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +35 -0
  40. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +28 -0
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/common.py +264 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  43. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  44. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +220 -0
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +437 -0
  46. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  47. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +18 -0
  48. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +20 -0
  49. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +84 -0
  50. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +370 -0
  51. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +20 -0
  52. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +376 -0
  53. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +38 -0
  54. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +24 -0
  55. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +19 -0
  56. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  58. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  59. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  60. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +19 -0
  61. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +21 -0
  62. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  63. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +79 -0
  64. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +19 -0
  65. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +21 -0
  66. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +19 -0
  67. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +25 -0
  68. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/__init__.py +30 -0
  69. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/_common.py +188 -0
  70. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +272 -0
  71. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +163 -0
  72. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svc.py +301 -0
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svr.py +164 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  75. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  76. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_config.py +39 -0
  77. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +225 -0
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +210 -0
  79. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +122 -0
  81. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  82. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +118 -0
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  84. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  85. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/validation.py +18 -0
  86. scikit_learn_intelex-2024.0.1.dist-info/LICENSE.txt +202 -0
  87. scikit_learn_intelex-2024.0.1.dist-info/METADATA +230 -0
  88. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  89. scikit_learn_intelex-2024.0.1.dist-info/WHEEL +5 -0
  90. scikit_learn_intelex-2024.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,40 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ import numpy as np
19
+ from numpy.testing import assert_allclose
20
+ from sklearn.datasets import load_breast_cancer
21
+
22
+
23
+ def test_sklearnex_import_roc_auc():
24
+ from sklearnex.linear_model import LogisticRegression
25
+ from sklearnex.metrics import roc_auc_score
26
+
27
+ X, y = load_breast_cancer(return_X_y=True)
28
+ clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
29
+ res = roc_auc_score(y, clf.decision_function(X))
30
+ assert_allclose(res, 0.99, atol=1e-2)
31
+
32
+
33
+ def test_sklearnex_import_pairwise_distances():
34
+ from sklearnex.metrics import pairwise_distances
35
+
36
+ rng = np.random.RandomState(0)
37
+ x = np.abs(rng.rand(4), dtype=np.float64)
38
+ x = np.vstack([x, x])
39
+ res = pairwise_distances(x, metric="cosine")
40
+ assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2)
@@ -0,0 +1,22 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ from .split import train_test_split
19
+
20
+ __all__ = [
21
+ "train_test_split",
22
+ ]
@@ -0,0 +1,18 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ from daal4py.sklearn.model_selection import train_test_split
@@ -0,0 +1,35 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ import numpy as np
19
+ from numpy.testing import assert_allclose
20
+
21
+
22
+ # TODO:
23
+ # add pytest params for checking different dataframe inputs/outputs.
24
+ def test_sklearnex_import_train_test_split():
25
+ from sklearnex.model_selection import train_test_split
26
+
27
+ X = np.arange(100).reshape((10, 10))
28
+ y = np.arange(10)
29
+
30
+ split = train_test_split(X, y, test_size=None, train_size=0.5)
31
+ X_train, X_test, y_train, y_test = split
32
+ assert len(y_test) == len(y_train)
33
+
34
+ assert_allclose(X_train[:, 0], y_train * 10)
35
+ assert_allclose(X_test[:, 0], y_test * 10)
@@ -0,0 +1,28 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ from .knn_classification import KNeighborsClassifier
19
+ from .knn_regression import KNeighborsRegressor
20
+ from .knn_unsupervised import NearestNeighbors
21
+ from .lof import LocalOutlierFactor
22
+
23
+ __all__ = [
24
+ "KNeighborsClassifier",
25
+ "KNeighborsRegressor",
26
+ "LocalOutlierFactor",
27
+ "NearestNeighbors",
28
+ ]
@@ -0,0 +1,264 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2023 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ import warnings
19
+
20
+ import numpy as np
21
+ from scipy import sparse as sp
22
+ from sklearn.neighbors._ball_tree import BallTree
23
+ from sklearn.neighbors._base import VALID_METRICS
24
+ from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
25
+ from sklearn.neighbors._kd_tree import KDTree
26
+
27
+ from daal4py.sklearn._utils import sklearn_check_version
28
+ from onedal.utils import _check_array, _num_features, _num_samples
29
+
30
+ from .._utils import PatchingConditionsChain
31
+
32
+
33
+ class KNeighborsDispatchingBase:
34
+ def _fit_validation(self, X, y=None):
35
+ if sklearn_check_version("1.2"):
36
+ self._validate_params()
37
+ if sklearn_check_version("1.0"):
38
+ self._check_feature_names(X, reset=True)
39
+ if self.metric_params is not None and "p" in self.metric_params:
40
+ if self.p is not None:
41
+ warnings.warn(
42
+ "Parameter p is found in metric_params. "
43
+ "The corresponding parameter from __init__ "
44
+ "is ignored.",
45
+ SyntaxWarning,
46
+ stacklevel=2,
47
+ )
48
+ self.effective_metric_params_ = self.metric_params.copy()
49
+ effective_p = self.metric_params["p"]
50
+ else:
51
+ self.effective_metric_params_ = {}
52
+ effective_p = self.p
53
+
54
+ self.effective_metric_params_["p"] = effective_p
55
+ self.effective_metric_ = self.metric
56
+ # For minkowski distance, use more efficient methods where available
57
+ if self.metric == "minkowski":
58
+ p = self.effective_metric_params_["p"]
59
+ if p == 1:
60
+ self.effective_metric_ = "manhattan"
61
+ elif p == 2:
62
+ self.effective_metric_ = "euclidean"
63
+ elif p == np.inf:
64
+ self.effective_metric_ = "chebyshev"
65
+
66
+ if not isinstance(X, (KDTree, BallTree, sklearn_NeighborsBase)):
67
+ self._fit_X = _check_array(
68
+ X, dtype=[np.float64, np.float32], accept_sparse=True
69
+ )
70
+ self.n_samples_fit_ = _num_samples(self._fit_X)
71
+ self.n_features_in_ = _num_features(self._fit_X)
72
+
73
+ if self.algorithm == "auto":
74
+ # A tree approach is better for small number of neighbors or small
75
+ # number of features, with KDTree generally faster when available
76
+ is_n_neighbors_valid_for_brute = (
77
+ self.n_neighbors is not None
78
+ and self.n_neighbors >= self._fit_X.shape[0] // 2
79
+ )
80
+ if self._fit_X.shape[1] > 15 or is_n_neighbors_valid_for_brute:
81
+ self._fit_method = "brute"
82
+ else:
83
+ if self.effective_metric_ in VALID_METRICS["kd_tree"]:
84
+ self._fit_method = "kd_tree"
85
+ elif (
86
+ callable(self.effective_metric_)
87
+ or self.effective_metric_ in VALID_METRICS["ball_tree"]
88
+ ):
89
+ self._fit_method = "ball_tree"
90
+ else:
91
+ self._fit_method = "brute"
92
+ else:
93
+ self._fit_method = self.algorithm
94
+
95
+ if hasattr(self, "_onedal_estimator"):
96
+ delattr(self, "_onedal_estimator")
97
+ # To cover test case when we pass patched
98
+ # estimator as an input for other estimator
99
+ if isinstance(X, sklearn_NeighborsBase):
100
+ self._fit_X = X._fit_X
101
+ self._tree = X._tree
102
+ self._fit_method = X._fit_method
103
+ self.n_samples_fit_ = X.n_samples_fit_
104
+ self.n_features_in_ = X.n_features_in_
105
+ if hasattr(X, "_onedal_estimator"):
106
+ self.effective_metric_params_.pop("p")
107
+ if self._fit_method == "ball_tree":
108
+ X._tree = BallTree(
109
+ X._fit_X,
110
+ self.leaf_size,
111
+ metric=self.effective_metric_,
112
+ **self.effective_metric_params_,
113
+ )
114
+ elif self._fit_method == "kd_tree":
115
+ X._tree = KDTree(
116
+ X._fit_X,
117
+ self.leaf_size,
118
+ metric=self.effective_metric_,
119
+ **self.effective_metric_params_,
120
+ )
121
+ elif self._fit_method == "brute":
122
+ X._tree = None
123
+ else:
124
+ raise ValueError("algorithm = '%s' not recognized" % self.algorithm)
125
+
126
+ elif isinstance(X, BallTree):
127
+ self._fit_X = X.data
128
+ self._tree = X
129
+ self._fit_method = "ball_tree"
130
+ self.n_samples_fit_ = X.data.shape[0]
131
+ self.n_features_in_ = X.data.shape[1]
132
+
133
+ elif isinstance(X, KDTree):
134
+ self._fit_X = X.data
135
+ self._tree = X
136
+ self._fit_method = "kd_tree"
137
+ self.n_samples_fit_ = X.data.shape[0]
138
+ self.n_features_in_ = X.data.shape[1]
139
+
140
+ def _onedal_supported(self, device, method_name, *data):
141
+ class_name = self.__class__.__name__
142
+ is_classifier = "Classifier" in class_name
143
+ is_regressor = "Regressor" in class_name
144
+ is_unsupervised = not (is_classifier or is_regressor)
145
+ patching_status = PatchingConditionsChain(
146
+ f"sklearn.neighbors.{class_name}.{method_name}"
147
+ )
148
+
149
+ if not patching_status.and_condition(
150
+ not isinstance(data[0], (KDTree, BallTree, sklearn_NeighborsBase)),
151
+ f"Input type {type(data[0])} is not supported.",
152
+ ):
153
+ return patching_status
154
+
155
+ if self._fit_method in ["auto", "ball_tree"]:
156
+ condition = (
157
+ self.n_neighbors is not None
158
+ and self.n_neighbors >= self.n_samples_fit_ // 2
159
+ )
160
+ if self.n_features_in_ > 15 or condition:
161
+ result_method = "brute"
162
+ else:
163
+ if self.effective_metric_ in ["euclidean"]:
164
+ result_method = "kd_tree"
165
+ else:
166
+ result_method = "brute"
167
+ else:
168
+ result_method = self._fit_method
169
+
170
+ p_less_than_one = (
171
+ "p" in self.effective_metric_params_.keys()
172
+ and self.effective_metric_params_["p"] < 1
173
+ )
174
+ if not patching_status.and_condition(
175
+ not p_less_than_one, '"p" metric parameter is less than 1'
176
+ ):
177
+ return patching_status
178
+
179
+ if not patching_status.and_condition(
180
+ not sp.isspmatrix(data[0]), "Sparse input is not supported."
181
+ ):
182
+ return patching_status
183
+
184
+ if not is_unsupervised:
185
+ is_valid_weights = self.weights in ["uniform", "distance"]
186
+ if is_classifier:
187
+ class_count = 1
188
+ is_single_output = False
189
+ y = None
190
+ # To check multioutput, might be overhead
191
+ if len(data) > 1:
192
+ y = np.asarray(data[1])
193
+ if is_classifier:
194
+ class_count = len(np.unique(y))
195
+ if hasattr(self, "_onedal_estimator"):
196
+ y = self._onedal_estimator._y
197
+ if y is not None and hasattr(y, "ndim") and hasattr(y, "shape"):
198
+ is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
199
+
200
+ # TODO: add native support for these metric names
201
+ metrics_map = {"manhattan": ["l1", "cityblock"], "euclidean": ["l2"]}
202
+ for origin, aliases in metrics_map.items():
203
+ if self.effective_metric_ in aliases:
204
+ self.effective_metric_ = origin
205
+ break
206
+ if self.effective_metric_ == "manhattan":
207
+ self.effective_metric_params_["p"] = 1
208
+ elif self.effective_metric_ == "euclidean":
209
+ self.effective_metric_params_["p"] = 2
210
+
211
+ onedal_brute_metrics = [
212
+ "manhattan",
213
+ "minkowski",
214
+ "euclidean",
215
+ "chebyshev",
216
+ "cosine",
217
+ ]
218
+ onedal_kdtree_metrics = ["euclidean"]
219
+ is_valid_for_brute = (
220
+ result_method == "brute" and self.effective_metric_ in onedal_brute_metrics
221
+ )
222
+ is_valid_for_kd_tree = (
223
+ result_method == "kd_tree" and self.effective_metric_ in onedal_kdtree_metrics
224
+ )
225
+ if result_method == "kd_tree":
226
+ if not patching_status.and_condition(
227
+ device != "gpu", '"kd_tree" method is not supported on GPU.'
228
+ ):
229
+ return patching_status
230
+
231
+ if not patching_status.and_condition(
232
+ is_valid_for_kd_tree or is_valid_for_brute,
233
+ f"{result_method} with {self.effective_metric_} metric is not supported.",
234
+ ):
235
+ return patching_status
236
+ if not is_unsupervised:
237
+ if not patching_status.and_conditions(
238
+ [
239
+ (is_single_output, "Only single output is supported."),
240
+ (
241
+ is_valid_weights,
242
+ f'"{type(self.weights)}" weights type is not supported.',
243
+ ),
244
+ ]
245
+ ):
246
+ return patching_status
247
+ if method_name == "fit":
248
+ if is_classifier:
249
+ patching_status.and_condition(
250
+ class_count >= 2, "One-class case is not supported."
251
+ )
252
+ return patching_status
253
+ if method_name in ["predict", "predict_proba", "kneighbors"]:
254
+ patching_status.and_condition(
255
+ hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."
256
+ )
257
+ return patching_status
258
+ raise RuntimeError(f"Unknown method {method_name} in {class_name}")
259
+
260
+ def _onedal_gpu_supported(self, method_name, *data):
261
+ return self._onedal_supported("gpu", method_name, *data)
262
+
263
+ def _onedal_cpu_supported(self, method_name, *data):
264
+ return self._onedal_supported("cpu", method_name, *data)