scikit-learn-intelex 2024.2.0__py39-none-win_amd64.whl → 2025.1.0__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/_daal4py.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +248 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -28
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +3 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +236 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/_models_info.py +13 -22
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/test_patching.py +10 -56
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/utils/_launch_algorithms.py +4 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_config.py +54 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_device_offload.py +222 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp39-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +564 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +76 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +154 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +126 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +414 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +727 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +250 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +767 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +176 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/test_common.py +57 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +81 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_dpep_helpers.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/validation.py +440 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__init__.py +10 -7
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_config.py +22 -16
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +126 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_utils.py +27 -4
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +230 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +16 -7
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +395 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +398 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +111 -17
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +425 -0
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +25 -9
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +222 -42
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +249 -182
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +39 -21
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +482 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +425 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +341 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +182 -102
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +45 -4
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +4 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +97 -28
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +53 -6
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +48 -149
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +43 -144
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +50 -93
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +6 -9
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +24 -18
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +233 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model}/__init__.py +19 -19
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +14 -18
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +339 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +168 -73
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +71 -66
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +166 -72
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +64 -63
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +390 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +123 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +379 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +276 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +108 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +385 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +321 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +44 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +371 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/METADATA +231 -230
- scikit_learn_intelex-2025.1.0.dist-info/RECORD +257 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -223
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -82
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -371
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -188
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -222
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -93
- scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -20,13 +20,16 @@ from abc import ABC
|
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
from scipy import sparse as sp
|
|
23
|
-
from sklearn.base import clone
|
|
24
|
-
from sklearn.ensemble import ExtraTreesClassifier as
|
|
25
|
-
from sklearn.ensemble import ExtraTreesRegressor as
|
|
26
|
-
from sklearn.ensemble import RandomForestClassifier as
|
|
27
|
-
from sklearn.ensemble import RandomForestRegressor as
|
|
23
|
+
from sklearn.base import BaseEstimator, clone
|
|
24
|
+
from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
|
|
25
|
+
from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
|
|
26
|
+
from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
|
|
27
|
+
from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
|
|
28
|
+
from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
|
|
29
|
+
from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
|
|
28
30
|
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
29
31
|
from sklearn.exceptions import DataConversionWarning
|
|
32
|
+
from sklearn.metrics import accuracy_score, r2_score
|
|
30
33
|
from sklearn.tree import (
|
|
31
34
|
DecisionTreeClassifier,
|
|
32
35
|
DecisionTreeRegressor,
|
|
@@ -36,8 +39,8 @@ from sklearn.tree import (
|
|
|
36
39
|
from sklearn.tree._tree import Tree
|
|
37
40
|
from sklearn.utils import check_random_state, deprecated
|
|
38
41
|
from sklearn.utils.validation import (
|
|
42
|
+
_check_sample_weight,
|
|
39
43
|
check_array,
|
|
40
|
-
check_consistent_length,
|
|
41
44
|
check_is_fitted,
|
|
42
45
|
check_X_y,
|
|
43
46
|
)
|
|
@@ -52,53 +55,43 @@ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
|
52
55
|
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
53
56
|
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
54
57
|
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
55
|
-
|
|
56
|
-
# try catch needed for changes in structures observed in Scikit-learn around v0.22
|
|
57
|
-
try:
|
|
58
|
-
from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
|
|
59
|
-
from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
|
|
60
|
-
except ModuleNotFoundError:
|
|
61
|
-
from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
|
|
62
|
-
from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
|
|
63
|
-
|
|
64
58
|
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
65
59
|
from onedal.utils import _num_features, _num_samples
|
|
60
|
+
from sklearnex import get_hyperparameters
|
|
61
|
+
from sklearnex._utils import register_hyperparameters
|
|
66
62
|
|
|
67
|
-
from .._config import get_config
|
|
68
63
|
from .._device_offload import dispatch, wrap_output_data
|
|
69
64
|
from .._utils import PatchingConditionsChain
|
|
65
|
+
from ..utils._array_api import get_namespace
|
|
70
66
|
|
|
71
67
|
if sklearn_check_version("1.2"):
|
|
72
68
|
from sklearn.utils._param_validation import Interval
|
|
73
69
|
if sklearn_check_version("1.4"):
|
|
74
70
|
from daal4py.sklearn.utils import _assert_all_finite
|
|
75
71
|
|
|
72
|
+
if sklearn_check_version("1.6"):
|
|
73
|
+
from sklearn.utils.validation import validate_data
|
|
74
|
+
else:
|
|
75
|
+
validate_data = BaseEstimator._validate_data
|
|
76
|
+
|
|
76
77
|
|
|
77
78
|
class BaseForest(ABC):
|
|
78
79
|
_onedal_factory = None
|
|
79
80
|
|
|
80
81
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
X, y = check_X_y(
|
|
92
|
-
X,
|
|
93
|
-
y,
|
|
94
|
-
accept_sparse=False,
|
|
95
|
-
dtype=[np.float64, np.float32],
|
|
96
|
-
multi_output=False,
|
|
97
|
-
force_all_finite=False,
|
|
98
|
-
)
|
|
82
|
+
X, y = validate_data(
|
|
83
|
+
self,
|
|
84
|
+
X,
|
|
85
|
+
y,
|
|
86
|
+
multi_output=True,
|
|
87
|
+
accept_sparse=False,
|
|
88
|
+
dtype=[np.float64, np.float32],
|
|
89
|
+
force_all_finite=False,
|
|
90
|
+
ensure_2d=True,
|
|
91
|
+
)
|
|
99
92
|
|
|
100
93
|
if sample_weight is not None:
|
|
101
|
-
sample_weight =
|
|
94
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
102
95
|
|
|
103
96
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
104
97
|
warnings.warn(
|
|
@@ -118,8 +111,6 @@ class BaseForest(ABC):
|
|
|
118
111
|
|
|
119
112
|
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
120
113
|
|
|
121
|
-
self.n_features_in_ = X.shape[1]
|
|
122
|
-
|
|
123
114
|
if expanded_class_weight is not None:
|
|
124
115
|
if sample_weight is not None:
|
|
125
116
|
sample_weight = sample_weight * expanded_class_weight
|
|
@@ -135,7 +126,9 @@ class BaseForest(ABC):
|
|
|
135
126
|
"min_samples_split": self.min_samples_split,
|
|
136
127
|
"min_samples_leaf": self.min_samples_leaf,
|
|
137
128
|
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
138
|
-
"max_features": self.
|
|
129
|
+
"max_features": self._to_absolute_max_features(
|
|
130
|
+
self.max_features, self.n_features_in_
|
|
131
|
+
),
|
|
139
132
|
"max_leaf_nodes": self.max_leaf_nodes,
|
|
140
133
|
"min_impurity_decrease": self.min_impurity_decrease,
|
|
141
134
|
"bootstrap": self.bootstrap,
|
|
@@ -173,15 +166,6 @@ class BaseForest(ABC):
|
|
|
173
166
|
|
|
174
167
|
return self
|
|
175
168
|
|
|
176
|
-
def _fit_proba(self, X, y, sample_weight=None, queue=None):
|
|
177
|
-
params = self.get_params()
|
|
178
|
-
self.__class__(**params)
|
|
179
|
-
|
|
180
|
-
# We use stock metaestimators below, so the only way
|
|
181
|
-
# to pass a queue is using config_context.
|
|
182
|
-
cfg = get_config()
|
|
183
|
-
cfg["target_offload"] = queue
|
|
184
|
-
|
|
185
169
|
def _save_attributes(self):
|
|
186
170
|
if self.oob_score:
|
|
187
171
|
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
@@ -204,8 +188,45 @@ class BaseForest(ABC):
|
|
|
204
188
|
self._validate_estimator()
|
|
205
189
|
return self
|
|
206
190
|
|
|
207
|
-
|
|
208
|
-
|
|
191
|
+
def _to_absolute_max_features(self, max_features, n_features):
|
|
192
|
+
if max_features is None:
|
|
193
|
+
return n_features
|
|
194
|
+
if isinstance(max_features, str):
|
|
195
|
+
if max_features == "auto":
|
|
196
|
+
if not sklearn_check_version("1.3"):
|
|
197
|
+
if sklearn_check_version("1.1"):
|
|
198
|
+
warnings.warn(
|
|
199
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
200
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
201
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
202
|
+
"parameter as it is also the default value for "
|
|
203
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
204
|
+
FutureWarning,
|
|
205
|
+
)
|
|
206
|
+
return (
|
|
207
|
+
max(1, int(np.sqrt(n_features)))
|
|
208
|
+
if isinstance(self, ForestClassifier)
|
|
209
|
+
else n_features
|
|
210
|
+
)
|
|
211
|
+
if max_features == "sqrt":
|
|
212
|
+
return max(1, int(np.sqrt(n_features)))
|
|
213
|
+
if max_features == "log2":
|
|
214
|
+
return max(1, int(np.log2(n_features)))
|
|
215
|
+
allowed_string_values = (
|
|
216
|
+
'"sqrt" or "log2"'
|
|
217
|
+
if sklearn_check_version("1.3")
|
|
218
|
+
else '"auto", "sqrt" or "log2"'
|
|
219
|
+
)
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Invalid value for max_features. Allowed string "
|
|
222
|
+
f"values are {allowed_string_values}."
|
|
223
|
+
)
|
|
224
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
225
|
+
return max_features
|
|
226
|
+
if max_features > 0.0:
|
|
227
|
+
return max(1, int(max_features * n_features))
|
|
228
|
+
return 0
|
|
229
|
+
|
|
209
230
|
def _check_parameters(self):
|
|
210
231
|
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
211
232
|
if not 1 <= self.min_samples_leaf:
|
|
@@ -281,38 +302,6 @@ class BaseForest(ABC):
|
|
|
281
302
|
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
282
303
|
)
|
|
283
304
|
|
|
284
|
-
def check_sample_weight(self, sample_weight, X, dtype=None):
|
|
285
|
-
n_samples = _num_samples(X)
|
|
286
|
-
|
|
287
|
-
if dtype is not None and dtype not in [np.float32, np.float64]:
|
|
288
|
-
dtype = np.float64
|
|
289
|
-
|
|
290
|
-
if sample_weight is None:
|
|
291
|
-
sample_weight = np.ones(n_samples, dtype=dtype)
|
|
292
|
-
elif isinstance(sample_weight, numbers.Number):
|
|
293
|
-
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
|
|
294
|
-
else:
|
|
295
|
-
if dtype is None:
|
|
296
|
-
dtype = [np.float64, np.float32]
|
|
297
|
-
sample_weight = check_array(
|
|
298
|
-
sample_weight,
|
|
299
|
-
accept_sparse=False,
|
|
300
|
-
ensure_2d=False,
|
|
301
|
-
dtype=dtype,
|
|
302
|
-
order="C",
|
|
303
|
-
force_all_finite=False,
|
|
304
|
-
)
|
|
305
|
-
if sample_weight.ndim != 1:
|
|
306
|
-
raise ValueError("Sample weights must be 1D array or scalar")
|
|
307
|
-
|
|
308
|
-
if sample_weight.shape != (n_samples,):
|
|
309
|
-
raise ValueError(
|
|
310
|
-
"sample_weight.shape == {}, expected {}!".format(
|
|
311
|
-
sample_weight.shape, (n_samples,)
|
|
312
|
-
)
|
|
313
|
-
)
|
|
314
|
-
return sample_weight
|
|
315
|
-
|
|
316
305
|
@property
|
|
317
306
|
def estimators_(self):
|
|
318
307
|
if hasattr(self, "_cached_estimators_"):
|
|
@@ -413,7 +402,7 @@ class BaseForest(ABC):
|
|
|
413
402
|
self.estimator = estimator
|
|
414
403
|
|
|
415
404
|
|
|
416
|
-
class ForestClassifier(
|
|
405
|
+
class ForestClassifier(_sklearn_ForestClassifier, BaseForest):
|
|
417
406
|
# Surprisingly, even though scikit-learn warns against using
|
|
418
407
|
# their ForestClassifier directly, it actually has a more stable
|
|
419
408
|
# API than the user-facing objects (over time). If they change it
|
|
@@ -453,14 +442,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
453
442
|
|
|
454
443
|
# The estimator is checked against the class attribute for conformance.
|
|
455
444
|
# This should only trigger if the user uses this class directly.
|
|
456
|
-
if (
|
|
457
|
-
self.
|
|
458
|
-
and self._onedal_factory != onedal_RandomForestClassifier
|
|
445
|
+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
|
|
446
|
+
self._onedal_factory, onedal_RandomForestClassifier
|
|
459
447
|
):
|
|
460
448
|
self._onedal_factory = onedal_RandomForestClassifier
|
|
461
|
-
elif (
|
|
462
|
-
self.
|
|
463
|
-
and self._onedal_factory != onedal_ExtraTreesClassifier
|
|
449
|
+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
|
|
450
|
+
self._onedal_factory, onedal_ExtraTreesClassifier
|
|
464
451
|
):
|
|
465
452
|
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
466
453
|
|
|
@@ -479,7 +466,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
479
466
|
"fit",
|
|
480
467
|
{
|
|
481
468
|
"onedal": self.__class__._onedal_fit,
|
|
482
|
-
"sklearn":
|
|
469
|
+
"sklearn": _sklearn_ForestClassifier.fit,
|
|
483
470
|
},
|
|
484
471
|
X,
|
|
485
472
|
y,
|
|
@@ -552,18 +539,14 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
552
539
|
)
|
|
553
540
|
|
|
554
541
|
if patching_status.get_status():
|
|
555
|
-
|
|
556
|
-
X,
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
)
|
|
564
|
-
else:
|
|
565
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
566
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
542
|
+
X, y = check_X_y(
|
|
543
|
+
X,
|
|
544
|
+
y,
|
|
545
|
+
multi_output=True,
|
|
546
|
+
accept_sparse=True,
|
|
547
|
+
dtype=[np.float64, np.float32],
|
|
548
|
+
force_all_finite=False,
|
|
549
|
+
)
|
|
567
550
|
|
|
568
551
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
569
552
|
warnings.warn(
|
|
@@ -617,12 +600,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
617
600
|
|
|
618
601
|
@wrap_output_data
|
|
619
602
|
def predict(self, X):
|
|
603
|
+
check_is_fitted(self)
|
|
620
604
|
return dispatch(
|
|
621
605
|
self,
|
|
622
606
|
"predict",
|
|
623
607
|
{
|
|
624
608
|
"onedal": self.__class__._onedal_predict,
|
|
625
|
-
"sklearn":
|
|
609
|
+
"sklearn": _sklearn_ForestClassifier.predict,
|
|
626
610
|
},
|
|
627
611
|
X,
|
|
628
612
|
)
|
|
@@ -632,34 +616,50 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
632
616
|
# TODO:
|
|
633
617
|
# _check_proba()
|
|
634
618
|
# self._check_proba()
|
|
635
|
-
|
|
636
|
-
self._check_feature_names(X, reset=False)
|
|
637
|
-
if hasattr(self, "n_features_in_"):
|
|
638
|
-
try:
|
|
639
|
-
num_features = _num_features(X)
|
|
640
|
-
except TypeError:
|
|
641
|
-
num_features = _num_samples(X)
|
|
642
|
-
if num_features != self.n_features_in_:
|
|
643
|
-
raise ValueError(
|
|
644
|
-
(
|
|
645
|
-
f"X has {num_features} features, "
|
|
646
|
-
f"but {self.__class__.__name__} is expecting "
|
|
647
|
-
f"{self.n_features_in_} features as input"
|
|
648
|
-
)
|
|
649
|
-
)
|
|
619
|
+
check_is_fitted(self)
|
|
650
620
|
return dispatch(
|
|
651
621
|
self,
|
|
652
622
|
"predict_proba",
|
|
653
623
|
{
|
|
654
624
|
"onedal": self.__class__._onedal_predict_proba,
|
|
655
|
-
"sklearn":
|
|
625
|
+
"sklearn": _sklearn_ForestClassifier.predict_proba,
|
|
656
626
|
},
|
|
657
627
|
X,
|
|
658
628
|
)
|
|
659
629
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
630
|
+
def predict_log_proba(self, X):
|
|
631
|
+
xp, _ = get_namespace(X)
|
|
632
|
+
proba = self.predict_proba(X)
|
|
633
|
+
|
|
634
|
+
if self.n_outputs_ == 1:
|
|
635
|
+
return xp.log(proba)
|
|
636
|
+
|
|
637
|
+
else:
|
|
638
|
+
for k in range(self.n_outputs_):
|
|
639
|
+
proba[k] = xp.log(proba[k])
|
|
640
|
+
|
|
641
|
+
return proba
|
|
642
|
+
|
|
643
|
+
@wrap_output_data
|
|
644
|
+
def score(self, X, y, sample_weight=None):
|
|
645
|
+
check_is_fitted(self)
|
|
646
|
+
return dispatch(
|
|
647
|
+
self,
|
|
648
|
+
"score",
|
|
649
|
+
{
|
|
650
|
+
"onedal": self.__class__._onedal_score,
|
|
651
|
+
"sklearn": _sklearn_ForestClassifier.score,
|
|
652
|
+
},
|
|
653
|
+
X,
|
|
654
|
+
y,
|
|
655
|
+
sample_weight=sample_weight,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
|
|
659
|
+
predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
|
|
660
|
+
predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
|
|
661
|
+
predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
662
|
+
score.__doc__ = _sklearn_ForestClassifier.score.__doc__
|
|
663
663
|
|
|
664
664
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
665
665
|
class_name = self.__class__.__name__
|
|
@@ -686,7 +686,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
686
686
|
]
|
|
687
687
|
)
|
|
688
688
|
|
|
689
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
689
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
690
690
|
X = data[0]
|
|
691
691
|
|
|
692
692
|
patching_status.and_conditions(
|
|
@@ -747,11 +747,15 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
747
747
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
748
748
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
749
749
|
),
|
|
750
|
-
(
|
|
750
|
+
(
|
|
751
|
+
not self.oob_score,
|
|
752
|
+
"oob_scores using r2 or accuracy not implemented.",
|
|
753
|
+
),
|
|
754
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
751
755
|
]
|
|
752
756
|
)
|
|
753
757
|
|
|
754
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
758
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
755
759
|
X = data[0]
|
|
756
760
|
|
|
757
761
|
patching_status.and_conditions(
|
|
@@ -786,31 +790,68 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
786
790
|
return patching_status
|
|
787
791
|
|
|
788
792
|
def _onedal_predict(self, X, queue=None):
|
|
789
|
-
X = check_array(
|
|
790
|
-
X,
|
|
791
|
-
dtype=[np.float64, np.float32],
|
|
792
|
-
force_all_finite=False,
|
|
793
|
-
) # Warning, order of dtype matters
|
|
794
|
-
check_is_fitted(self, "_onedal_estimator")
|
|
795
793
|
|
|
796
794
|
if sklearn_check_version("1.0"):
|
|
797
|
-
|
|
795
|
+
X = validate_data(
|
|
796
|
+
self,
|
|
797
|
+
X,
|
|
798
|
+
dtype=[np.float64, np.float32],
|
|
799
|
+
force_all_finite=False,
|
|
800
|
+
reset=False,
|
|
801
|
+
ensure_2d=True,
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
X = check_array(
|
|
805
|
+
X,
|
|
806
|
+
dtype=[np.float64, np.float32],
|
|
807
|
+
force_all_finite=False,
|
|
808
|
+
) # Warning, order of dtype matters
|
|
809
|
+
if hasattr(self, "n_features_in_"):
|
|
810
|
+
try:
|
|
811
|
+
num_features = _num_features(X)
|
|
812
|
+
except TypeError:
|
|
813
|
+
num_features = _num_samples(X)
|
|
814
|
+
if num_features != self.n_features_in_:
|
|
815
|
+
raise ValueError(
|
|
816
|
+
(
|
|
817
|
+
f"X has {num_features} features, "
|
|
818
|
+
f"but {self.__class__.__name__} is expecting "
|
|
819
|
+
f"{self.n_features_in_} features as input"
|
|
820
|
+
)
|
|
821
|
+
)
|
|
822
|
+
self._check_n_features(X, reset=False)
|
|
798
823
|
|
|
799
824
|
res = self._onedal_estimator.predict(X, queue=queue)
|
|
800
825
|
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
801
826
|
|
|
802
827
|
def _onedal_predict_proba(self, X, queue=None):
|
|
803
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
804
|
-
check_is_fitted(self, "_onedal_estimator")
|
|
805
828
|
|
|
806
|
-
if sklearn_check_version("0.23"):
|
|
807
|
-
self._check_n_features(X, reset=False)
|
|
808
829
|
if sklearn_check_version("1.0"):
|
|
809
|
-
|
|
830
|
+
X = validate_data(
|
|
831
|
+
self,
|
|
832
|
+
X,
|
|
833
|
+
dtype=[np.float64, np.float32],
|
|
834
|
+
force_all_finite=False,
|
|
835
|
+
reset=False,
|
|
836
|
+
ensure_2d=True,
|
|
837
|
+
)
|
|
838
|
+
else:
|
|
839
|
+
X = check_array(
|
|
840
|
+
X,
|
|
841
|
+
dtype=[np.float64, np.float32],
|
|
842
|
+
force_all_finite=False,
|
|
843
|
+
) # Warning, order of dtype matters
|
|
844
|
+
self._check_n_features(X, reset=False)
|
|
845
|
+
|
|
810
846
|
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
811
847
|
|
|
848
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
849
|
+
return accuracy_score(
|
|
850
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
851
|
+
)
|
|
852
|
+
|
|
812
853
|
|
|
813
|
-
class ForestRegressor(
|
|
854
|
+
class ForestRegressor(_sklearn_ForestRegressor, BaseForest):
|
|
814
855
|
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
815
856
|
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
816
857
|
|
|
@@ -843,14 +884,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
843
884
|
|
|
844
885
|
# The splitter is checked against the class attribute for conformance
|
|
845
886
|
# This should only trigger if the user uses this class directly.
|
|
846
|
-
if (
|
|
847
|
-
self.
|
|
848
|
-
and self._onedal_factory != onedal_RandomForestRegressor
|
|
887
|
+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
|
|
888
|
+
self._onedal_factory, onedal_RandomForestRegressor
|
|
849
889
|
):
|
|
850
890
|
self._onedal_factory = onedal_RandomForestRegressor
|
|
851
|
-
elif (
|
|
852
|
-
self.
|
|
853
|
-
and self._onedal_factory != onedal_ExtraTreesRegressor
|
|
891
|
+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
|
|
892
|
+
self._onedal_factory, onedal_ExtraTreesRegressor
|
|
854
893
|
):
|
|
855
894
|
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
856
895
|
|
|
@@ -920,18 +959,14 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
920
959
|
)
|
|
921
960
|
|
|
922
961
|
if patching_status.get_status():
|
|
923
|
-
|
|
924
|
-
X,
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
)
|
|
932
|
-
else:
|
|
933
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
934
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
962
|
+
X, y = check_X_y(
|
|
963
|
+
X,
|
|
964
|
+
y,
|
|
965
|
+
multi_output=True,
|
|
966
|
+
accept_sparse=True,
|
|
967
|
+
dtype=[np.float64, np.float32],
|
|
968
|
+
force_all_finite=False,
|
|
969
|
+
)
|
|
935
970
|
|
|
936
971
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
937
972
|
warnings.warn(
|
|
@@ -1006,7 +1041,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1006
1041
|
]
|
|
1007
1042
|
)
|
|
1008
1043
|
|
|
1009
|
-
elif method_name
|
|
1044
|
+
elif method_name in ["predict", "score"]:
|
|
1010
1045
|
X = data[0]
|
|
1011
1046
|
|
|
1012
1047
|
patching_status.and_conditions(
|
|
@@ -1056,11 +1091,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1056
1091
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1057
1092
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1058
1093
|
),
|
|
1059
|
-
(
|
|
1094
|
+
(not self.oob_score, "oob_score value is not sklearn conformant."),
|
|
1095
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
1060
1096
|
]
|
|
1061
1097
|
)
|
|
1062
1098
|
|
|
1063
|
-
elif method_name
|
|
1099
|
+
elif method_name in ["predict", "score"]:
|
|
1064
1100
|
X = data[0]
|
|
1065
1101
|
|
|
1066
1102
|
patching_status.and_conditions(
|
|
@@ -1093,23 +1129,36 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1093
1129
|
return patching_status
|
|
1094
1130
|
|
|
1095
1131
|
def _onedal_predict(self, X, queue=None):
|
|
1096
|
-
X = check_array(
|
|
1097
|
-
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1098
|
-
) # Warning, order of dtype matters
|
|
1099
1132
|
check_is_fitted(self, "_onedal_estimator")
|
|
1100
1133
|
|
|
1101
1134
|
if sklearn_check_version("1.0"):
|
|
1102
|
-
|
|
1135
|
+
X = validate_data(
|
|
1136
|
+
self,
|
|
1137
|
+
X,
|
|
1138
|
+
dtype=[np.float64, np.float32],
|
|
1139
|
+
force_all_finite=False,
|
|
1140
|
+
reset=False,
|
|
1141
|
+
ensure_2d=True,
|
|
1142
|
+
) # Warning, order of dtype matters
|
|
1143
|
+
else:
|
|
1144
|
+
X = check_array(
|
|
1145
|
+
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1146
|
+
) # Warning, order of dtype matters
|
|
1103
1147
|
|
|
1104
1148
|
return self._onedal_estimator.predict(X, queue=queue)
|
|
1105
1149
|
|
|
1150
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
1151
|
+
return r2_score(
|
|
1152
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1106
1155
|
def fit(self, X, y, sample_weight=None):
|
|
1107
1156
|
dispatch(
|
|
1108
1157
|
self,
|
|
1109
1158
|
"fit",
|
|
1110
1159
|
{
|
|
1111
1160
|
"onedal": self.__class__._onedal_fit,
|
|
1112
|
-
"sklearn":
|
|
1161
|
+
"sklearn": _sklearn_ForestRegressor.fit,
|
|
1113
1162
|
},
|
|
1114
1163
|
X,
|
|
1115
1164
|
y,
|
|
@@ -1119,28 +1168,46 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1119
1168
|
|
|
1120
1169
|
@wrap_output_data
|
|
1121
1170
|
def predict(self, X):
|
|
1171
|
+
check_is_fitted(self)
|
|
1122
1172
|
return dispatch(
|
|
1123
1173
|
self,
|
|
1124
1174
|
"predict",
|
|
1125
1175
|
{
|
|
1126
1176
|
"onedal": self.__class__._onedal_predict,
|
|
1127
|
-
"sklearn":
|
|
1177
|
+
"sklearn": _sklearn_ForestRegressor.predict,
|
|
1128
1178
|
},
|
|
1129
1179
|
X,
|
|
1130
1180
|
)
|
|
1131
1181
|
|
|
1132
|
-
|
|
1133
|
-
|
|
1182
|
+
@wrap_output_data
|
|
1183
|
+
def score(self, X, y, sample_weight=None):
|
|
1184
|
+
check_is_fitted(self)
|
|
1185
|
+
return dispatch(
|
|
1186
|
+
self,
|
|
1187
|
+
"score",
|
|
1188
|
+
{
|
|
1189
|
+
"onedal": self.__class__._onedal_score,
|
|
1190
|
+
"sklearn": _sklearn_ForestRegressor.score,
|
|
1191
|
+
},
|
|
1192
|
+
X,
|
|
1193
|
+
y,
|
|
1194
|
+
sample_weight=sample_weight,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
|
|
1198
|
+
predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
|
|
1199
|
+
score.__doc__ = _sklearn_ForestRegressor.score.__doc__
|
|
1134
1200
|
|
|
1135
1201
|
|
|
1136
|
-
@
|
|
1202
|
+
@register_hyperparameters({"infer": get_hyperparameters("decision_forest", "infer")})
|
|
1203
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1137
1204
|
class RandomForestClassifier(ForestClassifier):
|
|
1138
|
-
__doc__ =
|
|
1205
|
+
__doc__ = _sklearn_RandomForestClassifier.__doc__
|
|
1139
1206
|
_onedal_factory = onedal_RandomForestClassifier
|
|
1140
1207
|
|
|
1141
1208
|
if sklearn_check_version("1.2"):
|
|
1142
1209
|
_parameter_constraints: dict = {
|
|
1143
|
-
**
|
|
1210
|
+
**_sklearn_RandomForestClassifier._parameter_constraints,
|
|
1144
1211
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1145
1212
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1146
1213
|
}
|
|
@@ -1343,14 +1410,14 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1343
1410
|
self.min_bin_size = min_bin_size
|
|
1344
1411
|
|
|
1345
1412
|
|
|
1346
|
-
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1413
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1347
1414
|
class RandomForestRegressor(ForestRegressor):
|
|
1348
|
-
__doc__ =
|
|
1415
|
+
__doc__ = _sklearn_RandomForestRegressor.__doc__
|
|
1349
1416
|
_onedal_factory = onedal_RandomForestRegressor
|
|
1350
1417
|
|
|
1351
1418
|
if sklearn_check_version("1.2"):
|
|
1352
1419
|
_parameter_constraints: dict = {
|
|
1353
|
-
**
|
|
1420
|
+
**_sklearn_RandomForestRegressor._parameter_constraints,
|
|
1354
1421
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1355
1422
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1356
1423
|
}
|
|
@@ -1544,14 +1611,14 @@ class RandomForestRegressor(ForestRegressor):
|
|
|
1544
1611
|
self.min_bin_size = min_bin_size
|
|
1545
1612
|
|
|
1546
1613
|
|
|
1547
|
-
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
|
|
1614
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1548
1615
|
class ExtraTreesClassifier(ForestClassifier):
|
|
1549
|
-
__doc__ =
|
|
1616
|
+
__doc__ = _sklearn_ExtraTreesClassifier.__doc__
|
|
1550
1617
|
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1551
1618
|
|
|
1552
1619
|
if sklearn_check_version("1.2"):
|
|
1553
1620
|
_parameter_constraints: dict = {
|
|
1554
|
-
**
|
|
1621
|
+
**_sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1555
1622
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1556
1623
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1557
1624
|
}
|
|
@@ -1754,14 +1821,14 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1754
1821
|
self.min_bin_size = min_bin_size
|
|
1755
1822
|
|
|
1756
1823
|
|
|
1757
|
-
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1824
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1758
1825
|
class ExtraTreesRegressor(ForestRegressor):
|
|
1759
|
-
__doc__ =
|
|
1826
|
+
__doc__ = _sklearn_ExtraTreesRegressor.__doc__
|
|
1760
1827
|
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1761
1828
|
|
|
1762
1829
|
if sklearn_check_version("1.2"):
|
|
1763
1830
|
_parameter_constraints: dict = {
|
|
1764
|
-
**
|
|
1831
|
+
**_sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1765
1832
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1766
1833
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1767
1834
|
}
|
|
@@ -1956,7 +2023,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1956
2023
|
|
|
1957
2024
|
|
|
1958
2025
|
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
2026
|
+
_sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
2027
|
+
_sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
2028
|
+
_sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
2029
|
+
_sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|