scikit-learn-intelex 2024.1.0__py311-none-win_amd64.whl → 2025.1.0__py311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/_daal4py.cp311-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp311-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +248 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition}/__init__.py +3 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -29
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +4 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +236 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/_models_info.py +13 -22
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/test_patching.py +10 -42
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/utils/_launch_algorithms.py +4 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_config.py +54 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_device_offload.py +222 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp311-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp311-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +564 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +76 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +154 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +126 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +414 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +727 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +250 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +767 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +176 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/test_common.py +57 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +81 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_dpep_helpers.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/validation.py +440 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__init__.py +10 -7
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_config.py +22 -16
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +126 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_utils.py +27 -4
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +230 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +19 -10
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +395 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +398 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +425 -0
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +25 -9
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +241 -60
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +250 -188
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +39 -21
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +482 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +425 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +341 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +194 -133
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +134 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +4 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +236 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +53 -6
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +51 -155
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +46 -149
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +55 -100
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +16 -18
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +1 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +138 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +233 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model}/__init__.py +19 -19
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +11 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -1
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +14 -18
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +339 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +172 -78
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +74 -70
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +170 -77
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +66 -66
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +390 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +123 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +379 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +276 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +108 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +385 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +321 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +44 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +371 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/METADATA +231 -230
- scikit_learn_intelex-2025.1.0.dist-info/RECORD +257 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -223
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -388
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -82
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -376
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -98
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -376
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -188
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -225
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -227
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,564 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import warnings
|
|
19
|
+
from abc import ABC
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
24
|
+
from onedal import _backend
|
|
25
|
+
from onedal.basic_statistics import BasicStatistics
|
|
26
|
+
|
|
27
|
+
if daal_check_version((2023, "P", 200)):
|
|
28
|
+
from .kmeans_init import KMeansInit
|
|
29
|
+
|
|
30
|
+
from sklearn.cluster._kmeans import _kmeans_plusplus
|
|
31
|
+
from sklearn.exceptions import ConvergenceWarning
|
|
32
|
+
from sklearn.metrics.pairwise import euclidean_distances
|
|
33
|
+
from sklearn.utils import check_random_state
|
|
34
|
+
|
|
35
|
+
from ..common._base import BaseEstimator as onedal_BaseEstimator
|
|
36
|
+
from ..common._mixin import ClusterMixin, TransformerMixin
|
|
37
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
38
|
+
from ..utils import _check_array, _is_arraylike_not_scalar, _is_csr
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _BaseKMeans(onedal_BaseEstimator, TransformerMixin, ClusterMixin, ABC):
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
n_clusters,
|
|
45
|
+
*,
|
|
46
|
+
init,
|
|
47
|
+
n_init,
|
|
48
|
+
max_iter,
|
|
49
|
+
tol,
|
|
50
|
+
verbose,
|
|
51
|
+
random_state,
|
|
52
|
+
n_local_trials=None,
|
|
53
|
+
):
|
|
54
|
+
self.n_clusters = n_clusters
|
|
55
|
+
self.init = init
|
|
56
|
+
self.max_iter = max_iter
|
|
57
|
+
self.tol = tol
|
|
58
|
+
self.n_init = n_init
|
|
59
|
+
self.verbose = verbose
|
|
60
|
+
self.random_state = random_state
|
|
61
|
+
self.n_local_trials = n_local_trials
|
|
62
|
+
|
|
63
|
+
def _validate_center_shape(self, X, centers):
|
|
64
|
+
"""Check if centers is compatible with X and n_clusters."""
|
|
65
|
+
if centers.shape[0] != self.n_clusters:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"The shape of the initial centers {centers.shape} does not "
|
|
68
|
+
f"match the number of clusters {self.n_clusters}."
|
|
69
|
+
)
|
|
70
|
+
if centers.shape[1] != X.shape[1]:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"The shape of the initial centers {centers.shape} does not "
|
|
73
|
+
f"match the number of features of the data {X.shape[1]}."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _get_kmeans_init(self, cluster_count, seed, algorithm):
|
|
77
|
+
return KMeansInit(cluster_count=cluster_count, seed=seed, algorithm=algorithm)
|
|
78
|
+
|
|
79
|
+
# Get appropriate backend (required for SPMD)
|
|
80
|
+
def _get_basic_statistics_backend(self, result_options):
|
|
81
|
+
return BasicStatistics(result_options)
|
|
82
|
+
|
|
83
|
+
def _tolerance(self, X_table, rtol, is_csr, policy, dtype):
|
|
84
|
+
"""Compute absolute tolerance from the relative tolerance"""
|
|
85
|
+
if rtol == 0.0:
|
|
86
|
+
return rtol
|
|
87
|
+
dummy = to_table(None)
|
|
88
|
+
|
|
89
|
+
bs = self._get_basic_statistics_backend("variance")
|
|
90
|
+
|
|
91
|
+
res = bs._compute_raw(X_table, dummy, policy, dtype, is_csr)
|
|
92
|
+
mean_var = from_table(res["variance"]).mean()
|
|
93
|
+
|
|
94
|
+
return mean_var * rtol
|
|
95
|
+
|
|
96
|
+
def _check_params_vs_input(
|
|
97
|
+
self, X_table, is_csr, policy, default_n_init=10, dtype=np.float32
|
|
98
|
+
):
|
|
99
|
+
# n_clusters
|
|
100
|
+
if X_table.shape[0] < self.n_clusters:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"n_samples={X_table.shape[0]} should be >= n_clusters={self.n_clusters}."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# tol
|
|
106
|
+
self._tol = self._tolerance(X_table, self.tol, is_csr, policy, dtype)
|
|
107
|
+
|
|
108
|
+
# n-init
|
|
109
|
+
# TODO(1.4): Remove
|
|
110
|
+
self._n_init = self.n_init
|
|
111
|
+
if self._n_init == "warn":
|
|
112
|
+
warnings.warn(
|
|
113
|
+
(
|
|
114
|
+
"The default value of `n_init` will change from "
|
|
115
|
+
f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`"
|
|
116
|
+
" explicitly to suppress the warning"
|
|
117
|
+
),
|
|
118
|
+
FutureWarning,
|
|
119
|
+
stacklevel=2,
|
|
120
|
+
)
|
|
121
|
+
self._n_init = default_n_init
|
|
122
|
+
if self._n_init == "auto":
|
|
123
|
+
if isinstance(self.init, str) and self.init == "k-means++":
|
|
124
|
+
self._n_init = 1
|
|
125
|
+
elif isinstance(self.init, str) and self.init == "random":
|
|
126
|
+
self._n_init = default_n_init
|
|
127
|
+
elif callable(self.init):
|
|
128
|
+
self._n_init = default_n_init
|
|
129
|
+
else: # array-like
|
|
130
|
+
self._n_init = 1
|
|
131
|
+
|
|
132
|
+
if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
|
|
133
|
+
warnings.warn(
|
|
134
|
+
(
|
|
135
|
+
"Explicit initial center position passed: performing only"
|
|
136
|
+
f" one init in {self.__class__.__name__} instead of "
|
|
137
|
+
f"n_init={self._n_init}."
|
|
138
|
+
),
|
|
139
|
+
RuntimeWarning,
|
|
140
|
+
stacklevel=2,
|
|
141
|
+
)
|
|
142
|
+
self._n_init = 1
|
|
143
|
+
assert self.algorithm == "lloyd"
|
|
144
|
+
|
|
145
|
+
def _get_onedal_params(self, is_csr=False, dtype=np.float32, result_options=None):
|
|
146
|
+
thr = self._tol if hasattr(self, "_tol") else self.tol
|
|
147
|
+
return {
|
|
148
|
+
"fptype": "float" if dtype == np.float32 else "double",
|
|
149
|
+
"method": "lloyd_csr" if is_csr else "by_default",
|
|
150
|
+
"seed": -1,
|
|
151
|
+
"max_iteration_count": self.max_iter,
|
|
152
|
+
"cluster_count": self.n_clusters,
|
|
153
|
+
"accuracy_threshold": thr,
|
|
154
|
+
"result_options": "" if result_options is None else result_options,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
def _init_centroids_onedal(
|
|
158
|
+
self,
|
|
159
|
+
X_table,
|
|
160
|
+
init,
|
|
161
|
+
random_seed,
|
|
162
|
+
policy,
|
|
163
|
+
is_csr,
|
|
164
|
+
dtype=np.float32,
|
|
165
|
+
n_centroids=None,
|
|
166
|
+
):
|
|
167
|
+
n_clusters = self.n_clusters if n_centroids is None else n_centroids
|
|
168
|
+
# Use host policy for KMeans init, only for csr data
|
|
169
|
+
# as oneDAL KMeansInit for CSR data is not implemented on GPU
|
|
170
|
+
if is_csr:
|
|
171
|
+
init_policy = self._get_policy(None, None)
|
|
172
|
+
logging.getLogger("sklearnex").info("Running Sparse KMeansInit on CPU")
|
|
173
|
+
else:
|
|
174
|
+
init_policy = policy
|
|
175
|
+
|
|
176
|
+
if isinstance(init, str) and init == "k-means++":
|
|
177
|
+
if not is_csr:
|
|
178
|
+
alg = self._get_kmeans_init(
|
|
179
|
+
cluster_count=n_clusters,
|
|
180
|
+
seed=random_seed,
|
|
181
|
+
algorithm="plus_plus_dense",
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
alg = self._get_kmeans_init(
|
|
185
|
+
cluster_count=n_clusters, seed=random_seed, algorithm="plus_plus_csr"
|
|
186
|
+
)
|
|
187
|
+
centers_table = alg.compute_raw(X_table, init_policy, dtype)
|
|
188
|
+
elif isinstance(init, str) and init == "random":
|
|
189
|
+
if not is_csr:
|
|
190
|
+
alg = self._get_kmeans_init(
|
|
191
|
+
cluster_count=n_clusters, seed=random_seed, algorithm="random_dense"
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
alg = self._get_kmeans_init(
|
|
195
|
+
cluster_count=n_clusters, seed=random_seed, algorithm="random_csr"
|
|
196
|
+
)
|
|
197
|
+
centers_table = alg.compute_raw(X_table, init_policy, dtype)
|
|
198
|
+
elif _is_arraylike_not_scalar(init):
|
|
199
|
+
if _is_csr(init):
|
|
200
|
+
# oneDAL KMeans only supports Dense Centroids
|
|
201
|
+
centers = init.toarray()
|
|
202
|
+
else:
|
|
203
|
+
centers = np.asarray(init)
|
|
204
|
+
assert centers.shape[0] == n_clusters
|
|
205
|
+
assert centers.shape[1] == X_table.column_count
|
|
206
|
+
# KMeans is implemented on both CPU and GPU for Dense and CSR data
|
|
207
|
+
# The original policy can be used here
|
|
208
|
+
centers = _convert_to_supported(policy, centers)
|
|
209
|
+
centers_table = to_table(centers)
|
|
210
|
+
else:
|
|
211
|
+
raise TypeError("Unsupported type of the `init` value")
|
|
212
|
+
|
|
213
|
+
return centers_table
|
|
214
|
+
|
|
215
|
+
def _init_centroids_sklearn(self, X, init, random_state, policy, dtype=np.float32):
|
|
216
|
+
# For oneDAL versions < 2023.2 or callable init,
|
|
217
|
+
# using the scikit-learn implementation
|
|
218
|
+
logging.getLogger("sklearnex").info("Computing KMeansInit with Stock sklearn")
|
|
219
|
+
n_samples = X.shape[0]
|
|
220
|
+
|
|
221
|
+
if isinstance(init, str) and init == "k-means++":
|
|
222
|
+
centers, _ = _kmeans_plusplus(
|
|
223
|
+
X,
|
|
224
|
+
self.n_clusters,
|
|
225
|
+
random_state=random_state,
|
|
226
|
+
)
|
|
227
|
+
elif isinstance(init, str) and init == "random":
|
|
228
|
+
seeds = random_state.choice(n_samples, size=self.n_clusters, replace=False)
|
|
229
|
+
centers = X[seeds]
|
|
230
|
+
elif callable(init):
|
|
231
|
+
cc_arr = init(X, self.n_clusters, random_state)
|
|
232
|
+
cc_arr = np.ascontiguousarray(cc_arr, dtype=dtype)
|
|
233
|
+
self._validate_center_shape(X, cc_arr)
|
|
234
|
+
centers = cc_arr
|
|
235
|
+
elif _is_arraylike_not_scalar(init):
|
|
236
|
+
centers = init
|
|
237
|
+
else:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"init should be either 'k-means++', 'random', a ndarray or a "
|
|
240
|
+
f"callable, got '{ init }' instead."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
centers = _convert_to_supported(policy, centers)
|
|
244
|
+
return to_table(centers)
|
|
245
|
+
|
|
246
|
+
def _fit_backend(
|
|
247
|
+
self, X_table, centroids_table, module, policy, dtype=np.float32, is_csr=False
|
|
248
|
+
):
|
|
249
|
+
params = self._get_onedal_params(is_csr, dtype)
|
|
250
|
+
|
|
251
|
+
meta = _backend.get_table_metadata(X_table)
|
|
252
|
+
assert meta.get_npy_dtype(0) == dtype
|
|
253
|
+
|
|
254
|
+
result = module.train(policy, params, X_table, centroids_table)
|
|
255
|
+
|
|
256
|
+
return (
|
|
257
|
+
result.responses,
|
|
258
|
+
result.objective_function_value,
|
|
259
|
+
result.model,
|
|
260
|
+
result.iteration_count,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _fit(self, X, module, queue=None):
|
|
264
|
+
policy = self._get_policy(queue, X)
|
|
265
|
+
is_csr = _is_csr(X)
|
|
266
|
+
X = _check_array(
|
|
267
|
+
X, dtype=[np.float64, np.float32], accept_sparse="csr", force_all_finite=False
|
|
268
|
+
)
|
|
269
|
+
X = _convert_to_supported(policy, X)
|
|
270
|
+
dtype = get_dtype(X)
|
|
271
|
+
X_table = to_table(X)
|
|
272
|
+
|
|
273
|
+
self._check_params_vs_input(X_table, is_csr, policy, dtype=dtype)
|
|
274
|
+
|
|
275
|
+
params = self._get_onedal_params(is_csr, dtype)
|
|
276
|
+
|
|
277
|
+
self.n_features_in_ = X_table.column_count
|
|
278
|
+
|
|
279
|
+
best_model, best_n_iter = None, None
|
|
280
|
+
best_inertia, best_labels = None, None
|
|
281
|
+
|
|
282
|
+
def is_better_iteration(inertia, labels):
|
|
283
|
+
if best_inertia is None:
|
|
284
|
+
return True
|
|
285
|
+
else:
|
|
286
|
+
mod = self._get_backend("kmeans_common", None, None)
|
|
287
|
+
better_inertia = inertia < best_inertia
|
|
288
|
+
same_clusters = mod._is_same_clustering(
|
|
289
|
+
labels, best_labels, self.n_clusters
|
|
290
|
+
)
|
|
291
|
+
return better_inertia and not same_clusters
|
|
292
|
+
|
|
293
|
+
random_state = check_random_state(self.random_state)
|
|
294
|
+
|
|
295
|
+
init = self.init
|
|
296
|
+
init_is_array_like = _is_arraylike_not_scalar(init)
|
|
297
|
+
if init_is_array_like:
|
|
298
|
+
init = _check_array(
|
|
299
|
+
init, dtype=dtype, accept_sparse="csr", copy=True, order="C"
|
|
300
|
+
)
|
|
301
|
+
self._validate_center_shape(X, init)
|
|
302
|
+
|
|
303
|
+
use_onedal_init = daal_check_version((2023, "P", 200)) and not callable(self.init)
|
|
304
|
+
|
|
305
|
+
for _ in range(self._n_init):
|
|
306
|
+
if use_onedal_init:
|
|
307
|
+
random_seed = random_state.randint(np.iinfo("i").max)
|
|
308
|
+
centroids_table = self._init_centroids_onedal(
|
|
309
|
+
X_table, init, random_seed, policy, is_csr, dtype=dtype
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
centroids_table = self._init_centroids_sklearn(
|
|
313
|
+
X, init, random_state, policy, dtype=dtype
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if self.verbose:
|
|
317
|
+
print("Initialization complete")
|
|
318
|
+
|
|
319
|
+
labels, inertia, model, n_iter = self._fit_backend(
|
|
320
|
+
X_table, centroids_table, module, policy, dtype, is_csr
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if self.verbose:
|
|
324
|
+
print("Iteration {}, inertia {}.".format(n_iter, inertia))
|
|
325
|
+
|
|
326
|
+
if is_better_iteration(inertia, labels):
|
|
327
|
+
best_model, best_n_iter = model, n_iter
|
|
328
|
+
best_inertia, best_labels = inertia, labels
|
|
329
|
+
|
|
330
|
+
# Types without conversion
|
|
331
|
+
self.model_ = best_model
|
|
332
|
+
|
|
333
|
+
# Simple types
|
|
334
|
+
self.n_iter_ = best_n_iter
|
|
335
|
+
self.inertia_ = best_inertia
|
|
336
|
+
|
|
337
|
+
# Complex type conversion
|
|
338
|
+
self.labels_ = from_table(best_labels).ravel()
|
|
339
|
+
|
|
340
|
+
distinct_clusters = len(np.unique(self.labels_))
|
|
341
|
+
if distinct_clusters < self.n_clusters:
|
|
342
|
+
warnings.warn(
|
|
343
|
+
"Number of distinct clusters ({}) found smaller than "
|
|
344
|
+
"n_clusters ({}). Possibly due to duplicate points "
|
|
345
|
+
"in X.".format(distinct_clusters, self.n_clusters),
|
|
346
|
+
ConvergenceWarning,
|
|
347
|
+
stacklevel=2,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return self
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def cluster_centers_(self):
|
|
354
|
+
if not hasattr(self, "_cluster_centers_"):
|
|
355
|
+
if hasattr(self, "model_"):
|
|
356
|
+
centroids = self.model_.centroids
|
|
357
|
+
self._cluster_centers_ = from_table(centroids)
|
|
358
|
+
else:
|
|
359
|
+
raise NameError("This model have not been trained")
|
|
360
|
+
return self._cluster_centers_
|
|
361
|
+
|
|
362
|
+
@cluster_centers_.setter
|
|
363
|
+
def cluster_centers_(self, cluster_centers):
|
|
364
|
+
self._cluster_centers_ = np.asarray(cluster_centers)
|
|
365
|
+
|
|
366
|
+
self.n_iter_ = 0
|
|
367
|
+
self.inertia_ = 0
|
|
368
|
+
|
|
369
|
+
self.model_ = self._get_backend("kmeans", "clustering", "model")
|
|
370
|
+
self.model_.centroids = to_table(self._cluster_centers_)
|
|
371
|
+
self.n_features_in_ = self.model_.centroids.column_count
|
|
372
|
+
self.labels_ = np.arange(self.model_.centroids.row_count)
|
|
373
|
+
|
|
374
|
+
return self
|
|
375
|
+
|
|
376
|
+
@cluster_centers_.deleter
|
|
377
|
+
def cluster_centers_(self):
|
|
378
|
+
del self._cluster_centers_
|
|
379
|
+
|
|
380
|
+
def _predict(self, X, module, queue=None, result_options=None):
|
|
381
|
+
is_csr = _is_csr(X)
|
|
382
|
+
|
|
383
|
+
policy = self._get_policy(queue, X)
|
|
384
|
+
X = _convert_to_supported(policy, X)
|
|
385
|
+
X_table, dtype = to_table(X), X.dtype
|
|
386
|
+
params = self._get_onedal_params(is_csr, dtype, result_options)
|
|
387
|
+
|
|
388
|
+
result = module.infer(policy, params, self.model_, X_table)
|
|
389
|
+
|
|
390
|
+
if (
|
|
391
|
+
result_options == "compute_exact_objective_function"
|
|
392
|
+
): # This is only set for score function
|
|
393
|
+
return result.objective_function_value * (-1)
|
|
394
|
+
else:
|
|
395
|
+
return from_table(result.responses).ravel()
|
|
396
|
+
|
|
397
|
+
def _score(self, X, module, queue=None):
|
|
398
|
+
result_options = "compute_exact_objective_function"
|
|
399
|
+
|
|
400
|
+
return self._predict(
|
|
401
|
+
X, self._get_backend("kmeans", "clustering", None), queue, result_options
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
def _transform(self, X):
|
|
405
|
+
return euclidean_distances(X, self.cluster_centers_)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class KMeans(_BaseKMeans):
|
|
409
|
+
def __init__(
|
|
410
|
+
self,
|
|
411
|
+
n_clusters=8,
|
|
412
|
+
*,
|
|
413
|
+
init="k-means++",
|
|
414
|
+
n_init="auto",
|
|
415
|
+
max_iter=300,
|
|
416
|
+
tol=1e-4,
|
|
417
|
+
verbose=0,
|
|
418
|
+
random_state=None,
|
|
419
|
+
copy_x=True,
|
|
420
|
+
algorithm="lloyd",
|
|
421
|
+
):
|
|
422
|
+
super().__init__(
|
|
423
|
+
n_clusters=n_clusters,
|
|
424
|
+
init=init,
|
|
425
|
+
n_init=n_init,
|
|
426
|
+
max_iter=max_iter,
|
|
427
|
+
tol=tol,
|
|
428
|
+
verbose=verbose,
|
|
429
|
+
random_state=random_state,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
self.copy_x = copy_x
|
|
433
|
+
self.algorithm = algorithm
|
|
434
|
+
assert self.algorithm == "lloyd"
|
|
435
|
+
|
|
436
|
+
def fit(self, X, y=None, queue=None):
|
|
437
|
+
return super()._fit(X, self._get_backend("kmeans", "clustering", None), queue)
|
|
438
|
+
|
|
439
|
+
def predict(self, X, queue=None):
|
|
440
|
+
"""Predict the closest cluster each sample in X belongs to.
|
|
441
|
+
|
|
442
|
+
In the vector quantization literature, `cluster_centers_` is called
|
|
443
|
+
the code book and each value returned by `predict` is the index of
|
|
444
|
+
the closest code in the code book.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
X : array-like of shape (n_samples, n_features)
|
|
449
|
+
New data to predict.
|
|
450
|
+
|
|
451
|
+
Returns
|
|
452
|
+
-------
|
|
453
|
+
labels : ndarray of shape (n_samples,)
|
|
454
|
+
Index of the cluster each sample belongs to.
|
|
455
|
+
"""
|
|
456
|
+
return super()._predict(X, self._get_backend("kmeans", "clustering", None), queue)
|
|
457
|
+
|
|
458
|
+
def fit_predict(self, X, y=None, queue=None):
|
|
459
|
+
"""Compute cluster centers and predict cluster index for each sample.
|
|
460
|
+
|
|
461
|
+
Convenience method; equivalent to calling fit(X) followed by
|
|
462
|
+
predict(X).
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
X : array-like of shape (n_samples, n_features)
|
|
467
|
+
New data to transform.
|
|
468
|
+
|
|
469
|
+
y : Ignored
|
|
470
|
+
Not used, present here for API consistency by convention.
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
labels : ndarray of shape (n_samples,)
|
|
475
|
+
Index of the cluster each sample belongs to.
|
|
476
|
+
"""
|
|
477
|
+
return self.fit(X, queue=queue).labels_
|
|
478
|
+
|
|
479
|
+
def fit_transform(self, X, y=None, queue=None):
|
|
480
|
+
"""Compute clustering and transform X to cluster-distance space.
|
|
481
|
+
|
|
482
|
+
Equivalent to fit(X).transform(X), but more efficiently implemented.
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
X : array-like of shape (n_samples, n_features)
|
|
487
|
+
New data to transform.
|
|
488
|
+
|
|
489
|
+
y : Ignored
|
|
490
|
+
Not used, present here for API consistency by convention.
|
|
491
|
+
|
|
492
|
+
Returns
|
|
493
|
+
-------
|
|
494
|
+
X_new : ndarray of shape (n_samples, n_clusters)
|
|
495
|
+
X transformed in the new space.
|
|
496
|
+
"""
|
|
497
|
+
return self.fit(X, queue=queue)._transform(X)
|
|
498
|
+
|
|
499
|
+
def transform(self, X):
|
|
500
|
+
"""Transform X to a cluster-distance space.
|
|
501
|
+
|
|
502
|
+
In the new space, each dimension is the distance to the cluster
|
|
503
|
+
centers. Note that even if X is sparse, the array returned by
|
|
504
|
+
`transform` will typically be dense.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
X : array-like of shape (n_samples, n_features)
|
|
509
|
+
New data to transform.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
X_new : ndarray of shape (n_samples, n_clusters)
|
|
514
|
+
X transformed in the new space.
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
return self._transform(X)
|
|
518
|
+
|
|
519
|
+
def score(self, X, queue=None):
|
|
520
|
+
"""Opposite of the value of X on the K-means objective.
|
|
521
|
+
|
|
522
|
+
Parameters
|
|
523
|
+
----------
|
|
524
|
+
X: {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
525
|
+
New data.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
score: float
|
|
530
|
+
Opposite of the value of X on the K-means objective.
|
|
531
|
+
"""
|
|
532
|
+
return super()._score(X, self._get_backend("kmeans", "clustering", None), queue)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def k_means(
|
|
536
|
+
X,
|
|
537
|
+
n_clusters,
|
|
538
|
+
*,
|
|
539
|
+
init="k-means++",
|
|
540
|
+
n_init="auto",
|
|
541
|
+
max_iter=300,
|
|
542
|
+
verbose=False,
|
|
543
|
+
tol=1e-4,
|
|
544
|
+
random_state=None,
|
|
545
|
+
copy_x=True,
|
|
546
|
+
algorithm="lloyd",
|
|
547
|
+
return_n_iter=False,
|
|
548
|
+
queue=None,
|
|
549
|
+
):
|
|
550
|
+
est = KMeans(
|
|
551
|
+
n_clusters=n_clusters,
|
|
552
|
+
init=init,
|
|
553
|
+
n_init=n_init,
|
|
554
|
+
max_iter=max_iter,
|
|
555
|
+
verbose=verbose,
|
|
556
|
+
tol=tol,
|
|
557
|
+
random_state=random_state,
|
|
558
|
+
copy_x=copy_x,
|
|
559
|
+
algorithm=algorithm,
|
|
560
|
+
).fit(X, queue=queue)
|
|
561
|
+
if return_n_iter:
|
|
562
|
+
return est.cluster_centers_, est.labels_, est.inertia_, est.n_iter_
|
|
563
|
+
else:
|
|
564
|
+
return est.cluster_centers_, est.labels_, est.inertia_
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from scipy.sparse import issparse
|
|
19
|
+
from sklearn.utils import check_random_state
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
22
|
+
|
|
23
|
+
from ..common._base import BaseEstimator as onedal_BaseEstimator
|
|
24
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
25
|
+
from ..utils import _check_array
|
|
26
|
+
|
|
27
|
+
if daal_check_version((2023, "P", 200)):
|
|
28
|
+
|
|
29
|
+
class KMeansInit(onedal_BaseEstimator):
|
|
30
|
+
"""
|
|
31
|
+
KMeansInit oneDAL implementation.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
cluster_count,
|
|
37
|
+
seed=777,
|
|
38
|
+
local_trials_count=None,
|
|
39
|
+
algorithm="plus_plus_dense",
|
|
40
|
+
):
|
|
41
|
+
self.cluster_count = cluster_count
|
|
42
|
+
self.seed = seed
|
|
43
|
+
self.local_trials_count = local_trials_count
|
|
44
|
+
self.algorithm = algorithm
|
|
45
|
+
|
|
46
|
+
if local_trials_count is None:
|
|
47
|
+
self.local_trials_count = 2 + int(np.log(cluster_count))
|
|
48
|
+
else:
|
|
49
|
+
self.local_trials_count = local_trials_count
|
|
50
|
+
|
|
51
|
+
def _get_onedal_params(self, dtype=np.float32):
|
|
52
|
+
return {
|
|
53
|
+
"fptype": "float" if dtype == np.float32 else "double",
|
|
54
|
+
"local_trials_count": self.local_trials_count,
|
|
55
|
+
"method": self.algorithm,
|
|
56
|
+
"seed": self.seed,
|
|
57
|
+
"cluster_count": self.cluster_count,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def _get_params_and_input(self, X, policy):
|
|
61
|
+
X = _check_array(
|
|
62
|
+
X,
|
|
63
|
+
dtype=[np.float64, np.float32],
|
|
64
|
+
accept_sparse="csr",
|
|
65
|
+
force_all_finite=False,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
X = _convert_to_supported(policy, X)
|
|
69
|
+
|
|
70
|
+
dtype = get_dtype(X)
|
|
71
|
+
params = self._get_onedal_params(dtype)
|
|
72
|
+
return (params, to_table(X), dtype)
|
|
73
|
+
|
|
74
|
+
def _compute_raw(self, X_table, module, policy, dtype=np.float32):
|
|
75
|
+
params = self._get_onedal_params(dtype)
|
|
76
|
+
|
|
77
|
+
result = module.compute(policy, params, X_table)
|
|
78
|
+
|
|
79
|
+
return result.centroids
|
|
80
|
+
|
|
81
|
+
def _compute(self, X, module, queue):
|
|
82
|
+
policy = self._get_policy(queue, X)
|
|
83
|
+
# oneDAL KMeans Init for sparse data does not have GPU support
|
|
84
|
+
if issparse(X):
|
|
85
|
+
policy = self._get_policy(None, None)
|
|
86
|
+
_, X_table, dtype = self._get_params_and_input(X, policy)
|
|
87
|
+
|
|
88
|
+
centroids = self._compute_raw(X_table, module, policy, dtype)
|
|
89
|
+
|
|
90
|
+
return from_table(centroids)
|
|
91
|
+
|
|
92
|
+
def compute_raw(self, X_table, policy, dtype=np.float32):
|
|
93
|
+
return self._compute_raw(
|
|
94
|
+
X_table, self._get_backend("kmeans_init", "init", None), policy, dtype
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def compute(self, X, queue=None):
|
|
98
|
+
return self._compute(X, self._get_backend("kmeans_init", "init", None), queue)
|
|
99
|
+
|
|
100
|
+
def kmeans_plusplus(
|
|
101
|
+
X,
|
|
102
|
+
n_clusters,
|
|
103
|
+
*,
|
|
104
|
+
x_squared_norms=None,
|
|
105
|
+
random_state=None,
|
|
106
|
+
n_local_trials=None,
|
|
107
|
+
queue=None,
|
|
108
|
+
):
|
|
109
|
+
random_seed = check_random_state(random_state).tomaxint()
|
|
110
|
+
return (
|
|
111
|
+
KMeansInit(
|
|
112
|
+
n_clusters, seed=random_seed, local_trials_count=n_local_trials
|
|
113
|
+
).compute(X, queue),
|
|
114
|
+
np.full(n_clusters, -1),
|
|
115
|
+
)
|