scikit-learn-intelex 2023.2.1__py39-none-win_amd64.whl → 2024.0.1__py39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (109) hide show
  1. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
  2. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
  3. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
  4. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
  7. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
  8. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
  11. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
  12. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
  13. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
  14. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
  15. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
  16. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
  17. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  20. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
  21. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
  22. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
  23. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
  24. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
  25. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
  26. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
  27. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
  28. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
  29. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
  30. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
  31. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
  32. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
  33. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
  34. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
  35. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
  36. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
  37. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
  38. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
  39. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
  40. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  43. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
  44. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  46. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
  47. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
  48. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
  49. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
  50. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
  51. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
  52. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
  53. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
  54. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
  55. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
  56. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  58. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
  59. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
  60. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
  61. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
  62. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
  63. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
  64. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
  65. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
  66. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
  67. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
  68. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
  69. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
  70. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
  71. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
  72. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  75. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
  76. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
  77. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  79. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  81. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
  82. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  84. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
  85. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  86. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  87. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
  88. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
  89. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  90. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  91. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
  92. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
  93. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
  94. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
  95. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
  96. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
  97. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
  98. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
  99. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
  100. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  101. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
  102. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
  103. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
  104. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
  105. scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
  106. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  107. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  108. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  109. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  import numpy as np
19
19
  from numpy.testing import assert_allclose
@@ -21,18 +21,20 @@ from sklearn.datasets import load_breast_cancer
21
21
 
22
22
 
23
23
  def test_sklearnex_import_roc_auc():
24
- from sklearnex.metrics import roc_auc_score
25
24
  from sklearnex.linear_model import LogisticRegression
25
+ from sklearnex.metrics import roc_auc_score
26
+
26
27
  X, y = load_breast_cancer(return_X_y=True)
27
- clf = LogisticRegression(solver='liblinear', random_state=0).fit(X, y)
28
+ clf = LogisticRegression(solver="liblinear", random_state=0).fit(X, y)
28
29
  res = roc_auc_score(y, clf.decision_function(X))
29
30
  assert_allclose(res, 0.99, atol=1e-2)
30
31
 
31
32
 
32
33
  def test_sklearnex_import_pairwise_distances():
33
34
  from sklearnex.metrics import pairwise_distances
35
+
34
36
  rng = np.random.RandomState(0)
35
37
  x = np.abs(rng.rand(4), dtype=np.float64)
36
38
  x = np.vstack([x, x])
37
- res = pairwise_distances(x, metric='cosine')
38
- assert_allclose(res, [[0., 0.], [0., 0.]], atol=1e-2)
39
+ res = pairwise_distances(x, metric="cosine")
40
+ assert_allclose(res, [[0.0, 0.0], [0.0, 0.0]], atol=1e-2)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,10 +13,10 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from .split import train_test_split
19
19
 
20
20
  __all__ = [
21
- 'train_test_split',
21
+ "train_test_split",
22
22
  ]
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from daal4py.sklearn.model_selection import train_test_split
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,18 +13,21 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  import numpy as np
19
19
  from numpy.testing import assert_allclose
20
20
 
21
21
 
22
+ # TODO:
23
+ # add pytest params for checking different dataframe inputs/outputs.
22
24
  def test_sklearnex_import_train_test_split():
23
25
  from sklearnex.model_selection import train_test_split
26
+
24
27
  X = np.arange(100).reshape((10, 10))
25
28
  y = np.arange(10)
26
29
 
27
- split = train_test_split(X, y, test_size=None, train_size=.5)
30
+ split = train_test_split(X, y, test_size=None, train_size=0.5)
28
31
  X_train, X_test, y_train, y_test = split
29
32
  assert len(y_test) == len(y_train)
30
33
 
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,12 +13,16 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from .knn_classification import KNeighborsClassifier
19
- from .knn_unsupervised import NearestNeighbors
20
19
  from .knn_regression import KNeighborsRegressor
20
+ from .knn_unsupervised import NearestNeighbors
21
21
  from .lof import LocalOutlierFactor
22
22
 
23
- __all__ = ['KNeighborsClassifier', 'KNeighborsRegressor', 'LocalOutlierFactor',
24
- 'NearestNeighbors']
23
+ __all__ = [
24
+ "KNeighborsClassifier",
25
+ "KNeighborsRegressor",
26
+ "LocalOutlierFactor",
27
+ "NearestNeighbors",
28
+ ]
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ==============================================================================
3
3
  # Copyright 2023 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,20 +13,22 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ==============================================================================
17
17
 
