scikit-learn-intelex 2024.7.0__py39-none-win_amd64.whl → 2025.0.1__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/_daal4py.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +242 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +241 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +155 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/_config.py +53 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/_device_offload.py +229 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/_onedal_py_host.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/kmeans.py +560 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/hyperparameters.py +116 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +75 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +95 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +235 -0
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/ensemble/forest.py +720 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +149 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +778 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/tests/test_common.py +41 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +168 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +107 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/utils/_array_api.py +91 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/utils/validation.py +432 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +36 -13
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +30 -8
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +49 -16
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +383 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +153 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +28 -10
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +11 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +1 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +19 -9
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +2 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +7 -7
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +45 -26
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +418 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +4 -4
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +13 -10
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +5 -4
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +3 -3
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +4 -2
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +2 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +1 -1
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +1 -3
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +8 -8
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +2 -2
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +4 -4
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +4 -1
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +4 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +1 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +1 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/_utils_spmd.py +18 -5
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +2 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -1
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +12 -11
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +1 -2
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +5 -20
- {scikit_learn_intelex-2024.7.0.dist-info → scikit_learn_intelex-2025.0.1.dist-info}/METADATA +3 -2
- scikit_learn_intelex-2025.0.1.dist-info/RECORD +255 -0
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -25
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -42
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -303
- scikit_learn_intelex-2024.7.0.dist-info/RECORD +0 -122
- {scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal}/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/conftest.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_common.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +0 -0
- {scikit_learn_intelex-2024.7.0.data → scikit_learn_intelex-2025.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.7.0.dist-info → scikit_learn_intelex-2025.0.1.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.7.0.dist-info → scikit_learn_intelex-2025.0.1.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2024.7.0.dist-info → scikit_learn_intelex-2025.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
from onedal.common._policy import _get_policy
|
|
21
|
+
from onedal.tests.utils._device_selection import (
|
|
22
|
+
device_type_to_str,
|
|
23
|
+
get_memory_usm,
|
|
24
|
+
get_queues,
|
|
25
|
+
is_dpctl_available,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
30
|
+
def test_queue_passed_directly(queue):
|
|
31
|
+
device_name = device_type_to_str(queue)
|
|
32
|
+
test_queue = _get_policy(queue)
|
|
33
|
+
test_device_name = test_queue.get_device_name()
|
|
34
|
+
assert test_device_name == device_name
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
38
|
+
def test_with_numpy_data(queue):
|
|
39
|
+
X = np.zeros((5, 3))
|
|
40
|
+
y = np.zeros(3)
|
|
41
|
+
|
|
42
|
+
device_name = device_type_to_str(queue)
|
|
43
|
+
assert _get_policy(queue, X, y).get_device_name() == device_name
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.skipif(not is_dpctl_available(), reason="depends on dpctl")
|
|
47
|
+
@pytest.mark.parametrize("queue", get_queues("cpu,gpu"))
|
|
48
|
+
@pytest.mark.parametrize("memtype", get_memory_usm())
|
|
49
|
+
def test_with_usm_ndarray_data(queue, memtype):
|
|
50
|
+
if queue is None:
|
|
51
|
+
pytest.skip(
|
|
52
|
+
"dpctl Memory object with queue=None uses cached default (gpu if available)"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
from dpctl.tensor import usm_ndarray
|
|
56
|
+
|
|
57
|
+
device_name = device_type_to_str(queue)
|
|
58
|
+
X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=queue))
|
|
59
|
+
y = usm_ndarray((3,), buffer=memtype(3 * 8, queue=queue))
|
|
60
|
+
assert _get_policy(None, X, y).get_device_name() == device_name
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@pytest.mark.skipif(
|
|
64
|
+
not is_dpctl_available(["cpu", "gpu"]), reason="test uses multiple devices"
|
|
65
|
+
)
|
|
66
|
+
@pytest.mark.parametrize("memtype", get_memory_usm())
|
|
67
|
+
def test_queue_parameter_with_usm_ndarray(memtype):
|
|
68
|
+
from dpctl import SyclQueue
|
|
69
|
+
from dpctl.tensor import usm_ndarray
|
|
70
|
+
|
|
71
|
+
q1 = SyclQueue("cpu")
|
|
72
|
+
q2 = SyclQueue("gpu")
|
|
73
|
+
|
|
74
|
+
X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=q1))
|
|
75
|
+
assert _get_policy(q2, X).get_device_name() == device_type_to_str(q2)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
from .covariance import EmpiricalCovariance
|
|
18
|
+
from .incremental_covariance import IncrementalEmpiricalCovariance
|
|
19
|
+
|
|
20
|
+
__all__ = ["EmpiricalCovariance", "IncrementalEmpiricalCovariance"]
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
from abc import ABCMeta
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
21
|
+
from onedal.utils import _check_array
|
|
22
|
+
|
|
23
|
+
from ..common._base import BaseEstimator
|
|
24
|
+
from ..common.hyperparameters import get_hyperparameters
|
|
25
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseEmpiricalCovariance(BaseEstimator, metaclass=ABCMeta):
|
|
29
|
+
def __init__(self, method="dense", bias=False, assume_centered=False):
|
|
30
|
+
self.method = method
|
|
31
|
+
self.bias = bias
|
|
32
|
+
self.assume_centered = assume_centered
|
|
33
|
+
|
|
34
|
+
def _get_onedal_params(self, dtype=np.float32):
|
|
35
|
+
params = {
|
|
36
|
+
"fptype": "float" if dtype == np.float32 else "double",
|
|
37
|
+
"method": self.method,
|
|
38
|
+
}
|
|
39
|
+
if daal_check_version((2024, "P", 1)):
|
|
40
|
+
params["bias"] = self.bias
|
|
41
|
+
if daal_check_version((2024, "P", 400)):
|
|
42
|
+
params["assumeCentered"] = self.assume_centered
|
|
43
|
+
|
|
44
|
+
return params
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EmpiricalCovariance(BaseEmpiricalCovariance):
|
|
48
|
+
"""Covariance estimator.
|
|
49
|
+
|
|
50
|
+
Computes sample covariance matrix.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
method : string, default="dense"
|
|
55
|
+
Specifies computation method. Available methods: "dense".
|
|
56
|
+
|
|
57
|
+
bias: bool, default=False
|
|
58
|
+
If True biased estimation of covariance is computed which equals to
|
|
59
|
+
the unbiased one multiplied by (n_samples - 1) / n_samples.
|
|
60
|
+
|
|
61
|
+
assume_centered : bool, default=False
|
|
62
|
+
If True, data are not centered before computation.
|
|
63
|
+
Useful when working with data whose mean is almost, but not exactly
|
|
64
|
+
zero.
|
|
65
|
+
If False (default), data are centered before computation.
|
|
66
|
+
|
|
67
|
+
Attributes
|
|
68
|
+
----------
|
|
69
|
+
location_ : ndarray of shape (n_features,)
|
|
70
|
+
Estimated location, i.e., the estimated mean.
|
|
71
|
+
|
|
72
|
+
covariance_ : ndarray of shape (n_features, n_features)
|
|
73
|
+
Estimated covariance matrix
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def fit(self, X, y=None, queue=None):
|
|
77
|
+
"""Fit the sample covariance matrix of X.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
X : array-like of shape (n_samples, n_features)
|
|
82
|
+
Training data, where `n_samples` is the number of samples, and
|
|
83
|
+
`n_features` is the number of features.
|
|
84
|
+
|
|
85
|
+
y : Ignored
|
|
86
|
+
Not used, present for API consistency by convention.
|
|
87
|
+
|
|
88
|
+
queue : dpctl.SyclQueue
|
|
89
|
+
If not None, use this queue for computations.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
self : object
|
|
94
|
+
Returns the instance itself.
|
|
95
|
+
"""
|
|
96
|
+
policy = self._get_policy(queue, X)
|
|
97
|
+
X = _check_array(X, dtype=[np.float64, np.float32])
|
|
98
|
+
X = _convert_to_supported(policy, X)
|
|
99
|
+
dtype = get_dtype(X)
|
|
100
|
+
params = self._get_onedal_params(dtype)
|
|
101
|
+
hparams = get_hyperparameters("covariance", "compute")
|
|
102
|
+
if hparams is not None and not hparams.is_default:
|
|
103
|
+
result = self._get_backend(
|
|
104
|
+
"covariance",
|
|
105
|
+
None,
|
|
106
|
+
"compute",
|
|
107
|
+
policy,
|
|
108
|
+
params,
|
|
109
|
+
hparams.backend,
|
|
110
|
+
to_table(X),
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
result = self._get_backend(
|
|
114
|
+
"covariance", None, "compute", policy, params, to_table(X)
|
|
115
|
+
)
|
|
116
|
+
if daal_check_version((2024, "P", 1)) or (not self.bias):
|
|
117
|
+
self.covariance_ = from_table(result.cov_matrix)
|
|
118
|
+
else:
|
|
119
|
+
self.covariance_ = (
|
|
120
|
+
from_table(result.cov_matrix) * (X.shape[0] - 1) / X.shape[0]
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
self.location_ = from_table(result.means).ravel()
|
|
124
|
+
|
|
125
|
+
return self
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
19
|
+
|
|
20
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
21
|
+
from ..utils import _check_array
|
|
22
|
+
from .covariance import BaseEmpiricalCovariance
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class IncrementalEmpiricalCovariance(BaseEmpiricalCovariance):
|
|
26
|
+
"""
|
|
27
|
+
Covariance estimator based on oneDAL implementation.
|
|
28
|
+
|
|
29
|
+
Computes sample covariance matrix.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
method : string, default="dense"
|
|
34
|
+
Specifies computation method. Available methods: "dense".
|
|
35
|
+
|
|
36
|
+
bias: bool, default=False
|
|
37
|
+
If True biased estimation of covariance is computed which equals to
|
|
38
|
+
the unbiased one multiplied by (n_samples - 1) / n_samples.
|
|
39
|
+
|
|
40
|
+
assume_centered : bool, default=False
|
|
41
|
+
If True, data are not centered before computation.
|
|
42
|
+
Useful when working with data whose mean is almost, but not exactly
|
|
43
|
+
zero.
|
|
44
|
+
If False (default), data are centered before computation.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
location_ : ndarray of shape (n_features,)
|
|
49
|
+
Estimated location, i.e., the estimated mean.
|
|
50
|
+
|
|
51
|
+
covariance_ : ndarray of shape (n_features, n_features)
|
|
52
|
+
Estimated covariance matrix
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, method="dense", bias=False, assume_centered=False):
|
|
56
|
+
super().__init__(method, bias, assume_centered)
|
|
57
|
+
self._reset()
|
|
58
|
+
|
|
59
|
+
def _reset(self):
|
|
60
|
+
self._partial_result = self._get_backend(
|
|
61
|
+
"covariance", None, "partial_compute_result"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def partial_fit(self, X, y=None, queue=None):
|
|
65
|
+
"""
|
|
66
|
+
Computes partial data for the covariance matrix
|
|
67
|
+
from data batch X and saves it to `_partial_result`.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
X : array-like of shape (n_samples, n_features)
|
|
72
|
+
Training data batch, where `n_samples` is the number of samples
|
|
73
|
+
in the batch, and `n_features` is the number of features.
|
|
74
|
+
|
|
75
|
+
y : Ignored
|
|
76
|
+
Not used, present for API consistency by convention.
|
|
77
|
+
|
|
78
|
+
queue : dpctl.SyclQueue
|
|
79
|
+
If not None, use this queue for computations.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
self : object
|
|
84
|
+
Returns the instance itself.
|
|
85
|
+
"""
|
|
86
|
+
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)
|
|
87
|
+
|
|
88
|
+
self._queue = queue
|
|
89
|
+
|
|
90
|
+
policy = self._get_policy(queue, X)
|
|
91
|
+
|
|
92
|
+
X = _convert_to_supported(policy, X)
|
|
93
|
+
|
|
94
|
+
if not hasattr(self, "_dtype"):
|
|
95
|
+
self._dtype = get_dtype(X)
|
|
96
|
+
|
|
97
|
+
params = self._get_onedal_params(self._dtype)
|
|
98
|
+
table_X = to_table(X)
|
|
99
|
+
self._partial_result = self._get_backend(
|
|
100
|
+
"covariance",
|
|
101
|
+
None,
|
|
102
|
+
"partial_compute",
|
|
103
|
+
policy,
|
|
104
|
+
params,
|
|
105
|
+
self._partial_result,
|
|
106
|
+
table_X,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def finalize_fit(self, queue=None):
|
|
110
|
+
"""
|
|
111
|
+
Finalizes covariance matrix and obtains `covariance_` and `location_`
|
|
112
|
+
attributes from the current `_partial_result`.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
queue : dpctl.SyclQueue
|
|
117
|
+
If not None, use this queue for computations.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
self : object
|
|
122
|
+
Returns the instance itself.
|
|
123
|
+
"""
|
|
124
|
+
params = self._get_onedal_params(self._dtype)
|
|
125
|
+
if queue is not None:
|
|
126
|
+
policy = self._get_policy(queue)
|
|
127
|
+
else:
|
|
128
|
+
policy = self._get_policy(self._queue)
|
|
129
|
+
|
|
130
|
+
result = self._get_backend(
|
|
131
|
+
"covariance",
|
|
132
|
+
None,
|
|
133
|
+
"finalize_compute",
|
|
134
|
+
policy,
|
|
135
|
+
params,
|
|
136
|
+
self._partial_result,
|
|
137
|
+
)
|
|
138
|
+
if daal_check_version((2024, "P", 1)) or (not self.bias):
|
|
139
|
+
self.covariance_ = from_table(result.cov_matrix)
|
|
140
|
+
else:
|
|
141
|
+
n_rows = self._partial_result.partial_n_rows
|
|
142
|
+
self.covariance_ = from_table(result.cov_matrix) * (n_rows - 1) / n_rows
|
|
143
|
+
|
|
144
|
+
self.location_ = from_table(result.means).ravel()
|
|
145
|
+
|
|
146
|
+
return self
|
scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from numpy.testing import assert_allclose
|
|
20
|
+
|
|
21
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
25
|
+
def test_onedal_import_covariance(queue):
|
|
26
|
+
from onedal.covariance import EmpiricalCovariance
|
|
27
|
+
|
|
28
|
+
X = np.array([[0, 1], [0, 1]])
|
|
29
|
+
result = EmpiricalCovariance().fit(X, queue=queue)
|
|
30
|
+
expected_covariance = np.array([[0, 0], [0, 0]])
|
|
31
|
+
expected_means = np.array([0, 1])
|
|
32
|
+
|
|
33
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
34
|
+
assert_allclose(expected_means, result.location_)
|
|
35
|
+
|
|
36
|
+
X = np.array([[1, 2], [3, 6]])
|
|
37
|
+
result = EmpiricalCovariance().fit(X, queue=queue)
|
|
38
|
+
expected_covariance = np.array([[2, 4], [4, 8]])
|
|
39
|
+
expected_means = np.array([2, 4])
|
|
40
|
+
|
|
41
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
42
|
+
assert_allclose(expected_means, result.location_)
|
|
43
|
+
|
|
44
|
+
X = np.array([[1, 2], [3, 6]])
|
|
45
|
+
result = EmpiricalCovariance(bias=True).fit(X, queue=queue)
|
|
46
|
+
expected_covariance = np.array([[1, 2], [2, 4]])
|
|
47
|
+
expected_means = np.array([2, 4])
|
|
48
|
+
|
|
49
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
50
|
+
assert_allclose(expected_means, result.location_)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from numpy.testing import assert_allclose
|
|
20
|
+
|
|
21
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
25
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
26
|
+
def test_on_gold_data_unbiased(queue, dtype):
|
|
27
|
+
from onedal.covariance import IncrementalEmpiricalCovariance
|
|
28
|
+
|
|
29
|
+
X = np.array([[0, 1], [0, 1]])
|
|
30
|
+
X = X.astype(dtype)
|
|
31
|
+
X_split = np.array_split(X, 2)
|
|
32
|
+
inccov = IncrementalEmpiricalCovariance()
|
|
33
|
+
|
|
34
|
+
for i in range(2):
|
|
35
|
+
inccov.partial_fit(X_split[i], queue=queue)
|
|
36
|
+
result = inccov.finalize_fit()
|
|
37
|
+
|
|
38
|
+
expected_covariance = np.array([[0, 0], [0, 0]])
|
|
39
|
+
expected_means = np.array([0, 1])
|
|
40
|
+
|
|
41
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
42
|
+
assert_allclose(expected_means, result.location_)
|
|
43
|
+
|
|
44
|
+
X = np.array([[1, 2], [3, 6]])
|
|
45
|
+
X_split = np.array_split(X, 2)
|
|
46
|
+
X = X.astype(dtype)
|
|
47
|
+
inccov = IncrementalEmpiricalCovariance()
|
|
48
|
+
|
|
49
|
+
for i in range(2):
|
|
50
|
+
inccov.partial_fit(X_split[i], queue=queue)
|
|
51
|
+
result = inccov.finalize_fit()
|
|
52
|
+
|
|
53
|
+
expected_covariance = np.array([[2, 4], [4, 8]])
|
|
54
|
+
expected_means = np.array([2, 4])
|
|
55
|
+
|
|
56
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
57
|
+
assert_allclose(expected_means, result.location_)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
61
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
62
|
+
def test_on_gold_data_biased(queue, dtype):
|
|
63
|
+
from onedal.covariance import IncrementalEmpiricalCovariance
|
|
64
|
+
|
|
65
|
+
X = np.array([[0, 1], [0, 1]])
|
|
66
|
+
X = X.astype(dtype)
|
|
67
|
+
X_split = np.array_split(X, 2)
|
|
68
|
+
inccov = IncrementalEmpiricalCovariance(bias=True)
|
|
69
|
+
|
|
70
|
+
for i in range(2):
|
|
71
|
+
inccov.partial_fit(X_split[i], queue=queue)
|
|
72
|
+
result = inccov.finalize_fit()
|
|
73
|
+
|
|
74
|
+
expected_covariance = np.array([[0, 0], [0, 0]])
|
|
75
|
+
expected_means = np.array([0, 1])
|
|
76
|
+
|
|
77
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
78
|
+
assert_allclose(expected_means, result.location_)
|
|
79
|
+
|
|
80
|
+
X = np.array([[1, 2], [3, 6]])
|
|
81
|
+
X = X.astype(dtype)
|
|
82
|
+
X_split = np.array_split(X, 2)
|
|
83
|
+
inccov = IncrementalEmpiricalCovariance(bias=True)
|
|
84
|
+
|
|
85
|
+
for i in range(2):
|
|
86
|
+
inccov.partial_fit(X_split[i], queue=queue)
|
|
87
|
+
result = inccov.finalize_fit()
|
|
88
|
+
|
|
89
|
+
expected_covariance = np.array([[1, 2], [2, 4]])
|
|
90
|
+
expected_means = np.array([2, 4])
|
|
91
|
+
|
|
92
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
93
|
+
assert_allclose(expected_means, result.location_)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
97
|
+
@pytest.mark.parametrize("num_batches", [2, 4, 6, 8, 10])
|
|
98
|
+
@pytest.mark.parametrize("row_count", [100, 1000, 2000])
|
|
99
|
+
@pytest.mark.parametrize("column_count", [10, 100, 200])
|
|
100
|
+
@pytest.mark.parametrize("bias", [True, False])
|
|
101
|
+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
|
102
|
+
def test_partial_fit_on_random_data(
|
|
103
|
+
queue, num_batches, row_count, column_count, bias, dtype
|
|
104
|
+
):
|
|
105
|
+
from onedal.covariance import IncrementalEmpiricalCovariance
|
|
106
|
+
|
|
107
|
+
seed = 77
|
|
108
|
+
gen = np.random.default_rng(seed)
|
|
109
|
+
X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
|
|
110
|
+
X = X.astype(dtype)
|
|
111
|
+
X_split = np.array_split(X, num_batches)
|
|
112
|
+
inccov = IncrementalEmpiricalCovariance(bias=bias)
|
|
113
|
+
|
|
114
|
+
for i in range(num_batches):
|
|
115
|
+
inccov.partial_fit(X_split[i], queue=queue)
|
|
116
|
+
result = inccov.finalize_fit()
|
|
117
|
+
|
|
118
|
+
expected_covariance = np.cov(X.T, bias=bias)
|
|
119
|
+
expected_means = np.mean(X, axis=0)
|
|
120
|
+
|
|
121
|
+
assert_allclose(expected_covariance, result.covariance_, atol=1e-6)
|
|
122
|
+
assert_allclose(expected_means, result.location_, atol=1e-6)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from ._data_conversion import _convert_to_supported, from_table, to_table
|
|
18
|
+
|
|
19
|
+
__all__ = ["from_table", "to_table", "_convert_to_supported"]
|
scikit_learn_intelex-2025.0.1.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import warnings
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import make2d
|
|
22
|
+
from onedal import _backend, _is_dpc_backend
|
|
23
|
+
|
|
24
|
+
from ..utils import _is_csr
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import dpctl
|
|
28
|
+
import dpctl.tensor as dpt
|
|
29
|
+
|
|
30
|
+
dpctl_available = dpctl.__version__ >= "0.14"
|
|
31
|
+
except ImportError:
|
|
32
|
+
dpctl_available = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _apply_and_pass(func, *args):
|
|
36
|
+
if len(args) == 1:
|
|
37
|
+
return func(args[0])
|
|
38
|
+
return tuple(map(func, args))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def from_table(*args):
|
|
42
|
+
return _apply_and_pass(_backend.from_table, *args)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def convert_one_to_table(arg):
|
|
46
|
+
if dpctl_available:
|
|
47
|
+
if isinstance(arg, dpt.usm_ndarray):
|
|
48
|
+
return _backend.dpctl_to_table(arg)
|
|
49
|
+
|
|
50
|
+
if not _is_csr(arg):
|
|
51
|
+
arg = make2d(arg)
|
|
52
|
+
return _backend.to_table(arg)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def to_table(*args):
|
|
56
|
+
return _apply_and_pass(convert_one_to_table, *args)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if _is_dpc_backend:
|
|
60
|
+
from ..common._policy import _HostInteropPolicy
|
|
61
|
+
|
|
62
|
+
def _convert_to_supported(policy, *data):
|
|
63
|
+
def func(x):
|
|
64
|
+
return x
|
|
65
|
+
|
|
66
|
+
# CPUs support FP64 by default
|
|
67
|
+
if isinstance(policy, _HostInteropPolicy):
|
|
68
|
+
return _apply_and_pass(func, *data)
|
|
69
|
+
|
|
70
|
+
# It can be either SPMD or DPCPP policy
|
|
71
|
+
device = policy._queue.sycl_device
|
|
72
|
+
|
|
73
|
+
def convert_or_pass(x):
|
|
74
|
+
if (x is not None) and (x.dtype == np.float64):
|
|
75
|
+
warnings.warn(
|
|
76
|
+
"Data will be converted into float32 from "
|
|
77
|
+
"float64 because device does not support it",
|
|
78
|
+
RuntimeWarning,
|
|
79
|
+
)
|
|
80
|
+
return x.astype(np.float32)
|
|
81
|
+
else:
|
|
82
|
+
return x
|
|
83
|
+
|
|
84
|
+
if not device.has_aspect_fp64:
|
|
85
|
+
func = convert_or_pass
|
|
86
|
+
|
|
87
|
+
return _apply_and_pass(func, *data)
|
|
88
|
+
|
|
89
|
+
else:
|
|
90
|
+
|
|
91
|
+
def _convert_to_supported(policy, *data):
|
|
92
|
+
def func(x):
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
return _apply_and_pass(func, *data)
|