scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
- scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
- scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
- scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
"""Tools to support array_api."""
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from typing import Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import scipy.linalg as linalg
|
|
25
|
+
from sklearn.covariance import log_likelihood as _sklearn_log_likelihood
|
|
26
|
+
|
|
27
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
28
|
+
from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace
|
|
29
|
+
|
|
30
|
+
from ..base import oneDALEstimator
|
|
31
|
+
|
|
32
|
+
if sklearn_check_version("1.6"):
|
|
33
|
+
from ..base import Tags
|
|
34
|
+
|
|
35
|
+
if sklearn_check_version("1.2"):
|
|
36
|
+
from sklearn.utils._array_api import get_namespace as sklearn_get_namespace
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_namespace(*arrays):
|
|
40
|
+
"""Get namespace of arrays.
|
|
41
|
+
|
|
42
|
+
Introspect `arrays` arguments and return their common Array API
|
|
43
|
+
compatible namespace object, if any. NumPy 1.22 and later can
|
|
44
|
+
construct such containers using the `numpy.array_api` namespace
|
|
45
|
+
for instance.
|
|
46
|
+
|
|
47
|
+
This function will return the namespace of SYCL-related arrays
|
|
48
|
+
which define the __sycl_usm_array_interface__ attribute
|
|
49
|
+
regardless of array_api support, the configuration of
|
|
50
|
+
array_api_dispatch, or scikit-learn version.
|
|
51
|
+
|
|
52
|
+
See: https://numpy.org/neps/nep-0047-array-api-standard.html
|
|
53
|
+
|
|
54
|
+
If `arrays` are regular numpy arrays, an instance of the
|
|
55
|
+
`_NumPyApiWrapper` compatibility wrapper is returned instead.
|
|
56
|
+
|
|
57
|
+
Namespace support is not enabled by default. To enabled it
|
|
58
|
+
call:
|
|
59
|
+
|
|
60
|
+
sklearn.set_config(array_api_dispatch=True)
|
|
61
|
+
|
|
62
|
+
or:
|
|
63
|
+
|
|
64
|
+
with sklearn.config_context(array_api_dispatch=True):
|
|
65
|
+
# your code here
|
|
66
|
+
|
|
67
|
+
Otherwise an instance of the `_NumPyApiWrapper`
|
|
68
|
+
compatibility wrapper is always returned irrespective of
|
|
69
|
+
the fact that arrays implement the `__array_namespace__`
|
|
70
|
+
protocol or not.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
*arrays : array objects
|
|
75
|
+
Array objects.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
namespace : module
|
|
80
|
+
Namespace shared by array objects.
|
|
81
|
+
|
|
82
|
+
is_array_api : bool
|
|
83
|
+
True of the arrays are containers that implement the Array API spec.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
|
|
87
|
+
|
|
88
|
+
if sycl_type:
|
|
89
|
+
return xp, is_array_api_compliant
|
|
90
|
+
elif sklearn_check_version("1.2"):
|
|
91
|
+
return sklearn_get_namespace(*arrays)
|
|
92
|
+
else:
|
|
93
|
+
return np, False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _enable_array_api(original_class: type[oneDALEstimator]) -> type[oneDALEstimator]:
|
|
97
|
+
if sklearn_check_version("1.6"):
|
|
98
|
+
|
|
99
|
+
def __sklearn_tags__(self) -> Tags:
|
|
100
|
+
sktags = super(original_class, self).__sklearn_tags__()
|
|
101
|
+
sktags.onedal_array_api = True
|
|
102
|
+
return sktags
|
|
103
|
+
|
|
104
|
+
original_class.__sklearn_tags__ = __sklearn_tags__
|
|
105
|
+
|
|
106
|
+
elif sklearn_check_version("1.3"):
|
|
107
|
+
|
|
108
|
+
def _more_tags(self) -> dict[str, bool]:
|
|
109
|
+
return {"onedal_array_api": True}
|
|
110
|
+
|
|
111
|
+
original_class._more_tags = _more_tags
|
|
112
|
+
|
|
113
|
+
return original_class
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def enable_array_api(
|
|
117
|
+
class_or_str: Union[type[oneDALEstimator], str],
|
|
118
|
+
) -> Union[type[oneDALEstimator], Callable]:
|
|
119
|
+
"""Enable sklearnex to use dpctl, dpnp or array API inputs in oneDAL offloading.
|
|
120
|
+
|
|
121
|
+
This wrapper sets the proper flags/tags for the sklearnex infrastructure
|
|
122
|
+
to maintain the data framework, as the estimator can use it natively.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
class_or_str : oneDALEstimator subclass or str
|
|
127
|
+
Class which should enable data zero-copy support in sklearnex. By
|
|
128
|
+
default it will enable for sklearn versions >1.3. If the wrapper is
|
|
129
|
+
decorated with an argument, it must be a string defining the oldest
|
|
130
|
+
sklearn version where array API support begins.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
cls or wrapper : modified oneDALEstimator subclass or wrapper
|
|
135
|
+
Estimator class or wrapper.
|
|
136
|
+
|
|
137
|
+
Examples
|
|
138
|
+
--------
|
|
139
|
+
@enable_array_api # default array API support
|
|
140
|
+
class PCA():
|
|
141
|
+
...
|
|
142
|
+
|
|
143
|
+
@enable_array_api("1.5") # array API support for sklearn > 1.5
|
|
144
|
+
class Ridge():
|
|
145
|
+
...
|
|
146
|
+
"""
|
|
147
|
+
if isinstance(class_or_str, str):
|
|
148
|
+
# enable array_api for the estimator for a given sklearn version str
|
|
149
|
+
if sklearn_check_version(class_or_str):
|
|
150
|
+
return _enable_array_api
|
|
151
|
+
else:
|
|
152
|
+
# do not apply the wrapper as it is not supported
|
|
153
|
+
return lambda x: x
|
|
154
|
+
else:
|
|
155
|
+
# default setting (apply array_api enablement for sklearn >=1.3)
|
|
156
|
+
return _enable_array_api(class_or_str)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _pinvh(a, atol=None, rtol=None, lower=True, return_rank=False, check_finite=True):
|
|
160
|
+
# array API enabled pinvh implementation, via direct translation of scipy.linalg.pinhv
|
|
161
|
+
# this should be considered a temporary stopgap until implemented in oneDAL
|
|
162
|
+
xp, _ = get_namespace(a)
|
|
163
|
+
# fall back to scipy if the namespace is of a numpy origin
|
|
164
|
+
if _is_numpy_namespace(xp):
|
|
165
|
+
return linalg.pinvh(
|
|
166
|
+
a,
|
|
167
|
+
atol=atol,
|
|
168
|
+
rtol=rtol,
|
|
169
|
+
lower=lower,
|
|
170
|
+
return_rank=return_rank,
|
|
171
|
+
check_finite=check_finite,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if check_finite:
|
|
175
|
+
raise NotImplementedError("finite checking does not occur in sklearnex's pinvh")
|
|
176
|
+
|
|
177
|
+
s, u = xp.linalg.eigh(a)
|
|
178
|
+
maxS = xp.max(xp.abs(s))
|
|
179
|
+
|
|
180
|
+
atol = 0.0 if atol is None else atol
|
|
181
|
+
rtol = max(a.shape) * xp.finfo(u.dtype).eps if (rtol is None) else rtol
|
|
182
|
+
|
|
183
|
+
if (atol < 0.0) or (rtol < 0.0):
|
|
184
|
+
raise ValueError("atol and rtol values must be positive.")
|
|
185
|
+
|
|
186
|
+
val = atol + maxS * rtol
|
|
187
|
+
above_cutoff = xp.nonzero(abs(s) > val)[0]
|
|
188
|
+
|
|
189
|
+
psigma_diag = 1.0 / xp.take(s, above_cutoff)
|
|
190
|
+
u = xp.take(u, above_cutoff, axis=1)
|
|
191
|
+
|
|
192
|
+
uconj = xp.conj(u) if xp.isdtype(u.dtype, kind="complex floating") else u
|
|
193
|
+
|
|
194
|
+
B = (u * psigma_diag) @ uconj.T
|
|
195
|
+
|
|
196
|
+
if return_rank:
|
|
197
|
+
return B, len(psigma_diag)
|
|
198
|
+
else:
|
|
199
|
+
return B
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def log_likelihood(emp_cov, precision):
|
|
203
|
+
# this is to compensate for a lack of array API support in sklearn
|
|
204
|
+
# even though it exists for ``fast_logdet``
|
|
205
|
+
xp, _ = get_namespace(emp_cov, precision)
|
|
206
|
+
p = precision.shape[0]
|
|
207
|
+
# extract sklearn.utils.extmath.fast_logdet for dpnp/dpctl support
|
|
208
|
+
sign, ld = xp.linalg.slogdet(precision)
|
|
209
|
+
if not sign > 0:
|
|
210
|
+
ld = -xp.inf
|
|
211
|
+
log_likelihood_ = -xp.sum(emp_cov * precision) + ld
|
|
212
|
+
log_likelihood_ -= p * math.log(2 * math.pi)
|
|
213
|
+
log_likelihood_ /= 2.0
|
|
214
|
+
return log_likelihood_
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
log_likelihood.__doc__ = _sklearn_log_likelihood.__doc__
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright contributors to the oneDAL Project
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from sklearn.preprocessing import LabelEncoder as _sklearn_LabelEncoder
|
|
18
|
+
|
|
19
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
20
|
+
|
|
21
|
+
from ._array_api import get_namespace
|
|
22
|
+
from .validation import _check_sample_weight
|
|
23
|
+
|
|
24
|
+
if not sklearn_check_version("1.7"):
|
|
25
|
+
from sklearn.utils.class_weight import (
|
|
26
|
+
compute_class_weight as _sklearn_compute_class_weight,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def compute_class_weight(class_weight, *, classes, y, sample_weight=None):
|
|
30
|
+
return _sklearn_compute_class_weight(class_weight, classes=classes, y=y)
|
|
31
|
+
|
|
32
|
+
else:
|
|
33
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _compute_class_weight(class_weight, *, classes, y, sample_weight=None):
|
|
37
|
+
# this duplicates sklearn code in order to enable it for array API.
|
|
38
|
+
# Note for the use of LabelEncoder this is only valid for sklearn
|
|
39
|
+
# versions >= 1.6.
|
|
40
|
+
xp, is_array_api_compliant = get_namespace(classes, y, sample_weight)
|
|
41
|
+
|
|
42
|
+
if not is_array_api_compliant:
|
|
43
|
+
# use the sklearn version for standard use.
|
|
44
|
+
return compute_class_weight(class_weight, classes, y, sample_weight=sample_weight)
|
|
45
|
+
|
|
46
|
+
sety = xp.unique_values(y)
|
|
47
|
+
if class_weight is None or len(class_weight) == 0:
|
|
48
|
+
# uniform class weights
|
|
49
|
+
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
|
|
50
|
+
elif class_weight == "balanced":
|
|
51
|
+
if not sklearn_check_version("1.6"):
|
|
52
|
+
raise RuntimeError(
|
|
53
|
+
"array API support with 'balanced' keyword not supported for sklearn <1.6"
|
|
54
|
+
)
|
|
55
|
+
# Find the weight of each class as present in y.
|
|
56
|
+
le = _sklearn_LabelEncoder()
|
|
57
|
+
y_ind = le.fit_transform(y)
|
|
58
|
+
if not all([item in le.classes_ for item in classes]):
|
|
59
|
+
raise ValueError("classes should have valid labels that are in y")
|
|
60
|
+
|
|
61
|
+
sample_weight = _check_sample_weight(sample_weight, y)
|
|
62
|
+
# scikit-learn implementation uses numpy.bincount, which does a combined
|
|
63
|
+
# min and max search, only erroring when a value < 0. Replicating this
|
|
64
|
+
# exactly via array API would cause another O(n) evaluation (by doing
|
|
65
|
+
# min and max separately). However this check can be removed due to the
|
|
66
|
+
# nature of the LabelEncoder. Therefore only the maximum is found, and
|
|
67
|
+
# then core logic of bincount is replicated:
|
|
68
|
+
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/compiled_base.c
|
|
69
|
+
weighted_class_counts = xp.zeros(
|
|
70
|
+
(xp.max(y_ind) + 1,), dtype=sample_weight.dtype, device=y.device
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# use a more GPU-friendly summation approach for collecting weighted_class_counts
|
|
74
|
+
for w_idx in range(weighted_class_counts.shape[0]):
|
|
75
|
+
weighted_class_counts[w_idx] = xp.sum(sample_weight[y_ind == w_idx])
|
|
76
|
+
|
|
77
|
+
recip_freq = xp.sum(weighted_class_counts) / (
|
|
78
|
+
le.classes_.shape[0] * weighted_class_counts
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
weight = xp.take(recip_freq, le.transform(classes))
|
|
82
|
+
else:
|
|
83
|
+
# user-defined dictionary
|
|
84
|
+
weight = xp.ones((classes.shape[0],), dtype=xp.float64, device=classes.device)
|
|
85
|
+
unweighted_classes = []
|
|
86
|
+
for i, c in enumerate(classes):
|
|
87
|
+
if (fc := float(c)) in class_weight:
|
|
88
|
+
# array API has only numeric datatypes, convert to float for generality
|
|
89
|
+
# complex values should never be observed by this function
|
|
90
|
+
weight[i] = class_weight[fc]
|
|
91
|
+
else:
|
|
92
|
+
unweighted_classes.append(c)
|
|
93
|
+
|
|
94
|
+
n_weighted_classes = classes.shape[0] - len(unweighted_classes)
|
|
95
|
+
if unweighted_classes and n_weighted_classes != len(class_weight):
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"The classes, {unweighted_classes}, are not in" " class_weight"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return weight
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import warnings
|
|
18
|
+
from functools import update_wrapper
|
|
19
|
+
|
|
20
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
21
|
+
|
|
22
|
+
from .._config import config_context, get_config
|
|
23
|
+
|
|
24
|
+
# Replacement of _FuncWrapper is required to correctly propagate
|
|
25
|
+
# the scikit-learn-intelex configuration functions to the joblib workers.
|
|
26
|
+
if sklearn_check_version("1.7"):
|
|
27
|
+
|
|
28
|
+
class _FuncWrapper:
|
|
29
|
+
"""Load the global configuration before calling the function."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, function):
|
|
32
|
+
self.function = function
|
|
33
|
+
update_wrapper(self, self.function)
|
|
34
|
+
|
|
35
|
+
def with_config_and_warning_filters(self, config, warning_filters):
|
|
36
|
+
self.config = config
|
|
37
|
+
self.warning_filters = warning_filters
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
def __call__(self, *args, **kwargs):
|
|
41
|
+
config = getattr(self, "config", {})
|
|
42
|
+
warning_filters = getattr(self, "warning_filters", [])
|
|
43
|
+
if not config or not warning_filters:
|
|
44
|
+
warnings.warn(
|
|
45
|
+
(
|
|
46
|
+
"`sklearn.utils.parallel.delayed` should be used with"
|
|
47
|
+
" `sklearn.utils.parallel.Parallel` to make it possible to"
|
|
48
|
+
" propagate the scikit-learn configuration of the current thread to"
|
|
49
|
+
" the joblib workers."
|
|
50
|
+
),
|
|
51
|
+
UserWarning,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
with config_context(**config), warnings.catch_warnings():
|
|
55
|
+
warnings.filters = warning_filters
|
|
56
|
+
return self.function(*args, **kwargs)
|
|
57
|
+
|
|
58
|
+
elif sklearn_check_version("1.2.1"):
|
|
59
|
+
|
|
60
|
+
class _FuncWrapper:
|
|
61
|
+
"""Load the global configuration before calling the function."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, function):
|
|
64
|
+
self.function = function
|
|
65
|
+
update_wrapper(self, self.function)
|
|
66
|
+
|
|
67
|
+
def with_config(self, config):
|
|
68
|
+
self.config = config
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
def __call__(self, *args, **kwargs):
|
|
72
|
+
config = getattr(self, "config", None)
|
|
73
|
+
if config is None:
|
|
74
|
+
warnings.warn(
|
|
75
|
+
"`sklearn.utils.parallel.delayed` should be used with "
|
|
76
|
+
"`sklearn.utils.parallel.Parallel` to make it possible to propagate "
|
|
77
|
+
"the scikit-learn configuration of the current thread to the "
|
|
78
|
+
"joblib workers.",
|
|
79
|
+
UserWarning,
|
|
80
|
+
)
|
|
81
|
+
config = {}
|
|
82
|
+
with config_context(**config):
|
|
83
|
+
return self.function(*args, **kwargs)
|
|
84
|
+
|
|
85
|
+
else:
|
|
86
|
+
|
|
87
|
+
class _FuncWrapper:
|
|
88
|
+
"""Load the global configuration before calling the function."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, function):
|
|
91
|
+
self.function = function
|
|
92
|
+
self.config = get_config()
|
|
93
|
+
update_wrapper(self, self.function)
|
|
94
|
+
|
|
95
|
+
def __call__(self, *args, **kwargs):
|
|
96
|
+
with config_context(**self.config):
|
|
97
|
+
return self.function(*args, **kwargs)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright contributors to the oneDAL Project
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from sklearn.datasets import load_iris
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
22
|
+
from onedal.tests.utils._dataframes_support import (
|
|
23
|
+
_as_numpy,
|
|
24
|
+
_convert_to_dataframe,
|
|
25
|
+
get_dataframes_and_queues,
|
|
26
|
+
)
|
|
27
|
+
from sklearnex import config_context
|
|
28
|
+
from sklearnex.utils.class_weight import _compute_class_weight
|
|
29
|
+
from sklearnex.utils.class_weight import compute_class_weight as sk_compute_class_weight
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pytest.mark.skipif(not sklearn_check_version("1.6"), reason="lacks array API support")
|
|
33
|
+
@pytest.mark.parametrize("class_weight", [None, "balanced", "ramp"])
|
|
34
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues("array_api,dpctl"))
|
|
35
|
+
def test_compute_class_weight_array_api(class_weight, dataframe, queue):
|
|
36
|
+
# This verifies that array_api functionality matches sklearn
|
|
37
|
+
|
|
38
|
+
_, y = load_iris(return_X_y=True)
|
|
39
|
+
classes = np.unique(y)
|
|
40
|
+
|
|
41
|
+
y_xp = _convert_to_dataframe(y, target_df=dataframe, device=queue)
|
|
42
|
+
classes_xp = _convert_to_dataframe(classes, target_df=dataframe, device=queue)
|
|
43
|
+
|
|
44
|
+
rng = np.random.default_rng(seed=42)
|
|
45
|
+
|
|
46
|
+
# support of sample_weights added in sklearn 1.7
|
|
47
|
+
set_sample_weight = class_weight == "balanced" and sklearn_check_version("1.7")
|
|
48
|
+
|
|
49
|
+
sample_weight = rng.random(y.shape).astype(np.float64) if set_sample_weight else None
|
|
50
|
+
|
|
51
|
+
if class_weight == "ramp":
|
|
52
|
+
class_weight = {int(i): int(i) for i in np.unique(y)}
|
|
53
|
+
|
|
54
|
+
weight_np = sk_compute_class_weight(
|
|
55
|
+
class_weight, classes=classes, y=y, sample_weight=sample_weight
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if set_sample_weight:
|
|
59
|
+
sample_weight = _convert_to_dataframe(
|
|
60
|
+
sample_weight, target_df=dataframe, device=queue
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# evaluate custom sklearnex array API functionality
|
|
64
|
+
with config_context(array_api_dispatch=True):
|
|
65
|
+
weight_xp = _compute_class_weight(
|
|
66
|
+
class_weight, classes=classes_xp, y=y_xp, sample_weight=sample_weight
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
np.testing.assert_allclose(_as_numpy(weight_xp), weight_np)
|