18
- from daal4py.sklearn._utils import PatchingConditionsChain, sklearn_check_version
19
- from onedal.datatypes import _check_array, _num_features, _num_samples
18
+ import warnings
20
19
 
21
20
  import numpy as np
22
21
  from scipy import sparse as sp
23
- import warnings
24
-
22
+ from sklearn.neighbors._ball_tree import BallTree
25
23
  from sklearn.neighbors._base import VALID_METRICS
26
24
  from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
27
- from sklearn.neighbors._ball_tree import BallTree
28
25
  from sklearn.neighbors._kd_tree import KDTree
29
26
 
27
+ from daal4py.sklearn._utils import sklearn_check_version
28
+ from onedal.utils import _check_array, _num_features, _num_samples
29
+
30
+ from .._utils import PatchingConditionsChain
31
+
30
32
 
31
33
  class KNeighborsDispatchingBase:
32
34
  def _fit_validation(self, X, y=None):
@@ -34,11 +36,15 @@ class KNeighborsDispatchingBase:
34
36
  self._validate_params()
35
37
  if sklearn_check_version("1.0"):
36
38
  self._check_feature_names(X, reset=True)
37
- if self.metric_params is not None and 'p' in self.metric_params:
39
+ if self.metric_params is not None and "p" in self.metric_params:
38
40
  if self.p is not None:
39
- warnings.warn("Parameter p is found in metric_params. "
40
- "The corresponding parameter from __init__ "
41
- "is ignored.", SyntaxWarning, stacklevel=2)
41
+ warnings.warn(
42
+ "Parameter p is found in metric_params. "
43
+ "The corresponding parameter from __init__ "
44
+ "is ignored.",
45
+ SyntaxWarning,
46
+ stacklevel=2,
47
+ )
42
48
  self.effective_metric_params_ = self.metric_params.copy()
43
49
  effective_p = self.metric_params["p"]
44
50
  else:
@@ -59,31 +65,35 @@ class KNeighborsDispatchingBase:
59
65
 
60
66
  if not isinstance(X, (KDTree, BallTree, sklearn_NeighborsBase)):
61
67
  self._fit_X = _check_array(
62
- X, dtype=[np.float64, np.float32], accept_sparse=True)
68
+ X, dtype=[np.float64, np.float32], accept_sparse=True
69
+ )
63
70
  self.n_samples_fit_ = _num_samples(self._fit_X)
64
71
  self.n_features_in_ = _num_features(self._fit_X)
65
72
 
66
73
  if self.algorithm == "auto":
67
74
  # A tree approach is better for small number of neighbors or small
68
75
  # number of features, with KDTree generally faster when available
69
- is_n_neighbors_valid_for_brute = self.n_neighbors is not None and \
70
- self.n_neighbors >= self._fit_X.shape[0] // 2
76
+ is_n_neighbors_valid_for_brute = (
77
+ self.n_neighbors is not None
78
+ and self.n_neighbors >= self._fit_X.shape[0] // 2
79
+ )
71
80
  if self._fit_X.shape[1] > 15 or is_n_neighbors_valid_for_brute:
72
81
  self._fit_method = "brute"
73
82
  else:
74
83
  if self.effective_metric_ in VALID_METRICS["kd_tree"]:
75
84
  self._fit_method = "kd_tree"
76
- elif callable(self.effective_metric_) or \
77
- self.effective_metric_ in \
78
- VALID_METRICS["ball_tree"]:
85
+ elif (
86
+ callable(self.effective_metric_)
87
+ or self.effective_metric_ in VALID_METRICS["ball_tree"]
88
+ ):
79
89
  self._fit_method = "ball_tree"
80
90
  else:
81
91
  self._fit_method = "brute"
82
92
  else:
83
93
  self._fit_method = self.algorithm
84
94
 
85
- if hasattr(self, '_onedal_estimator'):
86
- delattr(self, '_onedal_estimator')
95
+ if hasattr(self, "_onedal_estimator"):
96
+ delattr(self, "_onedal_estimator")
87
97
  # To cover test case when we pass patched
88
98
  # estimator as an input for other estimator
89
99
  if isinstance(X, sklearn_NeighborsBase):
@@ -92,8 +102,8 @@ class KNeighborsDispatchingBase:
92
102
  self._fit_method = X._fit_method
93
103
  self.n_samples_fit_ = X.n_samples_fit_
94
104
  self.n_features_in_ = X.n_features_in_
95
- if hasattr(X, '_onedal_estimator'):
96
- self.effective_metric_params_.pop('p')
105
+ if hasattr(X, "_onedal_estimator"):
106
+ self.effective_metric_params_.pop("p")
97
107
  if self._fit_method == "ball_tree":
