scikit-learn-intelex 2024.3.0__py39-none-win_amd64.whl → 2024.5.0__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +39 -5
- {scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex}/basic_statistics/__init__.py +2 -1
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +317 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +54 -17
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +71 -19
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +2 -2
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +33 -2
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +73 -79
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +5 -3
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +387 -0
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +316 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +50 -9
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +200 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +40 -5
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +53 -36
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +4 -1
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +37 -122
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +10 -117
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +6 -78
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +2 -2
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +5 -73
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +6 -5
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +4 -7
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +66 -50
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -49
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +66 -51
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -49
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +34 -16
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +5 -1
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +12 -2
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +87 -58
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +1 -1
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +2 -1
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +97 -0
- scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
- {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/METADATA +227 -230
- scikit_learn_intelex-2024.5.0.dist-info/RECORD +104 -0
- {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
- scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
- scikit_learn_intelex-2024.3.0.dist-info/RECORD +0 -98
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/conftest.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/spmd}/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.3.0.data → scikit_learn_intelex-2024.5.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.3.0.dist-info → scikit_learn_intelex-2024.5.0.dist-info}/top_level.txt +0 -0
|
@@ -28,26 +28,33 @@ from onedal.tests.utils._dataframes_support import (
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
31
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
31
32
|
@pytest.mark.parametrize("macro_block", [None, 1024])
|
|
32
|
-
def test_sklearnex_import_linear(dataframe, queue, macro_block):
|
|
33
|
+
def test_sklearnex_import_linear(dataframe, queue, dtype, macro_block):
|
|
33
34
|
from sklearnex.linear_model import LinearRegression
|
|
34
35
|
|
|
35
36
|
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
|
|
36
37
|
y = np.dot(X, np.array([1, 2])) + 3
|
|
38
|
+
X = X.astype(dtype=dtype)
|
|
39
|
+
y = y.astype(dtype=dtype)
|
|
37
40
|
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
|
|
38
41
|
y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
|
|
42
|
+
|
|
39
43
|
linreg = LinearRegression()
|
|
40
44
|
if daal_check_version((2024, "P", 0)) and macro_block is not None:
|
|
41
45
|
hparams = linreg.get_hyperparameters("fit")
|
|
42
46
|
hparams.cpu_macro_block = macro_block
|
|
43
47
|
hparams.gpu_macro_block = macro_block
|
|
48
|
+
|
|
44
49
|
linreg.fit(X, y)
|
|
45
|
-
|
|
46
|
-
|
|
50
|
+
|
|
51
|
+
assert hasattr(linreg, "_onedal_estimator")
|
|
47
52
|
assert "sklearnex" in linreg.__module__
|
|
48
53
|
assert linreg.n_features_in_ == 2
|
|
49
|
-
|
|
50
|
-
|
|
54
|
+
|
|
55
|
+
tol = 1e-5 if dtype == np.float32 else 1e-7
|
|
56
|
+
assert_allclose(_as_numpy(linreg.intercept_), 3.0, rtol=tol)
|
|
57
|
+
assert_allclose(_as_numpy(linreg.coef_), [1.0, 2.0], rtol=tol)
|
|
51
58
|
|
|
52
59
|
|
|
53
60
|
def test_sklearnex_import_ridge():
|
|
@@ -80,3 +87,31 @@ def test_sklearnex_import_elastic():
|
|
|
80
87
|
assert "daal4py" in elasticnet.__module__
|
|
81
88
|
assert_allclose(elasticnet.intercept_, 1.451, atol=1e-3)
|
|
82
89
|
assert_allclose(elasticnet.coef_, [18.838, 64.559], atol=1e-3)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
93
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
94
|
+
def test_sklearnex_reconstruct_model(dataframe, queue, dtype):
|
|
95
|
+
from sklearnex.linear_model import LinearRegression
|
|
96
|
+
|
|
97
|
+
seed = 42
|
|
98
|
+
num_samples = 3500
|
|
99
|
+
num_features, num_targets = 14, 9
|
|
100
|
+
|
|
101
|
+
gen = np.random.default_rng(seed)
|
|
102
|
+
intercept = gen.random(size=num_targets, dtype=dtype)
|
|
103
|
+
coef = gen.random(size=(num_targets, num_features), dtype=dtype).T
|
|
104
|
+
|
|
105
|
+
X = gen.random(size=(num_samples, num_features), dtype=dtype)
|
|
106
|
+
gtr = X @ coef + intercept[np.newaxis, :]
|
|
107
|
+
|
|
108
|
+
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
|
|
109
|
+
|
|
110
|
+
linreg = LinearRegression(fit_intercept=True)
|
|
111
|
+
linreg.coef_ = coef.T
|
|
112
|
+
linreg.intercept_ = intercept
|
|
113
|
+
|
|
114
|
+
y_pred = linreg.predict(X)
|
|
115
|
+
|
|
116
|
+
tol = 1e-5 if dtype == np.float32 else 1e-7
|
|
117
|
+
assert_allclose(gtr, _as_numpy(y_pred), rtol=tol)
|
|
@@ -23,13 +23,13 @@ from sklearn.utils.validation import check_is_fitted
|
|
|
23
23
|
|
|
24
24
|
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
25
25
|
from daal4py.sklearn._utils import sklearn_check_version
|
|
26
|
+
from sklearnex._device_offload import dispatch, wrap_output_data
|
|
27
|
+
from sklearnex.neighbors.common import KNeighborsDispatchingBase
|
|
28
|
+
from sklearnex.neighbors.knn_unsupervised import NearestNeighbors
|
|
29
|
+
from sklearnex.utils import get_namespace
|
|
26
30
|
|
|
27
|
-
from .._device_offload import dispatch, wrap_output_data
|
|
28
|
-
from .common import KNeighborsDispatchingBase
|
|
29
|
-
from .knn_unsupervised import NearestNeighbors
|
|
30
31
|
|
|
31
|
-
|
|
32
|
-
@control_n_jobs(decorated_methods=["fit", "kneighbors"])
|
|
32
|
+
@control_n_jobs(decorated_methods=["fit", "_kneighbors"])
|
|
33
33
|
class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
|
|
34
34
|
__doc__ = (
|
|
35
35
|
sklearn_LocalOutlierFactor.__doc__
|
|
@@ -100,7 +100,6 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
|
|
|
100
100
|
return self
|
|
101
101
|
|
|
102
102
|
def fit(self, X, y=None):
|
|
103
|
-
self._fit_validation(X, y)
|
|
104
103
|
result = dispatch(
|
|
105
104
|
self,
|
|
106
105
|
"fit",
|
|
@@ -113,16 +112,13 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
|
|
|
113
112
|
)
|
|
114
113
|
return result
|
|
115
114
|
|
|
116
|
-
# Subtle order change to remove check_array and preserve dpnp and
|
|
117
|
-
# dpctl conformance. decision_function will return a dpnp or dpctl
|
|
118
|
-
# instance via kneighbors and an equivalent check_array exists in
|
|
119
|
-
# that call already in sklearn so no loss of functionality occurs
|
|
120
115
|
def _predict(self, X=None):
|
|
121
116
|
check_is_fitted(self)
|
|
122
117
|
|
|
123
118
|
if X is not None:
|
|
119
|
+
xp, _ = get_namespace(X)
|
|
124
120
|
output = self.decision_function(X) < 0
|
|
125
|
-
is_inlier =
|
|
121
|
+
is_inlier = xp.ones_like(output, dtype=int)
|
|
126
122
|
is_inlier[output] = -1
|
|
127
123
|
else:
|
|
128
124
|
is_inlier = np.ones(self.n_samples_fit_, dtype=int)
|
|
@@ -159,16 +155,40 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
|
|
|
159
155
|
"""
|
|
160
156
|
return self.fit(X)._predict()
|
|
161
157
|
|
|
162
|
-
|
|
158
|
+
def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
|
|
159
|
+
check_is_fitted(self)
|
|
160
|
+
if sklearn_check_version("1.0") and X is not None:
|
|
161
|
+
self._check_feature_names(X, reset=False)
|
|
162
|
+
return dispatch(
|
|
163
|
+
self,
|
|
164
|
+
"kneighbors",
|
|
165
|
+
{
|
|
166
|
+
"onedal": self.__class__._onedal_kneighbors,
|
|
167
|
+
"sklearn": sklearn_LocalOutlierFactor.kneighbors,
|
|
168
|
+
},
|
|
169
|
+
X,
|
|
170
|
+
n_neighbors=n_neighbors,
|
|
171
|
+
return_distance=return_distance,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
kneighbors = wrap_output_data(_kneighbors)
|
|
175
|
+
|
|
176
|
+
@available_if(sklearn_LocalOutlierFactor._check_novelty_score_samples)
|
|
163
177
|
@wrap_output_data
|
|
164
|
-
def
|
|
165
|
-
"""
|
|
178
|
+
def score_samples(self, X):
|
|
179
|
+
"""Opposite of the Local Outlier Factor of X.
|
|
180
|
+
|
|
181
|
+
It is the opposite as bigger is better, i.e. large values correspond
|
|
182
|
+
to inliers.
|
|
166
183
|
|
|
167
184
|
**Only available for novelty detection (when novelty is set to True).**
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
obtained
|
|
185
|
+
The argument X is supposed to contain *new data*: if X contains a
|
|
186
|
+
point from training, it considers the later in its own neighborhood.
|
|
187
|
+
Also, the samples in X are not considered in the neighborhood of any
|
|
188
|
+
point. Because of this, the scores obtained via ``score_samples`` may
|
|
189
|
+
differ from the standard LOF scores.
|
|
190
|
+
The standard LOF scores for the training data is available via the
|
|
191
|
+
``negative_outlier_factor_`` attribute.
|
|
172
192
|
|
|
173
193
|
Parameters
|
|
174
194
|
----------
|
|
@@ -178,27 +198,24 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
|
|
|
178
198
|
|
|
179
199
|
Returns
|
|
180
200
|
-------
|
|
181
|
-
|
|
182
|
-
|
|
201
|
+
opposite_lof_scores : ndarray of shape (n_samples,)
|
|
202
|
+
The opposite of the Local Outlier Factor of each input samples.
|
|
203
|
+
The lower, the more abnormal.
|
|
183
204
|
"""
|
|
184
|
-
return self._predict(X)
|
|
185
|
-
|
|
186
|
-
@wrap_output_data
|
|
187
|
-
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
|
|
188
205
|
check_is_fitted(self)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
self,
|
|
193
|
-
"kneighbors",
|
|
194
|
-
{
|
|
195
|
-
"onedal": self.__class__._onedal_kneighbors,
|
|
196
|
-
"sklearn": sklearn_LocalOutlierFactor.kneighbors,
|
|
197
|
-
},
|
|
198
|
-
X,
|
|
199
|
-
n_neighbors=n_neighbors,
|
|
200
|
-
return_distance=return_distance,
|
|
206
|
+
|
|
207
|
+
distances_X, neighbors_indices_X = self._kneighbors(
|
|
208
|
+
X, n_neighbors=self.n_neighbors_
|
|
201
209
|
)
|
|
202
210
|
|
|
211
|
+
X_lrd = self._local_reachability_density(
|
|
212
|
+
distances_X,
|
|
213
|
+
neighbors_indices_X,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
lrd_ratios_array = self._lrd[neighbors_indices_X] / X_lrd[:, np.newaxis]
|
|
217
|
+
|
|
218
|
+
return -np.mean(lrd_ratios_array, axis=1)
|
|
219
|
+
|
|
203
220
|
fit.__doc__ = sklearn_LocalOutlierFactor.fit.__doc__
|
|
204
221
|
kneighbors.__doc__ = sklearn_LocalOutlierFactor.kneighbors.__doc__
|
|
@@ -137,6 +137,9 @@ class KNeighborsDispatchingBase:
|
|
|
137
137
|
self.n_features_in_ = X.data.shape[1]
|
|
138
138
|
|
|
139
139
|
def _onedal_supported(self, device, method_name, *data):
|
|
140
|
+
if method_name == "fit":
|
|
141
|
+
self._fit_validation(data[0], data[1])
|
|
142
|
+
|
|
140
143
|
class_name = self.__class__.__name__
|
|
141
144
|
is_classifier = "Classifier" in class_name
|
|
142
145
|
is_regressor = "Regressor" in class_name
|
|
@@ -249,7 +252,7 @@ class KNeighborsDispatchingBase:
|
|
|
249
252
|
class_count >= 2, "One-class case is not supported."
|
|
250
253
|
)
|
|
251
254
|
return patching_status
|
|
252
|
-
if method_name in ["predict", "predict_proba", "kneighbors"]:
|
|
255
|
+
if method_name in ["predict", "predict_proba", "kneighbors", "score"]:
|
|
253
256
|
patching_status.and_condition(
|
|
254
257
|
hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."
|
|
255
258
|
)
|
|
@@ -14,129 +14,30 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ===============================================================================
|
|
16
16
|
|
|
17
|
-
from
|
|
18
|
-
from daal4py.sklearn._utils import sklearn_check_version
|
|
19
|
-
|
|
20
|
-
if not sklearn_check_version("1.2"):
|
|
21
|
-
from sklearn.neighbors._base import _check_weights
|
|
22
|
-
|
|
17
|
+
from sklearn.metrics import accuracy_score
|
|
23
18
|
from sklearn.neighbors._classification import (
|
|
24
19
|
KNeighborsClassifier as sklearn_KNeighborsClassifier,
|
|
25
20
|
)
|
|
26
21
|
from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
|
|
27
22
|
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
|
|
28
23
|
|
|
24
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
25
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
29
26
|
from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier
|
|
30
27
|
|
|
31
28
|
from .._device_offload import dispatch, wrap_output_data
|
|
32
29
|
from .common import KNeighborsDispatchingBase
|
|
33
30
|
|
|
34
|
-
if sklearn_check_version("0.24"):
|
|
35
|
-
|
|
36
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier):
|
|
37
|
-
if sklearn_check_version("1.2"):
|
|
38
|
-
_parameter_constraints: dict = {
|
|
39
|
-
**sklearn_KNeighborsClassifier._parameter_constraints
|
|
40
|
-
}
|
|
41
31
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
*,
|
|
47
|
-
weights="uniform",
|
|
48
|
-
algorithm="auto",
|
|
49
|
-
leaf_size=30,
|
|
50
|
-
p=2,
|
|
51
|
-
metric="minkowski",
|
|
52
|
-
metric_params=None,
|
|
53
|
-
n_jobs=None,
|
|
54
|
-
**kwargs,
|
|
55
|
-
):
|
|
56
|
-
super().__init__(
|
|
57
|
-
n_neighbors=n_neighbors,
|
|
58
|
-
algorithm=algorithm,
|
|
59
|
-
leaf_size=leaf_size,
|
|
60
|
-
metric=metric,
|
|
61
|
-
p=p,
|
|
62
|
-
metric_params=metric_params,
|
|
63
|
-
n_jobs=n_jobs,
|
|
64
|
-
**kwargs,
|
|
65
|
-
)
|
|
66
|
-
self.weights = (
|
|
67
|
-
weights if sklearn_check_version("1.0") else _check_weights(weights)
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
elif sklearn_check_version("0.22"):
|
|
71
|
-
from sklearn.neighbors._base import (
|
|
72
|
-
SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
|
|
76
|
-
@_deprecate_positional_args
|
|
77
|
-
def __init__(
|
|
78
|
-
self,
|
|
79
|
-
n_neighbors=5,
|
|
80
|
-
*,
|
|
81
|
-
weights="uniform",
|
|
82
|
-
algorithm="auto",
|
|
83
|
-
leaf_size=30,
|
|
84
|
-
p=2,
|
|
85
|
-
metric="minkowski",
|
|
86
|
-
metric_params=None,
|
|
87
|
-
n_jobs=None,
|
|
88
|
-
**kwargs,
|
|
89
|
-
):
|
|
90
|
-
super().__init__(
|
|
91
|
-
n_neighbors=n_neighbors,
|
|
92
|
-
algorithm=algorithm,
|
|
93
|
-
leaf_size=leaf_size,
|
|
94
|
-
metric=metric,
|
|
95
|
-
p=p,
|
|
96
|
-
metric_params=metric_params,
|
|
97
|
-
n_jobs=n_jobs,
|
|
98
|
-
**kwargs,
|
|
99
|
-
)
|
|
100
|
-
self.weights = _check_weights(weights)
|
|
101
|
-
|
|
102
|
-
else:
|
|
103
|
-
from sklearn.neighbors.base import (
|
|
104
|
-
SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
|
|
108
|
-
@_deprecate_positional_args
|
|
109
|
-
def __init__(
|
|
110
|
-
self,
|
|
111
|
-
n_neighbors=5,
|
|
112
|
-
*,
|
|
113
|
-
weights="uniform",
|
|
114
|
-
algorithm="auto",
|
|
115
|
-
leaf_size=30,
|
|
116
|
-
p=2,
|
|
117
|
-
metric="minkowski",
|
|
118
|
-
metric_params=None,
|
|
119
|
-
n_jobs=None,
|
|
120
|
-
**kwargs,
|
|
121
|
-
):
|
|
122
|
-
super().__init__(
|
|
123
|
-
n_neighbors=n_neighbors,
|
|
124
|
-
algorithm=algorithm,
|
|
125
|
-
leaf_size=leaf_size,
|
|
126
|
-
metric=metric,
|
|
127
|
-
p=p,
|
|
128
|
-
metric_params=metric_params,
|
|
129
|
-
n_jobs=n_jobs,
|
|
130
|
-
**kwargs,
|
|
131
|
-
)
|
|
132
|
-
self.weights = _check_weights(weights)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "kneighbors"])
|
|
136
|
-
class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
32
|
+
@control_n_jobs(
|
|
33
|
+
decorated_methods=["fit", "predict", "predict_proba", "kneighbors", "score"]
|
|
34
|
+
)
|
|
35
|
+
class KNeighborsClassifier(sklearn_KNeighborsClassifier, KNeighborsDispatchingBase):
|
|
137
36
|
__doc__ = sklearn_KNeighborsClassifier.__doc__
|
|
138
37
|
if sklearn_check_version("1.2"):
|
|
139
|
-
_parameter_constraints: dict = {
|
|
38
|
+
_parameter_constraints: dict = {
|
|
39
|
+
**sklearn_KNeighborsClassifier._parameter_constraints
|
|
40
|
+
}
|
|
140
41
|
|
|
141
42
|
if sklearn_check_version("1.0"):
|
|
142
43
|
|
|
@@ -192,7 +93,6 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
192
93
|
)
|
|
193
94
|
|
|
194
95
|
def fit(self, X, y):
|
|
195
|
-
self._fit_validation(X, y)
|
|
196
96
|
dispatch(
|
|
197
97
|
self,
|
|
198
98
|
"fit",
|
|
@@ -235,6 +135,23 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
235
135
|
X,
|
|
236
136
|
)
|
|
237
137
|
|
|
138
|
+
@wrap_output_data
|
|
139
|
+
def score(self, X, y, sample_weight=None):
|
|
140
|
+
check_is_fitted(self)
|
|
141
|
+
if sklearn_check_version("1.0"):
|
|
142
|
+
self._check_feature_names(X, reset=False)
|
|
143
|
+
return dispatch(
|
|
144
|
+
self,
|
|
145
|
+
"score",
|
|
146
|
+
{
|
|
147
|
+
"onedal": self.__class__._onedal_score,
|
|
148
|
+
"sklearn": sklearn_KNeighborsClassifier.score,
|
|
149
|
+
},
|
|
150
|
+
X,
|
|
151
|
+
y,
|
|
152
|
+
sample_weight=sample_weight,
|
|
153
|
+
)
|
|
154
|
+
|
|
238
155
|
@wrap_output_data
|
|
239
156
|
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
|
|
240
157
|
check_is_fitted(self)
|
|
@@ -263,18 +180,10 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
263
180
|
or getattr(self, "_tree", 0) is None
|
|
264
181
|
and self._fit_method == "kd_tree"
|
|
265
182
|
):
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
if sklearn_check_version("0.22"):
|
|
271
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
272
|
-
self, X, radius, return_distance, sort_results
|
|
273
|
-
)
|
|
274
|
-
else:
|
|
275
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
276
|
-
self, X, radius, return_distance
|
|
277
|
-
)
|
|
183
|
+
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
|
|
184
|
+
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
185
|
+
self, X, radius, return_distance, sort_results
|
|
186
|
+
)
|
|
278
187
|
|
|
279
188
|
return result
|
|
280
189
|
|
|
@@ -313,6 +222,11 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
313
222
|
X, n_neighbors, return_distance, queue=queue
|
|
314
223
|
)
|
|
315
224
|
|
|
225
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
226
|
+
return accuracy_score(
|
|
227
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
228
|
+
)
|
|
229
|
+
|
|
316
230
|
def _save_attributes(self):
|
|
317
231
|
self.classes_ = self._onedal_estimator.classes_
|
|
318
232
|
self.n_features_in_ = self._onedal_estimator.n_features_in_
|
|
@@ -326,5 +240,6 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
326
240
|
fit.__doc__ = sklearn_KNeighborsClassifier.fit.__doc__
|
|
327
241
|
predict.__doc__ = sklearn_KNeighborsClassifier.predict.__doc__
|
|
328
242
|
predict_proba.__doc__ = sklearn_KNeighborsClassifier.predict_proba.__doc__
|
|
243
|
+
score.__doc__ = sklearn_KNeighborsClassifier.score.__doc__
|
|
329
244
|
kneighbors.__doc__ = sklearn_KNeighborsClassifier.kneighbors.__doc__
|
|
330
245
|
radius_neighbors.__doc__ = sklearn_NearestNeighbors.radius_neighbors.__doc__
|
|
@@ -14,125 +14,27 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ==============================================================================
|
|
16
16
|
|
|
17
|
-
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
18
|
-
from daal4py.sklearn._utils import sklearn_check_version
|
|
19
|
-
|
|
20
|
-
if not sklearn_check_version("1.2"):
|
|
21
|
-
from sklearn.neighbors._base import _check_weights
|
|
22
|
-
|
|
23
17
|
from sklearn.neighbors._regression import (
|
|
24
18
|
KNeighborsRegressor as sklearn_KNeighborsRegressor,
|
|
25
19
|
)
|
|
26
20
|
from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
|
|
27
21
|
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
|
|
28
22
|
|
|
23
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
24
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
29
25
|
from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor
|
|
30
26
|
|
|
31
27
|
from .._device_offload import dispatch, wrap_output_data
|
|
32
28
|
from .common import KNeighborsDispatchingBase
|
|
33
29
|
|
|
34
|
-
if sklearn_check_version("0.24"):
|
|
35
|
-
|
|
36
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor):
|
|
37
|
-
if sklearn_check_version("1.2"):
|
|
38
|
-
_parameter_constraints: dict = {
|
|
39
|
-
**sklearn_KNeighborsRegressor._parameter_constraints
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
@_deprecate_positional_args
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
n_neighbors=5,
|
|
46
|
-
*,
|
|
47
|
-
weights="uniform",
|
|
48
|
-
algorithm="auto",
|
|
49
|
-
leaf_size=30,
|
|
50
|
-
p=2,
|
|
51
|
-
metric="minkowski",
|
|
52
|
-
metric_params=None,
|
|
53
|
-
n_jobs=None,
|
|
54
|
-
**kwargs,
|
|
55
|
-
):
|
|
56
|
-
super().__init__(
|
|
57
|
-
n_neighbors=n_neighbors,
|
|
58
|
-
algorithm=algorithm,
|
|
59
|
-
leaf_size=leaf_size,
|
|
60
|
-
metric=metric,
|
|
61
|
-
p=p,
|
|
62
|
-
metric_params=metric_params,
|
|
63
|
-
n_jobs=n_jobs,
|
|
64
|
-
**kwargs,
|
|
65
|
-
)
|
|
66
|
-
self.weights = (
|
|
67
|
-
weights if sklearn_check_version("1.0") else _check_weights(weights)
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
elif sklearn_check_version("0.22"):
|
|
71
|
-
from sklearn.neighbors._base import SupervisedFloatMixin as BaseSupervisedFloatMixin
|
|
72
|
-
|
|
73
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
|
|
74
|
-
@_deprecate_positional_args
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
n_neighbors=5,
|
|
78
|
-
*,
|
|
79
|
-
weights="uniform",
|
|
80
|
-
algorithm="auto",
|
|
81
|
-
leaf_size=30,
|
|
82
|
-
p=2,
|
|
83
|
-
metric="minkowski",
|
|
84
|
-
metric_params=None,
|
|
85
|
-
n_jobs=None,
|
|
86
|
-
**kwargs,
|
|
87
|
-
):
|
|
88
|
-
super().__init__(
|
|
89
|
-
n_neighbors=n_neighbors,
|
|
90
|
-
algorithm=algorithm,
|
|
91
|
-
leaf_size=leaf_size,
|
|
92
|
-
metric=metric,
|
|
93
|
-
p=p,
|
|
94
|
-
metric_params=metric_params,
|
|
95
|
-
n_jobs=n_jobs,
|
|
96
|
-
**kwargs,
|
|
97
|
-
)
|
|
98
|
-
self.weights = _check_weights(weights)
|
|
99
|
-
|
|
100
|
-
else:
|
|
101
|
-
from sklearn.neighbors.base import SupervisedFloatMixin as BaseSupervisedFloatMixin
|
|
102
|
-
|
|
103
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
|
|
104
|
-
@_deprecate_positional_args
|
|
105
|
-
def __init__(
|
|
106
|
-
self,
|
|
107
|
-
n_neighbors=5,
|
|
108
|
-
*,
|
|
109
|
-
weights="uniform",
|
|
110
|
-
algorithm="auto",
|
|
111
|
-
leaf_size=30,
|
|
112
|
-
p=2,
|
|
113
|
-
metric="minkowski",
|
|
114
|
-
metric_params=None,
|
|
115
|
-
n_jobs=None,
|
|
116
|
-
**kwargs,
|
|
117
|
-
):
|
|
118
|
-
super().__init__(
|
|
119
|
-
n_neighbors=n_neighbors,
|
|
120
|
-
algorithm=algorithm,
|
|
121
|
-
leaf_size=leaf_size,
|
|
122
|
-
metric=metric,
|
|
123
|
-
p=p,
|
|
124
|
-
metric_params=metric_params,
|
|
125
|
-
n_jobs=n_jobs,
|
|
126
|
-
**kwargs,
|
|
127
|
-
)
|
|
128
|
-
self.weights = _check_weights(weights)
|
|
129
|
-
|
|
130
30
|
|
|
131
31
|
@control_n_jobs(decorated_methods=["fit", "predict", "kneighbors"])
|
|
132
|
-
class KNeighborsRegressor(
|
|
32
|
+
class KNeighborsRegressor(sklearn_KNeighborsRegressor, KNeighborsDispatchingBase):
|
|
133
33
|
__doc__ = sklearn_KNeighborsRegressor.__doc__
|
|
134
34
|
if sklearn_check_version("1.2"):
|
|
135
|
-
_parameter_constraints: dict = {
|
|
35
|
+
_parameter_constraints: dict = {
|
|
36
|
+
**sklearn_KNeighborsRegressor._parameter_constraints
|
|
37
|
+
}
|
|
136
38
|
|
|
137
39
|
if sklearn_check_version("1.0"):
|
|
138
40
|
|
|
@@ -188,7 +90,6 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
188
90
|
)
|
|
189
91
|
|
|
190
92
|
def fit(self, X, y):
|
|
191
|
-
self._fit_validation(X, y)
|
|
192
93
|
dispatch(
|
|
193
94
|
self,
|
|
194
95
|
"fit",
|
|
@@ -244,18 +145,10 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
244
145
|
or getattr(self, "_tree", 0) is None
|
|
245
146
|
and self._fit_method == "kd_tree"
|
|
246
147
|
):
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
if sklearn_check_version("0.22"):
|
|
252
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
253
|
-
self, X, radius, return_distance, sort_results
|
|
254
|
-
)
|
|
255
|
-
else:
|
|
256
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
257
|
-
self, X, radius, return_distance
|
|
258
|
-
)
|
|
148
|
+
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
|
|
149
|
+
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
150
|
+
self, X, radius, return_distance, sort_results
|
|
151
|
+
)
|
|
259
152
|
|
|
260
153
|
return result
|
|
261
154
|
|