scikit-learn-intelex 2024.6.0__py310-none-win_amd64.whl → 2025.0.0__py310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/_daal4py.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +242 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +241 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +3 -3
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +155 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +4 -2
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/_config.py +53 -0
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal}/_device_offload.py +104 -132
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +560 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +116 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +75 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +95 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +235 -0
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +720 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +149 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +778 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/tests/test_common.py +41 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +168 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +107 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +91 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/utils/validation.py +432 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/_config.py +3 -15
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +121 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +140 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +5 -5
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +1 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -1
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +383 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +153 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +68 -17
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +46 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +25 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +113 -9
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +9 -36
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +9 -12
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +2 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +5 -6
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +418 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +2 -34
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +79 -59
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +24 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +13 -10
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +28 -3
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +46 -3
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +21 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +11 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +45 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +1 -20
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +1 -20
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +31 -7
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +8 -8
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +2 -2
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +19 -17
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +419 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +19 -21
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +1 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +1 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +143 -20
- scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/tests/_utils_spmd.py +198 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +4 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +2 -1
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +12 -4
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +16 -14
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +33 -20
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +1 -2
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +5 -20
- {scikit_learn_intelex-2024.6.0.dist-info → scikit_learn_intelex-2025.0.0.dist-info}/METADATA +3 -2
- scikit_learn_intelex-2025.0.0.dist-info/RECORD +255 -0
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -303
- scikit_learn_intelex-2024.6.0.dist-info/RECORD +0 -108
- {scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal}/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/conftest.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_common.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +0 -0
- {scikit_learn_intelex-2024.6.0.data → scikit_learn_intelex-2025.0.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.6.0.dist-info → scikit_learn_intelex-2025.0.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.6.0.dist-info → scikit_learn_intelex-2025.0.0.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2024.6.0.dist-info → scikit_learn_intelex-2025.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from scipy.sparse import issparse
|
|
19
|
+
from sklearn.utils import check_random_state
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
22
|
+
|
|
23
|
+
from ..common._base import BaseEstimator as onedal_BaseEstimator
|
|
24
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
25
|
+
from ..utils import _check_array
|
|
26
|
+
|
|
27
|
+
if daal_check_version((2023, "P", 200)):
|
|
28
|
+
|
|
29
|
+
class KMeansInit(onedal_BaseEstimator):
|
|
30
|
+
"""
|
|
31
|
+
KMeansInit oneDAL implementation.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
cluster_count,
|
|
37
|
+
seed=777,
|
|
38
|
+
local_trials_count=None,
|
|
39
|
+
algorithm="plus_plus_dense",
|
|
40
|
+
):
|
|
41
|
+
self.cluster_count = cluster_count
|
|
42
|
+
self.seed = seed
|
|
43
|
+
self.local_trials_count = local_trials_count
|
|
44
|
+
self.algorithm = algorithm
|
|
45
|
+
|
|
46
|
+
if local_trials_count is None:
|
|
47
|
+
self.local_trials_count = 2 + int(np.log(cluster_count))
|
|
48
|
+
else:
|
|
49
|
+
self.local_trials_count = local_trials_count
|
|
50
|
+
|
|
51
|
+
def _get_onedal_params(self, dtype=np.float32):
|
|
52
|
+
return {
|
|
53
|
+
"fptype": "float" if dtype == np.float32 else "double",
|
|
54
|
+
"local_trials_count": self.local_trials_count,
|
|
55
|
+
"method": self.algorithm,
|
|
56
|
+
"seed": self.seed,
|
|
57
|
+
"cluster_count": self.cluster_count,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def _get_params_and_input(self, X, policy):
|
|
61
|
+
X = _check_array(
|
|
62
|
+
X,
|
|
63
|
+
dtype=[np.float64, np.float32],
|
|
64
|
+
accept_sparse="csr",
|
|
65
|
+
force_all_finite=False,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
X = _convert_to_supported(policy, X)
|
|
69
|
+
|
|
70
|
+
dtype = get_dtype(X)
|
|
71
|
+
params = self._get_onedal_params(dtype)
|
|
72
|
+
return (params, to_table(X), dtype)
|
|
73
|
+
|
|
74
|
+
def _compute_raw(self, X_table, module, policy, dtype=np.float32):
|
|
75
|
+
params = self._get_onedal_params(dtype)
|
|
76
|
+
|
|
77
|
+
result = module.compute(policy, params, X_table)
|
|
78
|
+
|
|
79
|
+
return result.centroids
|
|
80
|
+
|
|
81
|
+
def _compute(self, X, module, queue):
|
|
82
|
+
policy = self._get_policy(queue, X)
|
|
83
|
+
# oneDAL KMeans Init for sparse data does not have GPU support
|
|
84
|
+
if issparse(X):
|
|
85
|
+
policy = self._get_policy(None, None)
|
|
86
|
+
_, X_table, dtype = self._get_params_and_input(X, policy)
|
|
87
|
+
|
|
88
|
+
centroids = self._compute_raw(X_table, module, policy, dtype)
|
|
89
|
+
|
|
90
|
+
return from_table(centroids)
|
|
91
|
+
|
|
92
|
+
def compute_raw(self, X_table, policy, dtype=np.float32):
|
|
93
|
+
return self._compute_raw(
|
|
94
|
+
X_table, self._get_backend("kmeans_init", "init", None), policy, dtype
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def compute(self, X, queue=None):
|
|
98
|
+
return self._compute(X, self._get_backend("kmeans_init", "init", None), queue)
|
|
99
|
+
|
|
100
|
+
def kmeans_plusplus(
|
|
101
|
+
X,
|
|
102
|
+
n_clusters,
|
|
103
|
+
*,
|
|
104
|
+
x_squared_norms=None,
|
|
105
|
+
random_state=None,
|
|
106
|
+
n_local_trials=None,
|
|
107
|
+
queue=None,
|
|
108
|
+
):
|
|
109
|
+
random_seed = check_random_state(random_state).tomaxint()
|
|
110
|
+
return (
|
|
111
|
+
KMeansInit(
|
|
112
|
+
n_clusters, seed=random_seed, local_trials_count=n_local_trials
|
|
113
|
+
).compute(X, queue),
|
|
114
|
+
np.full(n_clusters, -1),
|
|
115
|
+
)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from sklearn.cluster import DBSCAN as DBSCAN_SKLEARN
|
|
20
|
+
from sklearn.cluster.tests.common import generate_clustered_data
|
|
21
|
+
|
|
22
|
+
from onedal.cluster import DBSCAN as ONEDAL_DBSCAN
|
|
23
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def generate_data(
|
|
27
|
+
low: int, high: int, samples_number: int, sample_dimension: tuple
|
|
28
|
+
) -> tuple:
|
|
29
|
+
generator = np.random.RandomState()
|
|
30
|
+
table_size = (samples_number, sample_dimension)
|
|
31
|
+
return generator.uniform(low=low, high=high, size=table_size), generator.uniform(
|
|
32
|
+
size=samples_number
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def check_labels_equals(left_labels: np.ndarray, right_labels: np.ndarray) -> bool:
|
|
37
|
+
if left_labels.shape != right_labels.shape:
|
|
38
|
+
raise Exception("Shapes not equal")
|
|
39
|
+
if len(left_labels.shape) != 1:
|
|
40
|
+
raise Exception("Shapes size not equals 1")
|
|
41
|
+
if len(set(left_labels)) != len(set(right_labels)):
|
|
42
|
+
raise Exception("Cluster counts not equal")
|
|
43
|
+
dict_checker = {}
|
|
44
|
+
for index_sample in range(left_labels.shape[0]):
|
|
45
|
+
if left_labels[index_sample] not in dict_checker:
|
|
46
|
+
dict_checker[left_labels[index_sample]] = right_labels[index_sample]
|
|
47
|
+
elif dict_checker[left_labels[index_sample]] != right_labels[index_sample]:
|
|
48
|
+
raise Exception("Wrong clustering")
|
|
49
|
+
return True
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _test_dbscan_big_data_numpy_gen(
|
|
53
|
+
queue,
|
|
54
|
+
eps: float,
|
|
55
|
+
min_samples: int,
|
|
56
|
+
metric: str,
|
|
57
|
+
use_weights: bool,
|
|
58
|
+
low=-100.0,
|
|
59
|
+
high=100.0,
|
|
60
|
+
samples_number=1000,
|
|
61
|
+
sample_dimension=4,
|
|
62
|
+
):
|
|
63
|
+
data, weights = generate_data(
|
|
64
|
+
low=low,
|
|
65
|
+
high=high,
|
|
66
|
+
samples_number=samples_number,
|
|
67
|
+
sample_dimension=sample_dimension,
|
|
68
|
+
)
|
|
69
|
+
if use_weights is False:
|
|
70
|
+
weights = None
|
|
71
|
+
initialized_daal_dbscan = ONEDAL_DBSCAN(
|
|
72
|
+
eps=eps, min_samples=min_samples, metric=metric
|
|
73
|
+
).fit(X=data, sample_weight=weights, queue=queue)
|
|
74
|
+
initialized_sklearn_dbscan = DBSCAN_SKLEARN(
|
|
75
|
+
metric=metric, eps=eps, min_samples=min_samples
|
|
76
|
+
).fit(X=data, sample_weight=weights)
|
|
77
|
+
check_labels_equals(
|
|
78
|
+
initialized_daal_dbscan.labels_, initialized_sklearn_dbscan.labels_
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize(
|
|
83
|
+
"metric",
|
|
84
|
+
[
|
|
85
|
+
"euclidean",
|
|
86
|
+
],
|
|
87
|
+
)
|
|
88
|
+
@pytest.mark.parametrize("use_weights", [True, False])
|
|
89
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
90
|
+
def test_dbscan_big_data_numpy_gen(queue, metric, use_weights: bool):
|
|
91
|
+
eps = 35.0
|
|
92
|
+
min_samples = 6
|
|
93
|
+
_test_dbscan_big_data_numpy_gen(
|
|
94
|
+
queue, eps=eps, min_samples=min_samples, metric=metric, use_weights=use_weights
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool):
|
|
99
|
+
eps_begin = 0.05
|
|
100
|
+
eps_end = 0.5
|
|
101
|
+
eps_step = 0.05
|
|
102
|
+
min_samples_begin = 5
|
|
103
|
+
min_samples_end = 15
|
|
104
|
+
min_samples_step = 1
|
|
105
|
+
for eps in np.arange(eps_begin, eps_end, eps_step):
|
|
106
|
+
for min_samples in range(min_samples_begin, min_samples_end, min_samples_step):
|
|
107
|
+
_test_dbscan_big_data_numpy_gen(
|
|
108
|
+
queue,
|
|
109
|
+
eps=eps,
|
|
110
|
+
min_samples=min_samples,
|
|
111
|
+
metric=metric,
|
|
112
|
+
use_weights=use_weights,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@pytest.mark.parametrize(
|
|
117
|
+
"metric",
|
|
118
|
+
[
|
|
119
|
+
"euclidean",
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
@pytest.mark.parametrize("use_weights", [True, False])
|
|
123
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
124
|
+
def test_across_grid_parameter_numpy_gen(queue, metric, use_weights: bool):
|
|
125
|
+
_test_across_grid_parameter_numpy_gen(queue, metric=metric, use_weights=use_weights)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from numpy.testing import assert_array_equal
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import daal_check_version
|
|
22
|
+
|
|
23
|
+
if daal_check_version((2023, "P", 200)):
|
|
24
|
+
from sklearn.cluster import kmeans_plusplus as init_external
|
|
25
|
+
from sklearn.neighbors import NearestNeighbors
|
|
26
|
+
|
|
27
|
+
from onedal.cluster import KMeans
|
|
28
|
+
from onedal.cluster import kmeans_plusplus as init_internal
|
|
29
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
30
|
+
|
|
31
|
+
def generate_dataset(n_dim, n_cluster, n_points=None, seed=777, dtype=np.float32):
|
|
32
|
+
# We need some reference value of points for each cluster
|
|
33
|
+
n_points = (n_dim * n_cluster) if n_points is None else n_points
|
|
34
|
+
|
|
35
|
+
# Creating generator and generating cluster points
|
|
36
|
+
gen = np.random.Generator(np.random.MT19937(seed))
|
|
37
|
+
cs = gen.uniform(low=-1.0, high=+1.0, size=(n_cluster, n_dim))
|
|
38
|
+
|
|
39
|
+
# Finding variances for each cluster using 3 sigma criteria
|
|
40
|
+
# It ensures that point is in the Voronoi cell of cluster
|
|
41
|
+
nn = NearestNeighbors(n_neighbors=2)
|
|
42
|
+
d, i = nn.fit(cs).kneighbors(cs)
|
|
43
|
+
assert_array_equal(i[:, 0], np.arange(n_cluster))
|
|
44
|
+
vs = d[:, 1] / 3
|
|
45
|
+
|
|
46
|
+
# Generating dataset
|
|
47
|
+
def gen_one(c):
|
|
48
|
+
params = {"loc": cs[c, :], "scale": vs[c], "size": (n_points, n_dim)}
|
|
49
|
+
return gen.normal(**params)
|
|
50
|
+
|
|
51
|
+
data = [gen_one(c) for c in range(n_cluster)]
|
|
52
|
+
data = np.concatenate(data, axis=0)
|
|
53
|
+
gen.shuffle(data, axis=0)
|
|
54
|
+
|
|
55
|
+
data = data.astype(dtype)
|
|
56
|
+
|
|
57
|
+
return (cs, vs, data)
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
60
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
61
|
+
@pytest.mark.parametrize("n_dim", [3, 4, 17, 24])
|
|
62
|
+
@pytest.mark.parametrize("n_cluster", [9, 11, 32])
|
|
63
|
+
@pytest.mark.parametrize("pipeline", ["implicit", "external", "internal"])
|
|
64
|
+
def test_generated_dataset(queue, dtype, n_dim, n_cluster, pipeline):
|
|
65
|
+
seed = 777 * n_dim * n_cluster
|
|
66
|
+
cs, vs, X = generate_dataset(n_dim, n_cluster, seed=seed, dtype=dtype)
|
|
67
|
+
|
|
68
|
+
if pipeline == "external":
|
|
69
|
+
init_data, _ = init_external(X, n_cluster)
|
|
70
|
+
m = KMeans(n_cluster, init=init_data, max_iter=5)
|
|
71
|
+
elif pipeline == "internal":
|
|
72
|
+
init_data, _ = init_internal(X, n_cluster, queue=queue)
|
|
73
|
+
m = KMeans(n_cluster, init=init_data, max_iter=5)
|
|
74
|
+
else:
|
|
75
|
+
m = KMeans(n_cluster, init="k-means++", max_iter=5)
|
|
76
|
+
|
|
77
|
+
m.fit(X, queue=queue)
|
|
78
|
+
|
|
79
|
+
rs_centroids = m.cluster_centers_
|
|
80
|
+
nn = NearestNeighbors(n_neighbors=1)
|
|
81
|
+
d, i = nn.fit(rs_centroids).kneighbors(cs)
|
|
82
|
+
# We have applied 3 sigma rule once
|
|
83
|
+
desired_accuracy = int(0.9973 * n_cluster)
|
|
84
|
+
correctness = d.reshape(-1) <= (vs * 3)
|
|
85
|
+
exp_accuracy = np.count_nonzero(correctness)
|
|
86
|
+
|
|
87
|
+
# TODO: investigate accuracy with kmeans++ init and remove - 1
|
|
88
|
+
assert desired_accuracy - 1 <= exp_accuracy
|
scikit_learn_intelex-2025.0.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from numpy.testing import assert_array_equal
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import daal_check_version
|
|
22
|
+
|
|
23
|
+
if daal_check_version((2023, "P", 200)):
|
|
24
|
+
from sklearn.datasets import load_breast_cancer
|
|
25
|
+
from sklearn.metrics import davies_bouldin_score
|
|
26
|
+
|
|
27
|
+
from onedal.cluster import KMeans, kmeans_plusplus
|
|
28
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
29
|
+
|
|
30
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
31
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
32
|
+
@pytest.mark.parametrize("n_cluster", [2, 5, 11, 128])
|
|
33
|
+
def test_breast_cancer(queue, dtype, n_cluster):
|
|
34
|
+
X, _ = load_breast_cancer(return_X_y=True)
|
|
35
|
+
X = np.asarray(X).astype(dtype=dtype)
|
|
36
|
+
init_data, _ = kmeans_plusplus(X, n_cluster, random_state=777, queue=queue)
|
|
37
|
+
m = KMeans(n_cluster, init=init_data, max_iter=1)
|
|
38
|
+
res = davies_bouldin_score(X, m.fit(X).predict(X))
|
|
39
|
+
thr = 0.45 if n_cluster < 20 else 0.55
|
|
40
|
+
assert res > thr
|
|
41
|
+
|
|
42
|
+
from sklearn.neighbors import NearestNeighbors
|
|
43
|
+
|
|
44
|
+
def generate_dataset(n_dim, n_cluster, n_points=None, seed=777, dtype=np.float32):
|
|
45
|
+
# We need some reference value of points for each cluster
|
|
46
|
+
n_points = (n_dim * n_cluster) if n_points is None else n_points
|
|
47
|
+
|
|
48
|
+
# Creating generator and generating cluster points
|
|
49
|
+
gen = np.random.Generator(np.random.MT19937(seed))
|
|
50
|
+
cs = gen.uniform(low=-1.0, high=+1.0, size=(n_cluster, n_dim))
|
|
51
|
+
|
|
52
|
+
# Finding variances for each cluster using 3 sigma criteria
|
|
53
|
+
# It ensures that point is in the Voronoi cell of cluster
|
|
54
|
+
nn = NearestNeighbors(n_neighbors=2)
|
|
55
|
+
d, i = nn.fit(cs).kneighbors(cs)
|
|
56
|
+
assert_array_equal(i[:, 0], np.arange(n_cluster))
|
|
57
|
+
vs = d[:, 1] / 3
|
|
58
|
+
|
|
59
|
+
# Generating dataset
|
|
60
|
+
def gen_one(c):
|
|
61
|
+
params = {"loc": cs[c, :], "scale": vs[c], "size": (n_points, n_dim)}
|
|
62
|
+
return gen.normal(**params)
|
|
63
|
+
|
|
64
|
+
data = [gen_one(c) for c in range(n_cluster)]
|
|
65
|
+
data = np.concatenate(data, axis=0)
|
|
66
|
+
gen.shuffle(data, axis=0)
|
|
67
|
+
|
|
68
|
+
data = data.astype(dtype)
|
|
69
|
+
|
|
70
|
+
return (cs, vs, data)
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
73
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
74
|
+
@pytest.mark.parametrize("n_dim", [3, 12, 17])
|
|
75
|
+
@pytest.mark.parametrize("n_cluster", [2, 15, 61])
|
|
76
|
+
def test_generated_dataset(queue, dtype, n_dim, n_cluster):
|
|
77
|
+
seed = 777 * n_dim * n_cluster
|
|
78
|
+
cs, vs, X = generate_dataset(n_dim, n_cluster, seed=seed, dtype=dtype)
|
|
79
|
+
|
|
80
|
+
init_data, _ = kmeans_plusplus(X, n_cluster, random_state=seed, queue=queue)
|
|
81
|
+
m = KMeans(n_cluster, init=init_data, max_iter=3, algorithm="lloyd").fit(X)
|
|
82
|
+
|
|
83
|
+
rs_centroids = m.cluster_centers_
|
|
84
|
+
nn = NearestNeighbors(n_neighbors=1)
|
|
85
|
+
d, i = nn.fit(rs_centroids).kneighbors(cs)
|
|
86
|
+
# We have applied 2 sigma rule once
|
|
87
|
+
desired_accuracy = int(0.9973 * n_cluster)
|
|
88
|
+
if d.dtype == np.float64:
|
|
89
|
+
desired_accuracy = desired_accuracy - 1
|
|
90
|
+
correctness = d.reshape(-1) <= (vs * 3)
|
|
91
|
+
exp_accuracy = np.count_nonzero(correctness)
|
|
92
|
+
|
|
93
|
+
assert desired_accuracy <= exp_accuracy
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from abc import ABC
|
|
18
|
+
|
|
19
|
+
from onedal import _backend
|
|
20
|
+
|
|
21
|
+
from ._policy import _get_policy
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_backend(backend, module, submodule=None, method=None, *args, **kwargs):
|
|
25
|
+
result = getattr(backend, module)
|
|
26
|
+
if submodule:
|
|
27
|
+
result = getattr(result, submodule)
|
|
28
|
+
if method:
|
|
29
|
+
return getattr(result, method)(*args, **kwargs)
|
|
30
|
+
return result
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseEstimator(ABC):
|
|
34
|
+
def _get_backend(self, module, submodule=None, method=None, *args, **kwargs):
|
|
35
|
+
return _get_backend(_backend, module, submodule, method, *args, **kwargs)
|
|
36
|
+
|
|
37
|
+
def _get_policy(self, queue, *data):
|
|
38
|
+
return _get_policy(queue, *data)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2022 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _check_is_fitted(estimator, attributes=None, *, msg=None):
|
|
19
|
+
if msg is None:
|
|
20
|
+
msg = (
|
|
21
|
+
"This %(name)s instance is not fitted yet. Call 'fit' with "
|
|
22
|
+
"appropriate arguments before using this estimator."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if not (
|
|
26
|
+
hasattr(estimator, "fit")
|
|
27
|
+
or (hasattr(estimator, "partial_fit") and hasattr(estimator, "finalize_fit"))
|
|
28
|
+
):
|
|
29
|
+
raise TypeError("%s is not an estimator instance." % (estimator))
|
|
30
|
+
|
|
31
|
+
if attributes is not None:
|
|
32
|
+
if not isinstance(attributes, (list, tuple)):
|
|
33
|
+
attributes = [attributes]
|
|
34
|
+
attrs = all([hasattr(estimator, attr) for attr in attributes])
|
|
35
|
+
else:
|
|
36
|
+
attrs = [v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")]
|
|
37
|
+
|
|
38
|
+
if not attrs:
|
|
39
|
+
raise AttributeError(msg % {"name": type(estimator).__name__})
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _is_classifier(estimator):
|
|
43
|
+
return getattr(estimator, "_estimator_type", None) == "classifier"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _is_regressor(estimator):
|
|
47
|
+
return getattr(estimator, "_estimator_type", None) == "regressor"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ClusterMixin:
|
|
19
|
+
_estimator_type = "clusterer"
|
|
20
|
+
|
|
21
|
+
def fit_predict(self, X, y=None, queue=None, **kwargs):
|
|
22
|
+
self.fit(X, queue=queue, **kwargs)
|
|
23
|
+
return self.labels_
|
|
24
|
+
|
|
25
|
+
def _more_tags(self):
|
|
26
|
+
return {"preserves_dtype": []}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ClassifierMixin:
|
|
30
|
+
_estimator_type = "classifier"
|
|
31
|
+
|
|
32
|
+
def score(self, X, y, sample_weight=None, queue=None):
|
|
33
|
+
from sklearn.metrics import accuracy_score
|
|
34
|
+
|
|
35
|
+
return accuracy_score(
|
|
36
|
+
y, self.predict(X, queue=queue), sample_weight=sample_weight
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def _more_tags(self):
|
|
40
|
+
return {"requires_y": True}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RegressorMixin:
|
|
44
|
+
_estimator_type = "regressor"
|
|
45
|
+
|
|
46
|
+
def score(self, X, y, sample_weight=None, queue=None):
|
|
47
|
+
from sklearn.metrics import r2_score
|
|
48
|
+
|
|
49
|
+
return r2_score(y, self.predict(X, queue=queue), sample_weight=sample_weight)
|
|
50
|
+
|
|
51
|
+
def _more_tags(self):
|
|
52
|
+
return {"requires_y": True}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TransformerMixin:
|
|
56
|
+
_estimator_type = "transformer"
|
|
57
|
+
|
|
58
|
+
def fit_transform(self, X, y=None, queue=None, **fit_params):
|
|
59
|
+
if y is None:
|
|
60
|
+
return self.fit(X, queue=queue, **fit_params).transform(X, queue=queue)
|
|
61
|
+
else:
|
|
62
|
+
return self.fit(X, y, queue=queue, **fit_params).transform(X, queue=queue)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import sys
|
|
18
|
+
|
|
19
|
+
from onedal import _backend, _is_dpc_backend
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_policy(queue, *data):
|
|
23
|
+
data_queue = _get_queue(*data)
|
|
24
|
+
if _is_dpc_backend:
|
|
25
|
+
if queue is None:
|
|
26
|
+
if data_queue is None:
|
|
27
|
+
return _HostInteropPolicy()
|
|
28
|
+
return _DataParallelInteropPolicy(data_queue)
|
|
29
|
+
return _DataParallelInteropPolicy(queue)
|
|
30
|
+
else:
|
|
31
|
+
if not (data_queue is None and queue is None):
|
|
32
|
+
raise RuntimeError(
|
|
33
|
+
"Operation using the requested SYCL queue requires the DPC backend"
|
|
34
|
+
)
|
|
35
|
+
return _HostInteropPolicy()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_queue(*data):
|
|
39
|
+
if len(data) > 0 and hasattr(data[0], "__sycl_usm_array_interface__"):
|
|
40
|
+
# Assume that all data reside on the same device
|
|
41
|
+
return data[0].__sycl_usm_array_interface__["syclobj"]
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class _HostInteropPolicy(_backend.host_policy):
|
|
46
|
+
def __init__(self):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
if _is_dpc_backend:
|
|
51
|
+
from onedal._device_offload import DummySyclQueue
|
|
52
|
+
|
|
53
|
+
class _DataParallelInteropPolicy(_backend.data_parallel_policy):
|
|
54
|
+
def __init__(self, queue):
|
|
55
|
+
self._queue = queue
|
|
56
|
+
if isinstance(queue, DummySyclQueue):
|
|
57
|
+
super().__init__(self._queue.sycl_device.get_filter_string())
|
|
58
|
+
return
|
|
59
|
+
super().__init__(self._queue)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from onedal import _is_spmd_backend
|
|
18
|
+
|
|
19
|
+
if _is_spmd_backend:
|
|
20
|
+
from onedal import _spmd_backend
|
|
21
|
+
|
|
22
|
+
class _SPMDDataParallelInteropPolicy(_spmd_backend.spmd_data_parallel_policy):
|
|
23
|
+
def __init__(self, queue):
|
|
24
|
+
self._queue = queue
|
|
25
|
+
super().__init__(self._queue)
|
|
26
|
+
|
|
27
|
+
def _get_spmd_policy(queue):
|
|
28
|
+
# TODO:
|
|
29
|
+
# cases when queue is None
|
|
30
|
+
return _SPMDDataParallelInteropPolicy(queue)
|