98
108
  X._tree = BallTree(
99
109
  X._fit_X,
@@ -116,58 +126,63 @@ class KNeighborsDispatchingBase:
116
126
  elif isinstance(X, BallTree):
117
127
  self._fit_X = X.data
118
128
  self._tree = X
119
- self._fit_method = 'ball_tree'
129
+ self._fit_method = "ball_tree"
120
130
  self.n_samples_fit_ = X.data.shape[0]
121
131
  self.n_features_in_ = X.data.shape[1]
122
132
 
123
133
  elif isinstance(X, KDTree):
124
134
  self._fit_X = X.data
125
135
  self._tree = X
126
- self._fit_method = 'kd_tree'
136
+ self._fit_method = "kd_tree"
127
137
  self.n_samples_fit_ = X.data.shape[0]
128
138
  self.n_features_in_ = X.data.shape[1]
129
139
 
130
140
  def _onedal_supported(self, device, method_name, *data):
131
141
  class_name = self.__class__.__name__
132
- is_classifier = 'Classifier' in class_name
133
- is_regressor = 'Regressor' in class_name
142
+ is_classifier = "Classifier" in class_name
143
+ is_regressor = "Regressor" in class_name
134
144
  is_unsupervised = not (is_classifier or is_regressor)
135
145
  patching_status = PatchingConditionsChain(
136
- f'sklearn.neighbors.{class_name}.{method_name}')
146
+ f"sklearn.neighbors.{class_name}.{method_name}"
147
+ )
137
148
 
138
149
  if not patching_status.and_condition(
139
150
  not isinstance(data[0], (KDTree, BallTree, sklearn_NeighborsBase)),
140
- f'Input type {type(data[0])} is not supported.'
151
+ f"Input type {type(data[0])} is not supported.",
141
152
  ):
142
- return patching_status.get_status(logs=True)
153
+ return patching_status
143
154
 
144
- if self._fit_method in ['auto', 'ball_tree']:
145
- condition = self.n_neighbors is not None and \
146
- self.n_neighbors >= self.n_samples_fit_ // 2
155
+ if self._fit_method in ["auto", "ball_tree"]:
156
+ condition = (
157
+ self.n_neighbors is not None
158
+ and self.n_neighbors >= self.n_samples_fit_ // 2
159
+ )
147
160
  if self.n_features_in_ > 15 or condition:
148
- result_method = 'brute'
161
+ result_method = "brute"
149
162
  else:
150
- if self.effective_metric_ in ['euclidean']:
151
- result_method = 'kd_tree'
163
+ if self.effective_metric_ in ["euclidean"]:
164
+ result_method = "kd_tree"
152
165
  else:
153
- result_method = 'brute'
166
+ result_method = "brute"
154
167
  else:
155
168
  result_method = self._fit_method
156
169
 
157
- p_less_than_one = "p" in self.effective_metric_params_.keys() and \
158
- self.effective_metric_params_["p"] < 1
170
+ p_less_than_one = (
171
+ "p" in self.effective_metric_params_.keys()
172
+ and self.effective_metric_params_["p"] < 1
173
+ )
159
174
  if not patching_status.and_condition(
160
175
  not p_less_than_one, '"p" metric parameter is less than 1'
161
176
  ):
162
- return patching_status.get_status(logs=True)
177
+ return patching_status
163
178
 
164
179
  if not patching_status.and_condition(
165
- not sp.isspmatrix(data[0]), 'Sparse input is not supported.'
180
+ not sp.isspmatrix(data[0]), "Sparse input is not supported."
166
181
  ):
167
- return patching_status.get_status(logs=True)
182
+ return patching_status
168
183
 
169
184
  if not is_unsupervised:
170
- is_valid_weights = self.weights in ['uniform', "distance"]
185
+ is_valid_weights = self.weights in ["uniform", "distance"]
171
186
  if is_classifier:
172
187
  class_count = 1
173
188
  is_single_output = False
@@ -177,65 +192,73 @@ class KNeighborsDispatchingBase:
177
192
  y = np.asarray(data[1])
178
193
  if is_classifier:
179
194
  class_count = len(np.unique(y))
180
- if hasattr(self, '_onedal_estimator'):
195
+ if hasattr(self, "_onedal_estimator"):
181
196
  y = self._onedal_estimator._y
182
- if y is not None and hasattr(y, 'ndim') and hasattr(y, 'shape'):
197
+ if y is not None and hasattr(y, "ndim") and hasattr(y, "shape"):
183
198
  is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1
184
199
 
185
200
  # TODO: add native support for these metric names
