scikit-learn-intelex 2023.2.1__py311-none-win_amd64.whl → 2024.0.1__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (109) hide show
  1. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
  2. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
  3. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
  4. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
  7. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
  8. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
  11. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
  12. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
  13. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
  14. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
  15. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
  16. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
  17. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  20. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
  21. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
  22. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
  23. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
  24. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
  25. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
  26. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
  27. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
  28. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
  29. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
  30. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
  31. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
  32. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
  33. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
  34. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
  35. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
  36. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
  37. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
  38. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
  39. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
  40. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  43. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
  44. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  46. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
  47. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
  48. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
  49. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
  50. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
  51. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
  52. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
  53. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
  54. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
  55. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
  56. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  58. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
  59. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
  60. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
  61. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
  62. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
  63. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
  64. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
  65. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
  66. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
  67. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
  68. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
  69. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
  70. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
  71. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
  72. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  75. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
  76. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
  77. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  79. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  81. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
  82. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  84. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
  85. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  86. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  87. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
  88. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
  89. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  90. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  91. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
  92. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
  93. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
  94. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
  95. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
  96. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
  97. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
  98. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
  99. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
  100. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  101. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
  102. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
  103. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
  104. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
  105. scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
  106. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  107. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  108. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  109. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ import warnings
19
+
20
+ from sklearn.neighbors._ball_tree import BallTree
21
+ from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
22
+ from sklearn.neighbors._kd_tree import KDTree
23
+
24
+ from daal4py.sklearn._utils import sklearn_check_version
25
+
26
+ if not sklearn_check_version("1.2"):
27
+ from sklearn.neighbors._base import _check_weights
28
+
29
+ import numpy as np
30
+ from sklearn.neighbors._base import VALID_METRICS
31
+ from sklearn.neighbors._classification import (
32
+ KNeighborsClassifier as sklearn_KNeighborsClassifier,
33
+ )
34
+ from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
35
+ from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
36
+
37
+ from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier
38
+ from onedal.utils import _check_array, _num_features, _num_samples
39
+
40
+ from .._device_offload import dispatch, wrap_output_data
41
+ from .common import KNeighborsDispatchingBase
42
+
43
+ if sklearn_check_version("0.24"):
44
+
45
+ class KNeighborsClassifier_(sklearn_KNeighborsClassifier):
46
+ if sklearn_check_version("1.2"):
47
+ _parameter_constraints: dict = {
48
+ **sklearn_KNeighborsClassifier._parameter_constraints
49
+ }
50
+
51
+ @_deprecate_positional_args
52
+ def __init__(
53
+ self,
54
+ n_neighbors=5,
55
+ *,
56
+ weights="uniform",
57
+ algorithm="auto",
58
+ leaf_size=30,
59
+ p=2,
60
+ metric="minkowski",
61
+ metric_params=None,
62
+ n_jobs=None,
63
+ **kwargs,
64
+ ):
65
+ super().__init__(
66
+ n_neighbors=n_neighbors,
67
+ algorithm=algorithm,
68
+ leaf_size=leaf_size,
69
+ metric=metric,
70
+ p=p,
71
+ metric_params=metric_params,
72
+ n_jobs=n_jobs,
73
+ **kwargs,
74
+ )
75
+ self.weights = (
76
+ weights if sklearn_check_version("1.0") else _check_weights(weights)
77
+ )
78
+
79
+ elif sklearn_check_version("0.22"):
80
+ from sklearn.neighbors._base import (
81
+ SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
82
+ )
83
+
84
+ class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
85
+ @_deprecate_positional_args
86
+ def __init__(
87
+ self,
88
+ n_neighbors=5,
89
+ *,
90
+ weights="uniform",
91
+ algorithm="auto",
92
+ leaf_size=30,
93
+ p=2,
94
+ metric="minkowski",
95
+ metric_params=None,
96
+ n_jobs=None,
97
+ **kwargs,
98
+ ):
99
+ super().__init__(
100
+ n_neighbors=n_neighbors,
101
+ algorithm=algorithm,
102
+ leaf_size=leaf_size,
103
+ metric=metric,
104
+ p=p,
105
+ metric_params=metric_params,
106
+ n_jobs=n_jobs,
107
+ **kwargs,
108
+ )
109
+ self.weights = _check_weights(weights)
110
+
111
+ else:
112
+ from sklearn.neighbors.base import (
113
+ SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
114
+ )
115
+
116
+ class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
117
+ @_deprecate_positional_args
118
+ def __init__(
119
+ self,
120
+ n_neighbors=5,
121
+ *,
122
+ weights="uniform",
123
+ algorithm="auto",
124
+ leaf_size=30,
125
+ p=2,
126
+ metric="minkowski",
127
+ metric_params=None,
128
+ n_jobs=None,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ n_neighbors=n_neighbors,
133
+ algorithm=algorithm,
134
+ leaf_size=leaf_size,
135
+ metric=metric,
136
+ p=p,
137
+ metric_params=metric_params,
138
+ n_jobs=n_jobs,
139
+ **kwargs,
140
+ )
141
+ self.weights = _check_weights(weights)
142
+
143
+
144
+ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
145
+ if sklearn_check_version("1.2"):
146
+ _parameter_constraints: dict = {**KNeighborsClassifier_._parameter_constraints}
147
+
148
+ if sklearn_check_version("1.0"):
149
+
150
+ def __init__(
151
+ self,
152
+ n_neighbors=5,
153
+ *,
154
+ weights="uniform",
155
+ algorithm="auto",
156
+ leaf_size=30,
157
+ p=2,
158
+ metric="minkowski",
159
+ metric_params=None,
160
+ n_jobs=None,
161
+ ):
162
+ super().__init__(
163
+ n_neighbors=n_neighbors,
164
+ weights=weights,
165
+ algorithm=algorithm,
166
+ leaf_size=leaf_size,
167
+ metric=metric,
168
+ p=p,
169
+ metric_params=metric_params,
170
+ n_jobs=n_jobs,
171
+ )
172
+
173
+ else:
174
+
175
+ @_deprecate_positional_args
176
+ def __init__(
177
+ self,
178
+ n_neighbors=5,
179
+ *,
180
+ weights="uniform",
181
+ algorithm="auto",
182
+ leaf_size=30,
183
+ p=2,
184
+ metric="minkowski",
185
+ metric_params=None,
186
+ n_jobs=None,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(
190
+ n_neighbors=n_neighbors,
191
+ weights=weights,
192
+ algorithm=algorithm,
193
+ leaf_size=leaf_size,
194
+ metric=metric,
195
+ p=p,
196
+ metric_params=metric_params,
197
+ n_jobs=n_jobs,
198
+ **kwargs,
199
+ )
200
+
201
+ def fit(self, X, y):
202
+ self._fit_validation(X, y)
203
+ dispatch(
204
+ self,
205
+ "fit",
206
+ {
207
+ "onedal": self.__class__._onedal_fit,
208
+ "sklearn": sklearn_KNeighborsClassifier.fit,
209
+ },
210
+ X,
211
+ y,
212
+ )
213
+ return self
214
+
215
+ @wrap_output_data
216
+ def predict(self, X):
217
+ check_is_fitted(self)
218
+ if sklearn_check_version("1.0"):
219
+ self._check_feature_names(X, reset=False)
220
+ return dispatch(
221
+ self,
222
+ "predict",
223
+ {
224
+ "onedal": self.__class__._onedal_predict,
225
+ "sklearn": sklearn_KNeighborsClassifier.predict,
226
+ },
227
+ X,
228
+ )
229
+
230
+ @wrap_output_data
231
+ def predict_proba(self, X):
232
+ check_is_fitted(self)
233
+ if sklearn_check_version("1.0"):
234
+ self._check_feature_names(X, reset=False)
235
+ return dispatch(
236
+ self,
237
+ "predict_proba",
238
+ {
239
+ "onedal": self.__class__._onedal_predict_proba,
240
+ "sklearn": sklearn_KNeighborsClassifier.predict_proba,
241
+ },
242
+ X,
243
+ )
244
+
245
+ @wrap_output_data
246
+ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
247
+ check_is_fitted(self)
248
+ if sklearn_check_version("1.0"):
249
+ self._check_feature_names(X, reset=False)
250
+ return dispatch(
251
+ self,
252
+ "kneighbors",
253
+ {
254
+ "onedal": self.__class__._onedal_kneighbors,
255
+ "sklearn": sklearn_KNeighborsClassifier.kneighbors,
256
+ },
257
+ X,
258
+ n_neighbors,
259
+ return_distance,
260
+ )
261
+
262
+ @wrap_output_data
263
+ def radius_neighbors(
264
+ self, X=None, radius=None, return_distance=True, sort_results=False
265
+ ):
266
+ _onedal_estimator = getattr(self, "_onedal_estimator", None)
267
+
268
+ if (
269
+ _onedal_estimator is not None
270
+ or getattr(self, "_tree", 0) is None
271
+ and self._fit_method == "kd_tree"
272
+ ):
273
+ if sklearn_check_version("0.24"):
274
+ sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
275
+ else:
276
+ sklearn_NearestNeighbors.fit(self, self._fit_X)
277
+ if sklearn_check_version("0.22"):
278
+ result = sklearn_NearestNeighbors.radius_neighbors(
279
+ self, X, radius, return_distance, sort_results
280
+ )
281
+ else:
282
+ result = sklearn_NearestNeighbors.radius_neighbors(
283
+ self, X, radius, return_distance
284
+ )
285
+
286
+ return result
287
+
288
+ def _onedal_fit(self, X, y, queue=None):
289
+ onedal_params = {
290
+ "n_neighbors": self.n_neighbors,
291
+ "weights": self.weights,
292
+ "algorithm": self.algorithm,
293
+ "metric": self.effective_metric_,
294
+ "p": self.effective_metric_params_["p"],
295
+ }
296
+
297
+ try:
298
+ requires_y = self._get_tags()["requires_y"]
299
+ except KeyError:
300
+ requires_y = False
301
+
302
+ self._onedal_estimator = onedal_KNeighborsClassifier(**onedal_params)
303
+ self._onedal_estimator.requires_y = requires_y
304
+ self._onedal_estimator.effective_metric_ = self.effective_metric_
305
+ self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
306
+ self._onedal_estimator.fit(X, y, queue=queue)
307
+
308
+ self._save_attributes()
309
+
310
+ def _onedal_predict(self, X, queue=None):
311
+ return self._onedal_estimator.predict(X, queue=queue)
312
+
313
+ def _onedal_predict_proba(self, X, queue=None):
314
+ return self._onedal_estimator.predict_proba(X, queue=queue)
315
+
316
+ def _onedal_kneighbors(
317
+ self, X=None, n_neighbors=None, return_distance=True, queue=None
318
+ ):
319
+ return self._onedal_estimator.kneighbors(
320
+ X, n_neighbors, return_distance, queue=queue
321
+ )
322
+
323
+ def _save_attributes(self):
324
+ self.classes_ = self._onedal_estimator.classes_
325
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
326
+ self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
327
+ self._fit_X = self._onedal_estimator._fit_X
328
+ self._y = self._onedal_estimator._y
329
+ self._fit_method = self._onedal_estimator._fit_method
330
+ self.outputs_2d_ = self._onedal_estimator.outputs_2d_
331
+ self._tree = self._onedal_estimator._tree
@@ -0,0 +1,307 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ import warnings
19
+
20
+ from sklearn.neighbors._ball_tree import BallTree
21
+ from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
22
+ from sklearn.neighbors._kd_tree import KDTree
23
+
24
+ from daal4py.sklearn._utils import sklearn_check_version
25
+
26
+ if not sklearn_check_version("1.2"):
27
+ from sklearn.neighbors._base import _check_weights
28
+
29
+ import numpy as np
30
+ from sklearn.neighbors._base import VALID_METRICS
31
+ from sklearn.neighbors._regression import (
32
+ KNeighborsRegressor as sklearn_KNeighborsRegressor,
33
+ )
34
+ from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
35
+ from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
36
+
37
+ from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor
38
+ from onedal.utils import _check_array, _num_features, _num_samples
39
+
40
+ from .._device_offload import dispatch, wrap_output_data
41
+ from .common import KNeighborsDispatchingBase
42
+
43
+ if sklearn_check_version("0.24"):
44
+
45
+ class KNeighborsRegressor_(sklearn_KNeighborsRegressor):
46
+ if sklearn_check_version("1.2"):
47
+ _parameter_constraints: dict = {
48
+ **sklearn_KNeighborsRegressor._parameter_constraints
49
+ }
50
+
51
+ @_deprecate_positional_args
52
+ def __init__(
53
+ self,
54
+ n_neighbors=5,
55
+ *,
56
+ weights="uniform",
57
+ algorithm="auto",
58
+ leaf_size=30,
59
+ p=2,
60
+ metric="minkowski",
61
+ metric_params=None,
62
+ n_jobs=None,
63
+ **kwargs,
64
+ ):
65
+ super().__init__(
66
+ n_neighbors=n_neighbors,
67
+ algorithm=algorithm,
68
+ leaf_size=leaf_size,
69
+ metric=metric,
70
+ p=p,
71
+ metric_params=metric_params,
72
+ n_jobs=n_jobs,
73
+ **kwargs,
74
+ )
75
+ self.weights = (
76
+ weights if sklearn_check_version("1.0") else _check_weights(weights)
77
+ )
78
+
79
+ elif sklearn_check_version("0.22"):
80
+ from sklearn.neighbors._base import SupervisedFloatMixin as BaseSupervisedFloatMixin
81
+
82
+ class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
83
+ @_deprecate_positional_args
84
+ def __init__(
85
+ self,
86
+ n_neighbors=5,
87
+ *,
88
+ weights="uniform",
89
+ algorithm="auto",
90
+ leaf_size=30,
91
+ p=2,
92
+ metric="minkowski",
93
+ metric_params=None,
94
+ n_jobs=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(
98
+ n_neighbors=n_neighbors,
99
+ algorithm=algorithm,
100
+ leaf_size=leaf_size,
101
+ metric=metric,
102
+ p=p,
103
+ metric_params=metric_params,
104
+ n_jobs=n_jobs,
105
+ **kwargs,
106
+ )
107
+ self.weights = _check_weights(weights)
108
+
109
+ else:
110
+ from sklearn.neighbors.base import SupervisedFloatMixin as BaseSupervisedFloatMixin
111
+
112
+ class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
113
+ @_deprecate_positional_args
114
+ def __init__(
115
+ self,
116
+ n_neighbors=5,
117
+ *,
118
+ weights="uniform",
119
+ algorithm="auto",
120
+ leaf_size=30,
121
+ p=2,
122
+ metric="minkowski",
123
+ metric_params=None,
124
+ n_jobs=None,
125
+ **kwargs,
126
+ ):
127
+ super().__init__(
128
+ n_neighbors=n_neighbors,
129
+ algorithm=algorithm,
130
+ leaf_size=leaf_size,
131
+ metric=metric,
132
+ p=p,
133
+ metric_params=metric_params,
134
+ n_jobs=n_jobs,
135
+ **kwargs,
136
+ )
137
+ self.weights = _check_weights(weights)
138
+
139
+
140
+ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
141
+ if sklearn_check_version("1.2"):
142
+ _parameter_constraints: dict = {**KNeighborsRegressor_._parameter_constraints}
143
+
144
+ if sklearn_check_version("1.0"):
145
+
146
+ def __init__(
147
+ self,
148
+ n_neighbors=5,
149
+ *,
150
+ weights="uniform",
151
+ algorithm="auto",
152
+ leaf_size=30,
153
+ p=2,
154
+ metric="minkowski",
155
+ metric_params=None,
156
+ n_jobs=None,
157
+ ):
158
+ super().__init__(
159
+ n_neighbors=n_neighbors,
160
+ weights=weights,
161
+ algorithm=algorithm,
162
+ leaf_size=leaf_size,
163
+ metric=metric,
164
+ p=p,
165
+ metric_params=metric_params,
166
+ n_jobs=n_jobs,
167
+ )
168
+
169
+ else:
170
+
171
+ @_deprecate_positional_args
172
+ def __init__(
173
+ self,
174
+ n_neighbors=5,
175
+ *,
176
+ weights="uniform",
177
+ algorithm="auto",
178
+ leaf_size=30,
179
+ p=2,
180
+ metric="minkowski",
181
+ metric_params=None,
182
+ n_jobs=None,
183
+ **kwargs,
184
+ ):
185
+ super().__init__(
186
+ n_neighbors=n_neighbors,
187
+ weights=weights,
188
+ algorithm=algorithm,
189
+ leaf_size=leaf_size,
190
+ metric=metric,
191
+ p=p,
192
+ metric_params=metric_params,
193
+ n_jobs=n_jobs,
194
+ **kwargs,
195
+ )
196
+
197
+ def fit(self, X, y):
198
+ self._fit_validation(X, y)
199
+ dispatch(
200
+ self,
201
+ "fit",
202
+ {
203
+ "onedal": self.__class__._onedal_fit,
204
+ "sklearn": sklearn_KNeighborsRegressor.fit,
205
+ },
206
+ X,
207
+ y,
208
+ )
209
+ return self
210
+
211
+ @wrap_output_data
212
+ def predict(self, X):
213
+ check_is_fitted(self)
214
+ if sklearn_check_version("1.0"):
215
+ self._check_feature_names(X, reset=False)
216
+ return dispatch(
217
+ self,
218
+ "predict",
219
+ {
220
+ "onedal": self.__class__._onedal_predict,
221
+ "sklearn": sklearn_KNeighborsRegressor.predict,
222
+ },
223
+ X,
224
+ )
225
+
226
+ @wrap_output_data
227
+ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
228
+ check_is_fitted(self)
229
+ if sklearn_check_version("1.0"):
230
+ self._check_feature_names(X, reset=False)
231
+ return dispatch(
232
+ self,
233
+ "kneighbors",
234
+ {
235
+ "onedal": self.__class__._onedal_kneighbors,
236
+ "sklearn": sklearn_KNeighborsRegressor.kneighbors,
237
+ },
238
+ X,
239
+ n_neighbors,
240
+ return_distance,
241
+ )
242
+
243
+ @wrap_output_data
244
+ def radius_neighbors(
245
+ self, X=None, radius=None, return_distance=True, sort_results=False
246
+ ):
247
+ _onedal_estimator = getattr(self, "_onedal_estimator", None)
248
+
249
+ if (
250
+ _onedal_estimator is not None
251
+ or getattr(self, "_tree", 0) is None
252
+ and self._fit_method == "kd_tree"
253
+ ):
254
+ if sklearn_check_version("0.24"):
255
+ sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
256
+ else:
257
+ sklearn_NearestNeighbors.fit(self, self._fit_X)
258
+ if sklearn_check_version("0.22"):
259
+ result = sklearn_NearestNeighbors.radius_neighbors(
260
+ self, X, radius, return_distance, sort_results
261
+ )
262
+ else:
263
+ result = sklearn_NearestNeighbors.radius_neighbors(
264
+ self, X, radius, return_distance
265
+ )
266
+
267
+ return result
268
+
269
+ def _onedal_fit(self, X, y, queue=None):
270
+ onedal_params = {
271
+ "n_neighbors": self.n_neighbors,
272
+ "weights": self.weights,
273
+ "algorithm": self.algorithm,
274
+ "metric": self.effective_metric_,
275
+ "p": self.effective_metric_params_["p"],
276
+ }
277
+
278
+ try:
279
+ requires_y = self._get_tags()["requires_y"]
280
+ except KeyError:
281
+ requires_y = False
282
+
283
+ self._onedal_estimator = onedal_KNeighborsRegressor(**onedal_params)
284
+ self._onedal_estimator.requires_y = requires_y
285
+ self._onedal_estimator.effective_metric_ = self.effective_metric_
286
+ self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
287
+ self._onedal_estimator.fit(X, y, queue=queue)
288
+
289
+ self._save_attributes()
290
+
291
+ def _onedal_predict(self, X, queue=None):
292
+ return self._onedal_estimator.predict(X, queue=queue)
293
+
294
+ def _onedal_kneighbors(
295
+ self, X=None, n_neighbors=None, return_distance=True, queue=None
296
+ ):
297
+ return self._onedal_estimator.kneighbors(
298
+ X, n_neighbors, return_distance, queue=queue
299
+ )
300
+
301
+ def _save_attributes(self):
302
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
303
+ self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
304
+ self._fit_X = self._onedal_estimator._fit_X
305
+ self._y = self._onedal_estimator._y
306
+ self._fit_method = self._onedal_estimator._fit_method
307
+ self._tree = self._onedal_estimator._tree