scikit-learn-intelex 2024.2.0__py39-none-win_amd64.whl → 2025.1.0__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/_daal4py.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +248 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -28
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +3 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +236 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/_models_info.py +13 -22
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/test_patching.py +10 -56
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/utils/_launch_algorithms.py +4 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_config.py +54 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_device_offload.py +222 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +564 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +76 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +154 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +126 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +414 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +727 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +250 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +767 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +176 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/test_common.py +57 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +81 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_dpep_helpers.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/validation.py +440 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__init__.py +10 -7
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_config.py +22 -16
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +126 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_utils.py +27 -4
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +230 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +16 -7
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +395 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +398 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +111 -17
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +425 -0
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +25 -9
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +222 -42
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +249 -182
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +39 -21
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +482 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +425 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +341 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +182 -102
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +45 -4
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +4 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +97 -28
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +53 -6
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +48 -149
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +43 -144
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +50 -93
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +6 -9
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +24 -18
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +233 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model}/__init__.py +19 -19
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +14 -18
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +339 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +168 -73
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +71 -66
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +166 -72
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +64 -63
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +390 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +123 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +379 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +276 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +108 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +385 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +321 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +44 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +371 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/METADATA +231 -230
- scikit_learn_intelex-2025.1.0.dist-info/RECORD +257 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -223
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -82
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -371
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -188
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -222
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -93
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -19,14 +19,16 @@ import warnings
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
from scipy import sparse as sp
|
|
21
21
|
from sklearn.neighbors._ball_tree import BallTree
|
|
22
|
-
from sklearn.neighbors._base import VALID_METRICS
|
|
23
|
-
from sklearn.neighbors._base import NeighborsBase as
|
|
22
|
+
from sklearn.neighbors._base import VALID_METRICS, KNeighborsMixin
|
|
23
|
+
from sklearn.neighbors._base import NeighborsBase as _sklearn_NeighborsBase
|
|
24
24
|
from sklearn.neighbors._kd_tree import KDTree
|
|
25
|
+
from sklearn.utils.validation import check_is_fitted
|
|
25
26
|
|
|
26
27
|
from daal4py.sklearn._utils import sklearn_check_version
|
|
27
28
|
from onedal.utils import _check_array, _num_features, _num_samples
|
|
28
29
|
|
|
29
30
|
from .._utils import PatchingConditionsChain
|
|
31
|
+
from ..utils._array_api import get_namespace
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
class KNeighborsDispatchingBase:
|
|
@@ -62,7 +64,7 @@ class KNeighborsDispatchingBase:
|
|
|
62
64
|
elif p == np.inf:
|
|
63
65
|
self.effective_metric_ = "chebyshev"
|
|
64
66
|
|
|
65
|
-
if not isinstance(X, (KDTree, BallTree,
|
|
67
|
+
if not isinstance(X, (KDTree, BallTree, _sklearn_NeighborsBase)):
|
|
66
68
|
self._fit_X = _check_array(
|
|
67
69
|
X, dtype=[np.float64, np.float32], accept_sparse=True
|
|
68
70
|
)
|
|
@@ -95,7 +97,7 @@ class KNeighborsDispatchingBase:
|
|
|
95
97
|
delattr(self, "_onedal_estimator")
|
|
96
98
|
# To cover test case when we pass patched
|
|
97
99
|
# estimator as an input for other estimator
|
|
98
|
-
if isinstance(X,
|
|
100
|
+
if isinstance(X, _sklearn_NeighborsBase):
|
|
99
101
|
self._fit_X = X._fit_X
|
|
100
102
|
self._tree = X._tree
|
|
101
103
|
self._fit_method = X._fit_method
|
|
@@ -137,6 +139,9 @@ class KNeighborsDispatchingBase:
|
|
|
137
139
|
self.n_features_in_ = X.data.shape[1]
|
|
138
140
|
|
|
139
141
|
def _onedal_supported(self, device, method_name, *data):
|
|
142
|
+
if method_name == "fit":
|
|
143
|
+
self._fit_validation(data[0], data[1])
|
|
144
|
+
|
|
140
145
|
class_name = self.__class__.__name__
|
|
141
146
|
is_classifier = "Classifier" in class_name
|
|
142
147
|
is_regressor = "Regressor" in class_name
|
|
@@ -144,9 +149,13 @@ class KNeighborsDispatchingBase:
|
|
|
144
149
|
patching_status = PatchingConditionsChain(
|
|
145
150
|
f"sklearn.neighbors.{class_name}.{method_name}"
|
|
146
151
|
)
|
|
152
|
+
if not patching_status.and_condition(
|
|
153
|
+
"radius" not in method_name, "RadiusNeighbors not implemented in sklearnex"
|
|
154
|
+
):
|
|
155
|
+
return patching_status
|
|
147
156
|
|
|
148
157
|
if not patching_status.and_condition(
|
|
149
|
-
not isinstance(data[0], (KDTree, BallTree,
|
|
158
|
+
not isinstance(data[0], (KDTree, BallTree, _sklearn_NeighborsBase)),
|
|
150
159
|
f"Input type {type(data[0])} is not supported.",
|
|
151
160
|
):
|
|
152
161
|
return patching_status
|
|
@@ -249,7 +258,7 @@ class KNeighborsDispatchingBase:
|
|
|
249
258
|
class_count >= 2, "One-class case is not supported."
|
|
250
259
|
)
|
|
251
260
|
return patching_status
|
|
252
|
-
if method_name in ["predict", "predict_proba", "kneighbors"]:
|
|
261
|
+
if method_name in ["predict", "predict_proba", "kneighbors", "score"]:
|
|
253
262
|
patching_status.and_condition(
|
|
254
263
|
hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."
|
|
255
264
|
)
|
|
@@ -261,3 +270,41 @@ class KNeighborsDispatchingBase:
|
|
|
261
270
|
|
|
262
271
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
263
272
|
return self._onedal_supported("cpu", method_name, *data)
|
|
273
|
+
|
|
274
|
+
def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
|
|
275
|
+
check_is_fitted(self)
|
|
276
|
+
if n_neighbors is None:
|
|
277
|
+
n_neighbors = self.n_neighbors
|
|
278
|
+
|
|
279
|
+
# check the input only in self.kneighbors
|
|
280
|
+
|
|
281
|
+
# construct CSR matrix representation of the k-NN graph
|
|
282
|
+
if mode == "connectivity":
|
|
283
|
+
A_ind = self.kneighbors(X, n_neighbors, return_distance=False)
|
|
284
|
+
xp, _ = get_namespace(A_ind)
|
|
285
|
+
n_queries = A_ind.shape[0]
|
|
286
|
+
A_data = xp.ones(n_queries * n_neighbors)
|
|
287
|
+
|
|
288
|
+
elif mode == "distance":
|
|
289
|
+
A_data, A_ind = self.kneighbors(X, n_neighbors, return_distance=True)
|
|
290
|
+
xp, _ = get_namespace(A_ind)
|
|
291
|
+
A_data = xp.reshape(A_data, (-1,))
|
|
292
|
+
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(
|
|
295
|
+
'Unsupported mode, must be one of "connectivity", '
|
|
296
|
+
f'or "distance" but got "{mode}" instead'
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
n_queries = A_ind.shape[0]
|
|
300
|
+
n_samples_fit = self.n_samples_fit_
|
|
301
|
+
n_nonzero = n_queries * n_neighbors
|
|
302
|
+
A_indptr = xp.arange(0, n_nonzero + 1, n_neighbors)
|
|
303
|
+
|
|
304
|
+
kneighbors_graph = sp.csr_matrix(
|
|
305
|
+
(A_data, xp.reshape(A_ind, (-1,)), A_indptr), shape=(n_queries, n_samples_fit)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return kneighbors_graph
|
|
309
|
+
|
|
310
|
+
kneighbors_graph.__doc__ = KNeighborsMixin.kneighbors_graph.__doc__
|
|
@@ -14,137 +14,35 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ===============================================================================
|
|
16
16
|
|
|
17
|
-
import
|
|
18
|
-
|
|
19
|
-
from sklearn.neighbors._ball_tree import BallTree
|
|
20
|
-
from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
|
|
21
|
-
from sklearn.neighbors._kd_tree import KDTree
|
|
22
|
-
|
|
23
|
-
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
24
|
-
from daal4py.sklearn._utils import sklearn_check_version
|
|
25
|
-
|
|
26
|
-
if not sklearn_check_version("1.2"):
|
|
27
|
-
from sklearn.neighbors._base import _check_weights
|
|
28
|
-
|
|
29
|
-
import numpy as np
|
|
30
|
-
from sklearn.neighbors._base import VALID_METRICS
|
|
17
|
+
from sklearn.metrics import accuracy_score
|
|
31
18
|
from sklearn.neighbors._classification import (
|
|
32
|
-
KNeighborsClassifier as
|
|
19
|
+
KNeighborsClassifier as _sklearn_KNeighborsClassifier,
|
|
33
20
|
)
|
|
34
|
-
from sklearn.neighbors._unsupervised import NearestNeighbors as
|
|
21
|
+
from sklearn.neighbors._unsupervised import NearestNeighbors as _sklearn_NearestNeighbors
|
|
35
22
|
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
|
|
36
23
|
|
|
24
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
25
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
37
26
|
from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier
|
|
38
|
-
from onedal.utils import _check_array, _num_features, _num_samples
|
|
39
27
|
|
|
40
28
|
from .._device_offload import dispatch, wrap_output_data
|
|
41
29
|
from .common import KNeighborsDispatchingBase
|
|
42
30
|
|
|
43
|
-
if sklearn_check_version("
|
|
44
|
-
|
|
45
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier):
|
|
46
|
-
if sklearn_check_version("1.2"):
|
|
47
|
-
_parameter_constraints: dict = {
|
|
48
|
-
**sklearn_KNeighborsClassifier._parameter_constraints
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
@_deprecate_positional_args
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
n_neighbors=5,
|
|
55
|
-
*,
|
|
56
|
-
weights="uniform",
|
|
57
|
-
algorithm="auto",
|
|
58
|
-
leaf_size=30,
|
|
59
|
-
p=2,
|
|
60
|
-
metric="minkowski",
|
|
61
|
-
metric_params=None,
|
|
62
|
-
n_jobs=None,
|
|
63
|
-
**kwargs,
|
|
64
|
-
):
|
|
65
|
-
super().__init__(
|
|
66
|
-
n_neighbors=n_neighbors,
|
|
67
|
-
algorithm=algorithm,
|
|
68
|
-
leaf_size=leaf_size,
|
|
69
|
-
metric=metric,
|
|
70
|
-
p=p,
|
|
71
|
-
metric_params=metric_params,
|
|
72
|
-
n_jobs=n_jobs,
|
|
73
|
-
**kwargs,
|
|
74
|
-
)
|
|
75
|
-
self.weights = (
|
|
76
|
-
weights if sklearn_check_version("1.0") else _check_weights(weights)
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
elif sklearn_check_version("0.22"):
|
|
80
|
-
from sklearn.neighbors._base import (
|
|
81
|
-
SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
|
|
85
|
-
@_deprecate_positional_args
|
|
86
|
-
def __init__(
|
|
87
|
-
self,
|
|
88
|
-
n_neighbors=5,
|
|
89
|
-
*,
|
|
90
|
-
weights="uniform",
|
|
91
|
-
algorithm="auto",
|
|
92
|
-
leaf_size=30,
|
|
93
|
-
p=2,
|
|
94
|
-
metric="minkowski",
|
|
95
|
-
metric_params=None,
|
|
96
|
-
n_jobs=None,
|
|
97
|
-
**kwargs,
|
|
98
|
-
):
|
|
99
|
-
super().__init__(
|
|
100
|
-
n_neighbors=n_neighbors,
|
|
101
|
-
algorithm=algorithm,
|
|
102
|
-
leaf_size=leaf_size,
|
|
103
|
-
metric=metric,
|
|
104
|
-
p=p,
|
|
105
|
-
metric_params=metric_params,
|
|
106
|
-
n_jobs=n_jobs,
|
|
107
|
-
**kwargs,
|
|
108
|
-
)
|
|
109
|
-
self.weights = _check_weights(weights)
|
|
110
|
-
|
|
31
|
+
if sklearn_check_version("1.6"):
|
|
32
|
+
from sklearn.utils.validation import validate_data
|
|
111
33
|
else:
|
|
112
|
-
|
|
113
|
-
SupervisedIntegerMixin as BaseSupervisedIntegerMixin,
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
class KNeighborsClassifier_(sklearn_KNeighborsClassifier, BaseSupervisedIntegerMixin):
|
|
117
|
-
@_deprecate_positional_args
|
|
118
|
-
def __init__(
|
|
119
|
-
self,
|
|
120
|
-
n_neighbors=5,
|
|
121
|
-
*,
|
|
122
|
-
weights="uniform",
|
|
123
|
-
algorithm="auto",
|
|
124
|
-
leaf_size=30,
|
|
125
|
-
p=2,
|
|
126
|
-
metric="minkowski",
|
|
127
|
-
metric_params=None,
|
|
128
|
-
n_jobs=None,
|
|
129
|
-
**kwargs,
|
|
130
|
-
):
|
|
131
|
-
super().__init__(
|
|
132
|
-
n_neighbors=n_neighbors,
|
|
133
|
-
algorithm=algorithm,
|
|
134
|
-
leaf_size=leaf_size,
|
|
135
|
-
metric=metric,
|
|
136
|
-
p=p,
|
|
137
|
-
metric_params=metric_params,
|
|
138
|
-
n_jobs=n_jobs,
|
|
139
|
-
**kwargs,
|
|
140
|
-
)
|
|
141
|
-
self.weights = _check_weights(weights)
|
|
34
|
+
validate_data = _sklearn_KNeighborsClassifier._validate_data
|
|
142
35
|
|
|
143
36
|
|
|
144
|
-
@control_n_jobs(
|
|
145
|
-
|
|
37
|
+
@control_n_jobs(
|
|
38
|
+
decorated_methods=["fit", "predict", "predict_proba", "kneighbors", "score"]
|
|
39
|
+
)
|
|
40
|
+
class KNeighborsClassifier(KNeighborsDispatchingBase, _sklearn_KNeighborsClassifier):
|
|
41
|
+
__doc__ = _sklearn_KNeighborsClassifier.__doc__
|
|
146
42
|
if sklearn_check_version("1.2"):
|
|
147
|
-
_parameter_constraints: dict = {
|
|
43
|
+
_parameter_constraints: dict = {
|
|
44
|
+
**_sklearn_KNeighborsClassifier._parameter_constraints
|
|
45
|
+
}
|
|
148
46
|
|
|
149
47
|
if sklearn_check_version("1.0"):
|
|
150
48
|
|
|
@@ -200,13 +98,12 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
200
98
|
)
|
|
201
99
|
|
|
202
100
|
def fit(self, X, y):
|
|
203
|
-
self._fit_validation(X, y)
|
|
204
101
|
dispatch(
|
|
205
102
|
self,
|
|
206
103
|
"fit",
|
|
207
104
|
{
|
|
208
105
|
"onedal": self.__class__._onedal_fit,
|
|
209
|
-
"sklearn":
|
|
106
|
+
"sklearn": _sklearn_KNeighborsClassifier.fit,
|
|
210
107
|
},
|
|
211
108
|
X,
|
|
212
109
|
y,
|
|
@@ -223,7 +120,7 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
223
120
|
"predict",
|
|
224
121
|
{
|
|
225
122
|
"onedal": self.__class__._onedal_predict,
|
|
226
|
-
"sklearn":
|
|
123
|
+
"sklearn": _sklearn_KNeighborsClassifier.predict,
|
|
227
124
|
},
|
|
228
125
|
X,
|
|
229
126
|
)
|
|
@@ -238,9 +135,26 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
238
135
|
"predict_proba",
|
|
239
136
|
{
|
|
240
137
|
"onedal": self.__class__._onedal_predict_proba,
|
|
241
|
-
"sklearn":
|
|
138
|
+
"sklearn": _sklearn_KNeighborsClassifier.predict_proba,
|
|
139
|
+
},
|
|
140
|
+
X,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@wrap_output_data
|
|
144
|
+
def score(self, X, y, sample_weight=None):
|
|
145
|
+
check_is_fitted(self)
|
|
146
|
+
if sklearn_check_version("1.0"):
|
|
147
|
+
self._check_feature_names(X, reset=False)
|
|
148
|
+
return dispatch(
|
|
149
|
+
self,
|
|
150
|
+
"score",
|
|
151
|
+
{
|
|
152
|
+
"onedal": self.__class__._onedal_score,
|
|
153
|
+
"sklearn": _sklearn_KNeighborsClassifier.score,
|
|
242
154
|
},
|
|
243
155
|
X,
|
|
156
|
+
y,
|
|
157
|
+
sample_weight=sample_weight,
|
|
244
158
|
)
|
|
245
159
|
|
|
246
160
|
@wrap_output_data
|
|
@@ -253,39 +167,13 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
253
167
|
"kneighbors",
|
|
254
168
|
{
|
|
255
169
|
"onedal": self.__class__._onedal_kneighbors,
|
|
256
|
-
"sklearn":
|
|
170
|
+
"sklearn": _sklearn_KNeighborsClassifier.kneighbors,
|
|
257
171
|
},
|
|
258
172
|
X,
|
|
259
173
|
n_neighbors=n_neighbors,
|
|
260
174
|
return_distance=return_distance,
|
|
261
175
|
)
|
|
262
176
|
|
|
263
|
-
@wrap_output_data
|
|
264
|
-
def radius_neighbors(
|
|
265
|
-
self, X=None, radius=None, return_distance=True, sort_results=False
|
|
266
|
-
):
|
|
267
|
-
_onedal_estimator = getattr(self, "_onedal_estimator", None)
|
|
268
|
-
|
|
269
|
-
if (
|
|
270
|
-
_onedal_estimator is not None
|
|
271
|
-
or getattr(self, "_tree", 0) is None
|
|
272
|
-
and self._fit_method == "kd_tree"
|
|
273
|
-
):
|
|
274
|
-
if sklearn_check_version("0.24"):
|
|
275
|
-
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
|
|
276
|
-
else:
|
|
277
|
-
sklearn_NearestNeighbors.fit(self, self._fit_X)
|
|
278
|
-
if sklearn_check_version("0.22"):
|
|
279
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
280
|
-
self, X, radius, return_distance, sort_results
|
|
281
|
-
)
|
|
282
|
-
else:
|
|
283
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
284
|
-
self, X, radius, return_distance
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
return result
|
|
288
|
-
|
|
289
177
|
def _onedal_fit(self, X, y, queue=None):
|
|
290
178
|
onedal_params = {
|
|
291
179
|
"n_neighbors": self.n_neighbors,
|
|
@@ -321,6 +209,11 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
321
209
|
X, n_neighbors, return_distance, queue=queue
|
|
322
210
|
)
|
|
323
211
|
|
|
212
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
213
|
+
return accuracy_score(
|
|
214
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
215
|
+
)
|
|
216
|
+
|
|
324
217
|
def _save_attributes(self):
|
|
325
218
|
self.classes_ = self._onedal_estimator.classes_
|
|
326
219
|
self.n_features_in_ = self._onedal_estimator.n_features_in_
|
|
@@ -330,3 +223,9 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
|
|
|
330
223
|
self._fit_method = self._onedal_estimator._fit_method
|
|
331
224
|
self.outputs_2d_ = self._onedal_estimator.outputs_2d_
|
|
332
225
|
self._tree = self._onedal_estimator._tree
|
|
226
|
+
|
|
227
|
+
fit.__doc__ = _sklearn_KNeighborsClassifier.fit.__doc__
|
|
228
|
+
predict.__doc__ = _sklearn_KNeighborsClassifier.predict.__doc__
|
|
229
|
+
predict_proba.__doc__ = _sklearn_KNeighborsClassifier.predict_proba.__doc__
|
|
230
|
+
score.__doc__ = _sklearn_KNeighborsClassifier.score.__doc__
|
|
231
|
+
kneighbors.__doc__ = _sklearn_KNeighborsClassifier.kneighbors.__doc__
|
|
@@ -14,133 +14,32 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
# ==============================================================================
|
|
16
16
|
|
|
17
|
-
import
|
|
18
|
-
|
|
19
|
-
from sklearn.neighbors._ball_tree import BallTree
|
|
20
|
-
from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
|
|
21
|
-
from sklearn.neighbors._kd_tree import KDTree
|
|
22
|
-
|
|
23
|
-
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
24
|
-
from daal4py.sklearn._utils import sklearn_check_version
|
|
25
|
-
|
|
26
|
-
if not sklearn_check_version("1.2"):
|
|
27
|
-
from sklearn.neighbors._base import _check_weights
|
|
28
|
-
|
|
29
|
-
import numpy as np
|
|
30
|
-
from sklearn.neighbors._base import VALID_METRICS
|
|
17
|
+
from sklearn.metrics import r2_score
|
|
31
18
|
from sklearn.neighbors._regression import (
|
|
32
|
-
KNeighborsRegressor as
|
|
19
|
+
KNeighborsRegressor as _sklearn_KNeighborsRegressor,
|
|
33
20
|
)
|
|
34
|
-
from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
|
|
35
21
|
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
|
|
36
22
|
|
|
23
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
24
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
37
25
|
from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor
|
|
38
|
-
from onedal.utils import _check_array, _num_features, _num_samples
|
|
39
26
|
|
|
40
27
|
from .._device_offload import dispatch, wrap_output_data
|
|
41
28
|
from .common import KNeighborsDispatchingBase
|
|
42
29
|
|
|
43
|
-
if sklearn_check_version("
|
|
44
|
-
|
|
45
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor):
|
|
46
|
-
if sklearn_check_version("1.2"):
|
|
47
|
-
_parameter_constraints: dict = {
|
|
48
|
-
**sklearn_KNeighborsRegressor._parameter_constraints
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
@_deprecate_positional_args
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
n_neighbors=5,
|
|
55
|
-
*,
|
|
56
|
-
weights="uniform",
|
|
57
|
-
algorithm="auto",
|
|
58
|
-
leaf_size=30,
|
|
59
|
-
p=2,
|
|
60
|
-
metric="minkowski",
|
|
61
|
-
metric_params=None,
|
|
62
|
-
n_jobs=None,
|
|
63
|
-
**kwargs,
|
|
64
|
-
):
|
|
65
|
-
super().__init__(
|
|
66
|
-
n_neighbors=n_neighbors,
|
|
67
|
-
algorithm=algorithm,
|
|
68
|
-
leaf_size=leaf_size,
|
|
69
|
-
metric=metric,
|
|
70
|
-
p=p,
|
|
71
|
-
metric_params=metric_params,
|
|
72
|
-
n_jobs=n_jobs,
|
|
73
|
-
**kwargs,
|
|
74
|
-
)
|
|
75
|
-
self.weights = (
|
|
76
|
-
weights if sklearn_check_version("1.0") else _check_weights(weights)
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
elif sklearn_check_version("0.22"):
|
|
80
|
-
from sklearn.neighbors._base import SupervisedFloatMixin as BaseSupervisedFloatMixin
|
|
81
|
-
|
|
82
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
|
|
83
|
-
@_deprecate_positional_args
|
|
84
|
-
def __init__(
|
|
85
|
-
self,
|
|
86
|
-
n_neighbors=5,
|
|
87
|
-
*,
|
|
88
|
-
weights="uniform",
|
|
89
|
-
algorithm="auto",
|
|
90
|
-
leaf_size=30,
|
|
91
|
-
p=2,
|
|
92
|
-
metric="minkowski",
|
|
93
|
-
metric_params=None,
|
|
94
|
-
n_jobs=None,
|
|
95
|
-
**kwargs,
|
|
96
|
-
):
|
|
97
|
-
super().__init__(
|
|
98
|
-
n_neighbors=n_neighbors,
|
|
99
|
-
algorithm=algorithm,
|
|
100
|
-
leaf_size=leaf_size,
|
|
101
|
-
metric=metric,
|
|
102
|
-
p=p,
|
|
103
|
-
metric_params=metric_params,
|
|
104
|
-
n_jobs=n_jobs,
|
|
105
|
-
**kwargs,
|
|
106
|
-
)
|
|
107
|
-
self.weights = _check_weights(weights)
|
|
108
|
-
|
|
30
|
+
if sklearn_check_version("1.6"):
|
|
31
|
+
from sklearn.utils.validation import validate_data
|
|
109
32
|
else:
|
|
110
|
-
|
|
33
|
+
validate_data = _sklearn_KNeighborsRegressor._validate_data
|
|
111
34
|
|
|
112
|
-
class KNeighborsRegressor_(sklearn_KNeighborsRegressor, BaseSupervisedFloatMixin):
|
|
113
|
-
@_deprecate_positional_args
|
|
114
|
-
def __init__(
|
|
115
|
-
self,
|
|
116
|
-
n_neighbors=5,
|
|
117
|
-
*,
|
|
118
|
-
weights="uniform",
|
|
119
|
-
algorithm="auto",
|
|
120
|
-
leaf_size=30,
|
|
121
|
-
p=2,
|
|
122
|
-
metric="minkowski",
|
|
123
|
-
metric_params=None,
|
|
124
|
-
n_jobs=None,
|
|
125
|
-
**kwargs,
|
|
126
|
-
):
|
|
127
|
-
super().__init__(
|
|
128
|
-
n_neighbors=n_neighbors,
|
|
129
|
-
algorithm=algorithm,
|
|
130
|
-
leaf_size=leaf_size,
|
|
131
|
-
metric=metric,
|
|
132
|
-
p=p,
|
|
133
|
-
metric_params=metric_params,
|
|
134
|
-
n_jobs=n_jobs,
|
|
135
|
-
**kwargs,
|
|
136
|
-
)
|
|
137
|
-
self.weights = _check_weights(weights)
|
|
138
35
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
36
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "kneighbors", "score"])
|
|
37
|
+
class KNeighborsRegressor(KNeighborsDispatchingBase, _sklearn_KNeighborsRegressor):
|
|
38
|
+
__doc__ = _sklearn_KNeighborsRegressor.__doc__
|
|
142
39
|
if sklearn_check_version("1.2"):
|
|
143
|
-
_parameter_constraints: dict = {
|
|
40
|
+
_parameter_constraints: dict = {
|
|
41
|
+
**_sklearn_KNeighborsRegressor._parameter_constraints
|
|
42
|
+
}
|
|
144
43
|
|
|
145
44
|
if sklearn_check_version("1.0"):
|
|
146
45
|
|
|
@@ -196,13 +95,12 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
196
95
|
)
|
|
197
96
|
|
|
198
97
|
def fit(self, X, y):
|
|
199
|
-
self._fit_validation(X, y)
|
|
200
98
|
dispatch(
|
|
201
99
|
self,
|
|
202
100
|
"fit",
|
|
203
101
|
{
|
|
204
102
|
"onedal": self.__class__._onedal_fit,
|
|
205
|
-
"sklearn":
|
|
103
|
+
"sklearn": _sklearn_KNeighborsRegressor.fit,
|
|
206
104
|
},
|
|
207
105
|
X,
|
|
208
106
|
y,
|
|
@@ -219,11 +117,28 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
219
117
|
"predict",
|
|
220
118
|
{
|
|
221
119
|
"onedal": self.__class__._onedal_predict,
|
|
222
|
-
"sklearn":
|
|
120
|
+
"sklearn": _sklearn_KNeighborsRegressor.predict,
|
|
223
121
|
},
|
|
224
122
|
X,
|
|
225
123
|
)
|
|
226
124
|
|
|
125
|
+
@wrap_output_data
|
|
126
|
+
def score(self, X, y, sample_weight=None):
|
|
127
|
+
check_is_fitted(self)
|
|
128
|
+
if sklearn_check_version("1.0"):
|
|
129
|
+
self._check_feature_names(X, reset=False)
|
|
130
|
+
return dispatch(
|
|
131
|
+
self,
|
|
132
|
+
"score",
|
|
133
|
+
{
|
|
134
|
+
"onedal": self.__class__._onedal_score,
|
|
135
|
+
"sklearn": _sklearn_KNeighborsRegressor.score,
|
|
136
|
+
},
|
|
137
|
+
X,
|
|
138
|
+
y,
|
|
139
|
+
sample_weight=sample_weight,
|
|
140
|
+
)
|
|
141
|
+
|
|
227
142
|
@wrap_output_data
|
|
228
143
|
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
|
|
229
144
|
check_is_fitted(self)
|
|
@@ -234,39 +149,13 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
234
149
|
"kneighbors",
|
|
235
150
|
{
|
|
236
151
|
"onedal": self.__class__._onedal_kneighbors,
|
|
237
|
-
"sklearn":
|
|
152
|
+
"sklearn": _sklearn_KNeighborsRegressor.kneighbors,
|
|
238
153
|
},
|
|
239
154
|
X,
|
|
240
155
|
n_neighbors=n_neighbors,
|
|
241
156
|
return_distance=return_distance,
|
|
242
157
|
)
|
|
243
158
|
|
|
244
|
-
@wrap_output_data
|
|
245
|
-
def radius_neighbors(
|
|
246
|
-
self, X=None, radius=None, return_distance=True, sort_results=False
|
|
247
|
-
):
|
|
248
|
-
_onedal_estimator = getattr(self, "_onedal_estimator", None)
|
|
249
|
-
|
|
250
|
-
if (
|
|
251
|
-
_onedal_estimator is not None
|
|
252
|
-
or getattr(self, "_tree", 0) is None
|
|
253
|
-
and self._fit_method == "kd_tree"
|
|
254
|
-
):
|
|
255
|
-
if sklearn_check_version("0.24"):
|
|
256
|
-
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
|
|
257
|
-
else:
|
|
258
|
-
sklearn_NearestNeighbors.fit(self, self._fit_X)
|
|
259
|
-
if sklearn_check_version("0.22"):
|
|
260
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
261
|
-
self, X, radius, return_distance, sort_results
|
|
262
|
-
)
|
|
263
|
-
else:
|
|
264
|
-
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
265
|
-
self, X, radius, return_distance
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
return result
|
|
269
|
-
|
|
270
159
|
def _onedal_fit(self, X, y, queue=None):
|
|
271
160
|
onedal_params = {
|
|
272
161
|
"n_neighbors": self.n_neighbors,
|
|
@@ -299,6 +188,11 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
299
188
|
X, n_neighbors, return_distance, queue=queue
|
|
300
189
|
)
|
|
301
190
|
|
|
191
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
192
|
+
return r2_score(
|
|
193
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
194
|
+
)
|
|
195
|
+
|
|
302
196
|
def _save_attributes(self):
|
|
303
197
|
self.n_features_in_ = self._onedal_estimator.n_features_in_
|
|
304
198
|
self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_
|
|
@@ -306,3 +200,8 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
|
|
|
306
200
|
self._y = self._onedal_estimator._y
|
|
307
201
|
self._fit_method = self._onedal_estimator._fit_method
|
|
308
202
|
self._tree = self._onedal_estimator._tree
|
|
203
|
+
|
|
204
|
+
fit.__doc__ = _sklearn_KNeighborsRegressor.__doc__
|
|
205
|
+
predict.__doc__ = _sklearn_KNeighborsRegressor.predict.__doc__
|
|
206
|
+
kneighbors.__doc__ = _sklearn_KNeighborsRegressor.kneighbors.__doc__
|
|
207
|
+
score.__doc__ = _sklearn_KNeighborsRegressor.score.__doc__
|