186
- metrics_map = {
187
- 'manhattan': ['l1', 'cityblock'],
188
- 'euclidean': ['l2']
189
- }
201
+ metrics_map = {"manhattan": ["l1", "cityblock"], "euclidean": ["l2"]}
190
202
  for origin, aliases in metrics_map.items():
191
203
  if self.effective_metric_ in aliases:
192
204
  self.effective_metric_ = origin
193
205
  break
194
- if self.effective_metric_ == 'manhattan':
195
- self.effective_metric_params_['p'] = 1
196
- elif self.effective_metric_ == 'euclidean':
197
- self.effective_metric_params_['p'] = 2
206
+ if self.effective_metric_ == "manhattan":
207
+ self.effective_metric_params_["p"] = 1
208
+ elif self.effective_metric_ == "euclidean":
209
+ self.effective_metric_params_["p"] = 2
198
210
 
199
211
  onedal_brute_metrics = [
200
- 'manhattan', 'minkowski', 'euclidean', 'chebyshev', 'cosine']
201
- onedal_kdtree_metrics = ['euclidean']
202
- is_valid_for_brute = result_method == 'brute' and \
203
- self.effective_metric_ in onedal_brute_metrics
204
- is_valid_for_kd_tree = result_method == 'kd_tree' and \
205
- self.effective_metric_ in onedal_kdtree_metrics
206
- if result_method == 'kd_tree':
212
+ "manhattan",
213
+ "minkowski",
214
+ "euclidean",
215
+ "chebyshev",
216
+ "cosine",
217
+ ]
218
+ onedal_kdtree_metrics = ["euclidean"]
219
+ is_valid_for_brute = (
220
+ result_method == "brute" and self.effective_metric_ in onedal_brute_metrics
221
+ )
222
+ is_valid_for_kd_tree = (
223
+ result_method == "kd_tree" and self.effective_metric_ in onedal_kdtree_metrics
224
+ )
225
+ if result_method == "kd_tree":
207
226
  if not patching_status.and_condition(
208
- device != 'gpu', '"kd_tree" method is not supported on GPU.'
227
+ device != "gpu", '"kd_tree" method is not supported on GPU.'
209
228
  ):
210
- return patching_status.get_status(logs=True)
229
+ return patching_status
211
230
 
212
231
  if not patching_status.and_condition(
213
232
  is_valid_for_kd_tree or is_valid_for_brute,
214
- f'{result_method} with {self.effective_metric_} metric is not supported.'
233
+ f"{result_method} with {self.effective_metric_} metric is not supported.",
215
234
  ):
216
- return patching_status.get_status(logs=True)
235
+ return patching_status
217
236
  if not is_unsupervised:
218
- if not patching_status.and_conditions([
219
- (is_single_output, 'Only single output is supported.'),
220
- (is_valid_weights,
221
- f'"{type(self.weights)}" weights type is not supported.')
222
- ]):
223
- return patching_status.get_status(logs=True)
224
- if method_name == 'fit':
237
+ if not patching_status.and_conditions(
238
+ [
239
+ (is_single_output, "Only single output is supported."),
240
+ (
241
+ is_valid_weights,
242
+ f'"{type(self.weights)}" weights type is not supported.',
243
+ ),
244
+ ]
245
+ ):
246
+ return patching_status
247
+ if method_name == "fit":
225
248
  if is_classifier:
226
249
  patching_status.and_condition(
227
- class_count >= 2, 'One-class case is not supported.'
250
+ class_count >= 2, "One-class case is not supported."
228
251
  )
229
- return patching_status.get_status(logs=True)
230
- if method_name in ['predict', 'predict_proba', 'kneighbors']:
252
+ return patching_status
253
+ if method_name in ["predict", "predict_proba", "kneighbors"]:
231
254
  patching_status.and_condition(
232
- hasattr(self, '_onedal_estimator'), 'oneDAL model was not trained.'
255
+ hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."
233
256
  )
234
- return patching_status.get_status(logs=True)
235
- raise RuntimeError(f'Unknown method {method_name} in {class_name}')
257
+ return patching_status
258
+ raise RuntimeError(f"Unknown method {method_name} in {class_name}")
236
259
 
237
260
  def _onedal_gpu_supported(self, method_name, *data):
238
- return self._onedal_supported('gpu', method_name, *data)
261
+ return self._onedal_supported("gpu", method_name, *data)
239
262
 
240
263
  def _onedal_cpu_supported(self, method_name, *data):
241
- return self._onedal_supported('cpu', method_name, *data)
264
+ return self._onedal_supported("cpu", method_name, *data)