scikit-learn-intelex 2024.4.0__py312-none-win_amd64.whl → 2025.10.0__py312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp312-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp312-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -28
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +3 -3
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +4 -2
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp312-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp312-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal}/neighbors/__init__.py +19 -19
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/__init__.py +7 -3
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/__main__.py +2 -2
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +128 -78
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +101 -32
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +38 -29
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/conftest.py +20 -1
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +199 -21
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +207 -2
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -17
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +288 -440
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +1 -1
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +17 -3
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +11 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +3 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +3 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +30 -62
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +56 -9
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +45 -101
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +63 -94
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +49 -25
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +6 -4
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +54 -8
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +9 -4
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +3 -4
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +6 -4
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +7 -4
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +1 -3
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +99 -117
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +55 -16
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +95 -113
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +51 -16
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +43 -20
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +5 -4
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +122 -75
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/validation.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -1
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
- scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
- scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
- {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_config.py +0 -110
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -250
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_utils.py +0 -109
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -143
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -335
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -56
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -113
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -316
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +0 -385
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -117
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -91
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -26
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -303
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -133
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -50
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -71
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +0 -164
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -39
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -227
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -99
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -20
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +0 -97
- scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -59
- scikit_learn_intelex-2024.4.0.dist-info/METADATA +0 -230
- scikit_learn_intelex-2024.4.0.dist-info/RECORD +0 -101
- {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal}/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/top_level.txt +0 -0
|
@@ -17,19 +17,20 @@
|
|
|
17
17
|
import numbers
|
|
18
18
|
import warnings
|
|
19
19
|
from abc import ABC
|
|
20
|
+
from collections.abc import Iterable
|
|
20
21
|
|
|
21
22
|
import numpy as np
|
|
22
23
|
from scipy import sparse as sp
|
|
23
|
-
from sklearn.base import clone
|
|
24
|
-
from sklearn.ensemble import ExtraTreesClassifier as
|
|
25
|
-
from sklearn.ensemble import ExtraTreesRegressor as
|
|
26
|
-
from sklearn.ensemble import RandomForestClassifier as
|
|
27
|
-
from sklearn.ensemble import RandomForestRegressor as
|
|
28
|
-
from sklearn.ensemble._forest import ForestClassifier as
|
|
29
|
-
from sklearn.ensemble._forest import ForestRegressor as
|
|
24
|
+
from sklearn.base import BaseEstimator, clone
|
|
25
|
+
from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
|
|
26
|
+
from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
|
|
27
|
+
from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
|
|
28
|
+
from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
|
|
29
|
+
from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
|
|
30
|
+
from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
|
|
30
31
|
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
31
32
|
from sklearn.exceptions import DataConversionWarning
|
|
32
|
-
from sklearn.metrics import accuracy_score
|
|
33
|
+
from sklearn.metrics import accuracy_score, r2_score
|
|
33
34
|
from sklearn.tree import (
|
|
34
35
|
DecisionTreeClassifier,
|
|
35
36
|
DecisionTreeRegressor,
|
|
@@ -38,7 +39,12 @@ from sklearn.tree import (
|
|
|
38
39
|
)
|
|
39
40
|
from sklearn.tree._tree import Tree
|
|
40
41
|
from sklearn.utils import check_random_state, deprecated
|
|
41
|
-
from sklearn.utils.validation import
|
|
42
|
+
from sklearn.utils.validation import (
|
|
43
|
+
_check_sample_weight,
|
|
44
|
+
check_array,
|
|
45
|
+
check_is_fitted,
|
|
46
|
+
check_X_y,
|
|
47
|
+
)
|
|
42
48
|
|
|
43
49
|
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
44
50
|
from daal4py.sklearn._utils import (
|
|
@@ -46,16 +52,21 @@ from daal4py.sklearn._utils import (
|
|
|
46
52
|
daal_check_version,
|
|
47
53
|
sklearn_check_version,
|
|
48
54
|
)
|
|
55
|
+
from onedal._device_offload import support_input_format
|
|
49
56
|
from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
50
57
|
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
51
58
|
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
52
59
|
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
53
60
|
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
54
|
-
from onedal.utils import _num_features, _num_samples
|
|
55
|
-
from sklearnex.
|
|
61
|
+
from onedal.utils.validation import _num_features, _num_samples
|
|
62
|
+
from sklearnex._utils import register_hyperparameters
|
|
56
63
|
|
|
64
|
+
from .._config import get_config
|
|
57
65
|
from .._device_offload import dispatch, wrap_output_data
|
|
58
66
|
from .._utils import PatchingConditionsChain
|
|
67
|
+
from ..base import oneDALEstimator
|
|
68
|
+
from ..utils._array_api import get_namespace
|
|
69
|
+
from ..utils.validation import check_n_features, validate_data
|
|
59
70
|
|
|
60
71
|
if sklearn_check_version("1.2"):
|
|
61
72
|
from sklearn.utils._param_validation import Interval
|
|
@@ -63,21 +74,26 @@ if sklearn_check_version("1.4"):
|
|
|
63
74
|
from daal4py.sklearn.utils import _assert_all_finite
|
|
64
75
|
|
|
65
76
|
|
|
66
|
-
class BaseForest(ABC):
|
|
77
|
+
class BaseForest(oneDALEstimator, ABC):
|
|
67
78
|
_onedal_factory = None
|
|
68
79
|
|
|
69
80
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
81
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
82
|
+
xp, _ = get_namespace(X)
|
|
83
|
+
if not use_raw_input:
|
|
84
|
+
X, y = validate_data(
|
|
85
|
+
self,
|
|
86
|
+
X,
|
|
87
|
+
y,
|
|
88
|
+
multi_output=True,
|
|
89
|
+
accept_sparse=False,
|
|
90
|
+
dtype=[np.float64, np.float32],
|
|
91
|
+
ensure_all_finite=False,
|
|
92
|
+
ensure_2d=True,
|
|
93
|
+
)
|
|
78
94
|
|
|
79
|
-
|
|
80
|
-
|
|
95
|
+
if sample_weight is not None:
|
|
96
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
81
97
|
|
|
82
98
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
83
99
|
warnings.warn(
|
|
@@ -91,21 +107,30 @@ class BaseForest(ABC):
|
|
|
91
107
|
if y.ndim == 1:
|
|
92
108
|
# reshape is necessary to preserve the data contiguity against vs
|
|
93
109
|
# [:, np.newaxis] that does not.
|
|
94
|
-
y =
|
|
110
|
+
y = xp.reshape(y, (-1, 1))
|
|
95
111
|
|
|
96
112
|
self._n_samples, self.n_outputs_ = y.shape
|
|
97
113
|
|
|
98
|
-
|
|
114
|
+
if not use_raw_input:
|
|
115
|
+
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
99
116
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
117
|
+
if expanded_class_weight is not None:
|
|
118
|
+
if sample_weight is not None:
|
|
119
|
+
sample_weight = sample_weight * expanded_class_weight
|
|
120
|
+
else:
|
|
121
|
+
sample_weight = expanded_class_weight
|
|
103
122
|
if sample_weight is not None:
|
|
104
|
-
sample_weight = sample_weight
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
123
|
+
sample_weight = [sample_weight]
|
|
124
|
+
else:
|
|
125
|
+
# try catch needed for raw_inputs + array_api data where unlike
|
|
126
|
+
# numpy the way to yield unique values is via `unique_values`
|
|
127
|
+
# This should be removed when refactored for gpu zero-copy
|
|
128
|
+
try:
|
|
129
|
+
self.classes_ = xp.unique(y)
|
|
130
|
+
except AttributeError:
|
|
131
|
+
self.classes_ = xp.unique_values(y)
|
|
132
|
+
self.n_classes_ = len(self.classes_)
|
|
133
|
+
self.n_features_in_ = X.shape[1]
|
|
109
134
|
|
|
110
135
|
onedal_params = {
|
|
111
136
|
"n_estimators": self.n_estimators,
|
|
@@ -114,7 +139,9 @@ class BaseForest(ABC):
|
|
|
114
139
|
"min_samples_split": self.min_samples_split,
|
|
115
140
|
"min_samples_leaf": self.min_samples_leaf,
|
|
116
141
|
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
117
|
-
"max_features": self.
|
|
142
|
+
"max_features": self._to_absolute_max_features(
|
|
143
|
+
self.max_features, self.n_features_in_
|
|
144
|
+
),
|
|
118
145
|
"max_leaf_nodes": self.max_leaf_nodes,
|
|
119
146
|
"min_impurity_decrease": self.min_impurity_decrease,
|
|
120
147
|
"bootstrap": self.bootstrap,
|
|
@@ -131,24 +158,29 @@ class BaseForest(ABC):
|
|
|
131
158
|
"max_samples": self.max_samples,
|
|
132
159
|
}
|
|
133
160
|
|
|
134
|
-
|
|
135
|
-
onedal_params["min_impurity_split"] = self.min_impurity_split
|
|
136
|
-
else:
|
|
137
|
-
onedal_params["min_impurity_split"] = None
|
|
161
|
+
onedal_params["min_impurity_split"] = None
|
|
138
162
|
|
|
139
163
|
# Lazy evaluation of estimators_
|
|
140
164
|
self._cached_estimators_ = None
|
|
141
165
|
|
|
142
166
|
# Compute
|
|
143
167
|
self._onedal_estimator = self._onedal_factory(**onedal_params)
|
|
144
|
-
self._onedal_estimator.fit(X,
|
|
168
|
+
self._onedal_estimator.fit(X, xp.reshape(y, (-1,)), sample_weight, queue=queue)
|
|
145
169
|
|
|
146
170
|
self._save_attributes()
|
|
147
171
|
|
|
148
172
|
# Decapsulate classes_ attributes
|
|
149
173
|
if hasattr(self, "classes_") and self.n_outputs_ == 1:
|
|
150
|
-
self.n_classes_ =
|
|
151
|
-
|
|
174
|
+
self.n_classes_ = (
|
|
175
|
+
self.n_classes_[0]
|
|
176
|
+
if isinstance(self.n_classes_, Iterable)
|
|
177
|
+
else self.n_classes_
|
|
178
|
+
)
|
|
179
|
+
self.classes_ = (
|
|
180
|
+
self.classes_[0]
|
|
181
|
+
if isinstance(self.classes_[0], Iterable)
|
|
182
|
+
else self.classes_
|
|
183
|
+
)
|
|
152
184
|
|
|
153
185
|
return self
|
|
154
186
|
|
|
@@ -174,6 +206,45 @@ class BaseForest(ABC):
|
|
|
174
206
|
self._validate_estimator()
|
|
175
207
|
return self
|
|
176
208
|
|
|
209
|
+
def _to_absolute_max_features(self, max_features, n_features):
|
|
210
|
+
if max_features is None:
|
|
211
|
+
return n_features
|
|
212
|
+
if isinstance(max_features, str):
|
|
213
|
+
if max_features == "auto":
|
|
214
|
+
if not sklearn_check_version("1.3"):
|
|
215
|
+
if sklearn_check_version("1.1"):
|
|
216
|
+
warnings.warn(
|
|
217
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
218
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
219
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
220
|
+
"parameter as it is also the default value for "
|
|
221
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
222
|
+
FutureWarning,
|
|
223
|
+
)
|
|
224
|
+
return (
|
|
225
|
+
max(1, int(np.sqrt(n_features)))
|
|
226
|
+
if isinstance(self, ForestClassifier)
|
|
227
|
+
else n_features
|
|
228
|
+
)
|
|
229
|
+
if max_features == "sqrt":
|
|
230
|
+
return max(1, int(np.sqrt(n_features)))
|
|
231
|
+
if max_features == "log2":
|
|
232
|
+
return max(1, int(np.log2(n_features)))
|
|
233
|
+
allowed_string_values = (
|
|
234
|
+
'"sqrt" or "log2"'
|
|
235
|
+
if sklearn_check_version("1.3")
|
|
236
|
+
else '"auto", "sqrt" or "log2"'
|
|
237
|
+
)
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"Invalid value for max_features. Allowed string "
|
|
240
|
+
f"values are {allowed_string_values}."
|
|
241
|
+
)
|
|
242
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
243
|
+
return max_features
|
|
244
|
+
if max_features > 0.0:
|
|
245
|
+
return max(1, int(max_features * n_features))
|
|
246
|
+
return 0
|
|
247
|
+
|
|
177
248
|
def _check_parameters(self):
|
|
178
249
|
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
179
250
|
if not 1 <= self.min_samples_leaf:
|
|
@@ -249,38 +320,6 @@ class BaseForest(ABC):
|
|
|
249
320
|
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
250
321
|
)
|
|
251
322
|
|
|
252
|
-
def check_sample_weight(self, sample_weight, X, dtype=None):
|
|
253
|
-
n_samples = _num_samples(X)
|
|
254
|
-
|
|
255
|
-
if dtype is not None and dtype not in [np.float32, np.float64]:
|
|
256
|
-
dtype = np.float64
|
|
257
|
-
|
|
258
|
-
if sample_weight is None:
|
|
259
|
-
sample_weight = np.ones(n_samples, dtype=dtype)
|
|
260
|
-
elif isinstance(sample_weight, numbers.Number):
|
|
261
|
-
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
|
|
262
|
-
else:
|
|
263
|
-
if dtype is None:
|
|
264
|
-
dtype = [np.float64, np.float32]
|
|
265
|
-
sample_weight = check_array(
|
|
266
|
-
sample_weight,
|
|
267
|
-
accept_sparse=False,
|
|
268
|
-
ensure_2d=False,
|
|
269
|
-
dtype=dtype,
|
|
270
|
-
order="C",
|
|
271
|
-
force_all_finite=False,
|
|
272
|
-
)
|
|
273
|
-
if sample_weight.ndim != 1:
|
|
274
|
-
raise ValueError("Sample weights must be 1D array or scalar")
|
|
275
|
-
|
|
276
|
-
if sample_weight.shape != (n_samples,):
|
|
277
|
-
raise ValueError(
|
|
278
|
-
"sample_weight.shape == {}, expected {}!".format(
|
|
279
|
-
sample_weight.shape, (n_samples,)
|
|
280
|
-
)
|
|
281
|
-
)
|
|
282
|
-
return sample_weight
|
|
283
|
-
|
|
284
323
|
@property
|
|
285
324
|
def estimators_(self):
|
|
286
325
|
if hasattr(self, "_cached_estimators_"):
|
|
@@ -321,10 +360,8 @@ class BaseForest(ABC):
|
|
|
321
360
|
"min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
|
|
322
361
|
"random_state": None,
|
|
323
362
|
}
|
|
324
|
-
if not sklearn_check_version("1.0"):
|
|
325
|
-
params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
|
|
326
363
|
est = self.estimator.__class__(**params)
|
|
327
|
-
# we need to set est.tree_ field with Trees constructed from
|
|
364
|
+
# we need to set est.tree_ field with Trees constructed from
|
|
328
365
|
# oneAPI Data Analytics Library solution
|
|
329
366
|
estimators_ = []
|
|
330
367
|
|
|
@@ -335,10 +372,7 @@ class BaseForest(ABC):
|
|
|
335
372
|
est_i.set_params(
|
|
336
373
|
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
337
374
|
)
|
|
338
|
-
|
|
339
|
-
est_i.n_features_in_ = self.n_features_in_
|
|
340
|
-
else:
|
|
341
|
-
est_i.n_features_ = self.n_features_in_
|
|
375
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
342
376
|
est_i.n_outputs_ = self.n_outputs_
|
|
343
377
|
est_i.n_classes_ = n_classes_
|
|
344
378
|
tree_i_state_class = self._get_tree_state(
|
|
@@ -350,6 +384,7 @@ class BaseForest(ABC):
|
|
|
350
384
|
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
351
385
|
"values": tree_i_state_class.value_ar,
|
|
352
386
|
}
|
|
387
|
+
# Note: only on host.
|
|
353
388
|
est_i.tree_ = Tree(
|
|
354
389
|
self.n_features_in_,
|
|
355
390
|
np.array([n_classes_], dtype=np.intp),
|
|
@@ -360,16 +395,6 @@ class BaseForest(ABC):
|
|
|
360
395
|
|
|
361
396
|
self._cached_estimators_ = estimators_
|
|
362
397
|
|
|
363
|
-
if sklearn_check_version("1.0"):
|
|
364
|
-
|
|
365
|
-
@deprecated(
|
|
366
|
-
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
367
|
-
"removed in 1.2. Use `n_features_in_` instead."
|
|
368
|
-
)
|
|
369
|
-
@property
|
|
370
|
-
def n_features_(self):
|
|
371
|
-
return self.n_features_in_
|
|
372
|
-
|
|
373
398
|
if not sklearn_check_version("1.2"):
|
|
374
399
|
|
|
375
400
|
@property
|
|
@@ -381,7 +406,7 @@ class BaseForest(ABC):
|
|
|
381
406
|
self.estimator = estimator
|
|
382
407
|
|
|
383
408
|
|
|
384
|
-
class ForestClassifier(
|
|
409
|
+
class ForestClassifier(BaseForest, _sklearn_ForestClassifier):
|
|
385
410
|
# Surprisingly, even though scikit-learn warns against using
|
|
386
411
|
# their ForestClassifier directly, it actually has a more stable
|
|
387
412
|
# API than the user-facing objects (over time). If they change it
|
|
@@ -433,11 +458,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
433
458
|
if self._onedal_factory is None:
|
|
434
459
|
raise TypeError(f" oneDAL estimator has not been set.")
|
|
435
460
|
|
|
461
|
+
decision_path = support_input_format(_sklearn_ForestClassifier.decision_path)
|
|
462
|
+
apply = support_input_format(_sklearn_ForestClassifier.apply)
|
|
463
|
+
|
|
436
464
|
def _estimators_(self):
|
|
437
465
|
super()._estimators_()
|
|
438
|
-
classes_ = self.classes_[0]
|
|
439
466
|
for est in self._cached_estimators_:
|
|
440
|
-
est.classes_ = classes_
|
|
467
|
+
est.classes_ = self.classes_
|
|
441
468
|
|
|
442
469
|
def fit(self, X, y, sample_weight=None):
|
|
443
470
|
dispatch(
|
|
@@ -445,7 +472,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
445
472
|
"fit",
|
|
446
473
|
{
|
|
447
474
|
"onedal": self.__class__._onedal_fit,
|
|
448
|
-
"sklearn":
|
|
475
|
+
"sklearn": _sklearn_ForestClassifier.fit,
|
|
449
476
|
},
|
|
450
477
|
X,
|
|
451
478
|
y,
|
|
@@ -518,14 +545,24 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
518
545
|
)
|
|
519
546
|
|
|
520
547
|
if patching_status.get_status():
|
|
521
|
-
|
|
522
|
-
X,
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
548
|
+
if sklearn_check_version("1.6"):
|
|
549
|
+
X, y = check_X_y(
|
|
550
|
+
X,
|
|
551
|
+
y,
|
|
552
|
+
multi_output=True,
|
|
553
|
+
accept_sparse=True,
|
|
554
|
+
dtype=[np.float64, np.float32],
|
|
555
|
+
ensure_all_finite=False,
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
X, y = check_X_y(
|
|
559
|
+
X,
|
|
560
|
+
y,
|
|
561
|
+
multi_output=True,
|
|
562
|
+
accept_sparse=True,
|
|
563
|
+
dtype=[np.float64, np.float32],
|
|
564
|
+
force_all_finite=False,
|
|
565
|
+
)
|
|
529
566
|
|
|
530
567
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
531
568
|
warnings.warn(
|
|
@@ -555,6 +592,19 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
555
592
|
)
|
|
556
593
|
# TODO: Fix to support integers as input
|
|
557
594
|
|
|
595
|
+
if self.n_outputs_ == 1:
|
|
596
|
+
xp, is_array_api_compliant = get_namespace(y)
|
|
597
|
+
sety = xp.unique_values(y) if is_array_api_compliant else np.unique(y)
|
|
598
|
+
num_classes = sety.shape[0]
|
|
599
|
+
patching_status.and_conditions(
|
|
600
|
+
[
|
|
601
|
+
(
|
|
602
|
+
num_classes >= 2,
|
|
603
|
+
"Number of classes must be at least 2.",
|
|
604
|
+
),
|
|
605
|
+
]
|
|
606
|
+
)
|
|
607
|
+
|
|
558
608
|
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
559
609
|
|
|
560
610
|
if not self.bootstrap and self.max_samples is not None:
|
|
@@ -579,12 +629,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
579
629
|
|
|
580
630
|
@wrap_output_data
|
|
581
631
|
def predict(self, X):
|
|
632
|
+
check_is_fitted(self)
|
|
582
633
|
return dispatch(
|
|
583
634
|
self,
|
|
584
635
|
"predict",
|
|
585
636
|
{
|
|
586
637
|
"onedal": self.__class__._onedal_predict,
|
|
587
|
-
"sklearn":
|
|
638
|
+
"sklearn": _sklearn_ForestClassifier.predict,
|
|
588
639
|
},
|
|
589
640
|
X,
|
|
590
641
|
)
|
|
@@ -594,27 +645,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
594
645
|
# TODO:
|
|
595
646
|
# _check_proba()
|
|
596
647
|
# self._check_proba()
|
|
597
|
-
|
|
598
|
-
self._check_feature_names(X, reset=False)
|
|
599
|
-
if hasattr(self, "n_features_in_"):
|
|
600
|
-
try:
|
|
601
|
-
num_features = _num_features(X)
|
|
602
|
-
except TypeError:
|
|
603
|
-
num_features = _num_samples(X)
|
|
604
|
-
if num_features != self.n_features_in_:
|
|
605
|
-
raise ValueError(
|
|
606
|
-
(
|
|
607
|
-
f"X has {num_features} features, "
|
|
608
|
-
f"but {self.__class__.__name__} is expecting "
|
|
609
|
-
f"{self.n_features_in_} features as input"
|
|
610
|
-
)
|
|
611
|
-
)
|
|
648
|
+
check_is_fitted(self)
|
|
612
649
|
return dispatch(
|
|
613
650
|
self,
|
|
614
651
|
"predict_proba",
|
|
615
652
|
{
|
|
616
653
|
"onedal": self.__class__._onedal_predict_proba,
|
|
617
|
-
"sklearn":
|
|
654
|
+
"sklearn": _sklearn_ForestClassifier.predict_proba,
|
|
618
655
|
},
|
|
619
656
|
X,
|
|
620
657
|
)
|
|
@@ -634,23 +671,24 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
634
671
|
|
|
635
672
|
@wrap_output_data
|
|
636
673
|
def score(self, X, y, sample_weight=None):
|
|
674
|
+
check_is_fitted(self)
|
|
637
675
|
return dispatch(
|
|
638
676
|
self,
|
|
639
677
|
"score",
|
|
640
678
|
{
|
|
641
679
|
"onedal": self.__class__._onedal_score,
|
|
642
|
-
"sklearn":
|
|
680
|
+
"sklearn": _sklearn_ForestClassifier.score,
|
|
643
681
|
},
|
|
644
682
|
X,
|
|
645
683
|
y,
|
|
646
684
|
sample_weight=sample_weight,
|
|
647
685
|
)
|
|
648
686
|
|
|
649
|
-
fit.__doc__ =
|
|
650
|
-
predict.__doc__ =
|
|
651
|
-
predict_proba.__doc__ =
|
|
652
|
-
predict_log_proba.__doc__ =
|
|
653
|
-
score.__doc__ =
|
|
687
|
+
fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
|
|
688
|
+
predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
|
|
689
|
+
predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
|
|
690
|
+
predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
691
|
+
score.__doc__ = _sklearn_ForestClassifier.score.__doc__
|
|
654
692
|
|
|
655
693
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
656
694
|
class_name = self.__class__.__name__
|
|
@@ -738,6 +776,10 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
738
776
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
739
777
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
740
778
|
),
|
|
779
|
+
(
|
|
780
|
+
not self.oob_score,
|
|
781
|
+
"oob_scores using r2 or accuracy not implemented.",
|
|
782
|
+
),
|
|
741
783
|
(sample_weight is None, "sample_weight is not supported."),
|
|
742
784
|
]
|
|
743
785
|
)
|
|
@@ -777,26 +819,52 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
777
819
|
return patching_status
|
|
778
820
|
|
|
779
821
|
def _onedal_predict(self, X, queue=None):
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
822
|
+
xp, _ = get_namespace(X)
|
|
823
|
+
if not get_config()["use_raw_input"]:
|
|
824
|
+
X = validate_data(
|
|
825
|
+
self,
|
|
826
|
+
X,
|
|
827
|
+
dtype=[np.float64, np.float32],
|
|
828
|
+
ensure_all_finite=False,
|
|
829
|
+
reset=False,
|
|
830
|
+
ensure_2d=True,
|
|
831
|
+
)
|
|
832
|
+
if hasattr(self, "n_features_in_"):
|
|
833
|
+
try:
|
|
834
|
+
num_features = _num_features(X)
|
|
835
|
+
except TypeError:
|
|
836
|
+
num_features = _num_samples(X)
|
|
837
|
+
if num_features != self.n_features_in_:
|
|
838
|
+
raise ValueError(
|
|
839
|
+
(
|
|
840
|
+
f"X has {num_features} features, "
|
|
841
|
+
f"but {self.__class__.__name__} is expecting "
|
|
842
|
+
f"{self.n_features_in_} features as input"
|
|
843
|
+
)
|
|
844
|
+
)
|
|
845
|
+
check_n_features(self, X, reset=False)
|
|
789
846
|
|
|
790
847
|
res = self._onedal_estimator.predict(X, queue=queue)
|
|
791
|
-
|
|
848
|
+
try:
|
|
849
|
+
return xp.take(
|
|
850
|
+
xp.asarray(self.classes_, device=res.sycl_queue),
|
|
851
|
+
xp.astype(xp.reshape(res, (-1,)), xp.int64),
|
|
852
|
+
)
|
|
853
|
+
except AttributeError:
|
|
854
|
+
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
792
855
|
|
|
793
856
|
def _onedal_predict_proba(self, X, queue=None):
|
|
794
|
-
|
|
795
|
-
|
|
857
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
858
|
+
if not use_raw_input:
|
|
859
|
+
X = validate_data(
|
|
860
|
+
self,
|
|
861
|
+
X,
|
|
862
|
+
dtype=[np.float64, np.float32],
|
|
863
|
+
ensure_all_finite=False,
|
|
864
|
+
reset=False,
|
|
865
|
+
ensure_2d=True,
|
|
866
|
+
)
|
|
796
867
|
|
|
797
|
-
self._check_n_features(X, reset=False)
|
|
798
|
-
if sklearn_check_version("1.0"):
|
|
799
|
-
self._check_feature_names(X, reset=False)
|
|
800
868
|
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
801
869
|
|
|
802
870
|
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
@@ -805,7 +873,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
805
873
|
)
|
|
806
874
|
|
|
807
875
|
|
|
808
|
-
class ForestRegressor(
|
|
876
|
+
class ForestRegressor(BaseForest, _sklearn_ForestRegressor):
|
|
809
877
|
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
810
878
|
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
811
879
|
|
|
@@ -850,6 +918,9 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
850
918
|
if self._onedal_factory is None:
|
|
851
919
|
raise TypeError(f" oneDAL estimator has not been set.")
|
|
852
920
|
|
|
921
|
+
decision_path = support_input_format(_sklearn_ForestRegressor.decision_path)
|
|
922
|
+
apply = support_input_format(_sklearn_ForestRegressor.apply)
|
|
923
|
+
|
|
853
924
|
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
854
925
|
if sp.issparse(y):
|
|
855
926
|
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
@@ -862,7 +933,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
862
933
|
if not self.bootstrap and self.oob_score:
|
|
863
934
|
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
864
935
|
|
|
865
|
-
if sklearn_check_version("1.
|
|
936
|
+
if not sklearn_check_version("1.2") and self.criterion == "mse":
|
|
866
937
|
warnings.warn(
|
|
867
938
|
"Criterion 'mse' was deprecated in v1.0 and will be "
|
|
868
939
|
"removed in version 1.2. Use `criterion='squared_error'` "
|
|
@@ -913,14 +984,24 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
913
984
|
)
|
|
914
985
|
|
|
915
986
|
if patching_status.get_status():
|
|
916
|
-
|
|
917
|
-
X,
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
987
|
+
if sklearn_check_version("1.6"):
|
|
988
|
+
X, y = check_X_y(
|
|
989
|
+
X,
|
|
990
|
+
y,
|
|
991
|
+
multi_output=True,
|
|
992
|
+
accept_sparse=True,
|
|
993
|
+
dtype=[np.float64, np.float32],
|
|
994
|
+
ensure_all_finite=False,
|
|
995
|
+
)
|
|
996
|
+
else:
|
|
997
|
+
X, y = check_X_y(
|
|
998
|
+
X,
|
|
999
|
+
y,
|
|
1000
|
+
multi_output=True,
|
|
1001
|
+
accept_sparse=True,
|
|
1002
|
+
dtype=[np.float64, np.float32],
|
|
1003
|
+
force_all_finite=False,
|
|
1004
|
+
)
|
|
924
1005
|
|
|
925
1006
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
926
1007
|
warnings.warn(
|
|
@@ -995,7 +1076,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
995
1076
|
]
|
|
996
1077
|
)
|
|
997
1078
|
|
|
998
|
-
elif method_name
|
|
1079
|
+
elif method_name in ["predict", "score"]:
|
|
999
1080
|
X = data[0]
|
|
1000
1081
|
|
|
1001
1082
|
patching_status.and_conditions(
|
|
@@ -1045,11 +1126,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1045
1126
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1046
1127
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1047
1128
|
),
|
|
1129
|
+
(not self.oob_score, "oob_score value is not sklearn conformant."),
|
|
1048
1130
|
(sample_weight is None, "sample_weight is not supported."),
|
|
1049
1131
|
]
|
|
1050
1132
|
)
|
|
1051
1133
|
|
|
1052
|
-
elif method_name
|
|
1134
|
+
elif method_name in ["predict", "score"]:
|
|
1053
1135
|
X = data[0]
|
|
1054
1136
|
|
|
1055
1137
|
patching_status.and_conditions(
|
|
@@ -1082,23 +1164,33 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1082
1164
|
return patching_status
|
|
1083
1165
|
|
|
1084
1166
|
def _onedal_predict(self, X, queue=None):
|
|
1085
|
-
X = check_array(
|
|
1086
|
-
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1087
|
-
) # Warning, order of dtype matters
|
|
1088
1167
|
check_is_fitted(self, "_onedal_estimator")
|
|
1168
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
1089
1169
|
|
|
1090
|
-
if
|
|
1091
|
-
|
|
1170
|
+
if not use_raw_input:
|
|
1171
|
+
X = validate_data(
|
|
1172
|
+
self,
|
|
1173
|
+
X,
|
|
1174
|
+
dtype=[np.float64, np.float32],
|
|
1175
|
+
ensure_all_finite=False,
|
|
1176
|
+
reset=False,
|
|
1177
|
+
ensure_2d=True,
|
|
1178
|
+
) # Warning, order of dtype matters
|
|
1092
1179
|
|
|
1093
1180
|
return self._onedal_estimator.predict(X, queue=queue)
|
|
1094
1181
|
|
|
1182
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
1183
|
+
return r2_score(
|
|
1184
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1095
1187
|
def fit(self, X, y, sample_weight=None):
|
|
1096
1188
|
dispatch(
|
|
1097
1189
|
self,
|
|
1098
1190
|
"fit",
|
|
1099
1191
|
{
|
|
1100
1192
|
"onedal": self.__class__._onedal_fit,
|
|
1101
|
-
"sklearn":
|
|
1193
|
+
"sklearn": _sklearn_ForestRegressor.fit,
|
|
1102
1194
|
},
|
|
1103
1195
|
X,
|
|
1104
1196
|
y,
|
|
@@ -1108,28 +1200,46 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1108
1200
|
|
|
1109
1201
|
@wrap_output_data
|
|
1110
1202
|
def predict(self, X):
|
|
1203
|
+
check_is_fitted(self)
|
|
1111
1204
|
return dispatch(
|
|
1112
1205
|
self,
|
|
1113
1206
|
"predict",
|
|
1114
1207
|
{
|
|
1115
1208
|
"onedal": self.__class__._onedal_predict,
|
|
1116
|
-
"sklearn":
|
|
1209
|
+
"sklearn": _sklearn_ForestRegressor.predict,
|
|
1117
1210
|
},
|
|
1118
1211
|
X,
|
|
1119
1212
|
)
|
|
1120
1213
|
|
|
1121
|
-
|
|
1122
|
-
|
|
1214
|
+
@wrap_output_data
|
|
1215
|
+
def score(self, X, y, sample_weight=None):
|
|
1216
|
+
check_is_fitted(self)
|
|
1217
|
+
return dispatch(
|
|
1218
|
+
self,
|
|
1219
|
+
"score",
|
|
1220
|
+
{
|
|
1221
|
+
"onedal": self.__class__._onedal_score,
|
|
1222
|
+
"sklearn": _sklearn_ForestRegressor.score,
|
|
1223
|
+
},
|
|
1224
|
+
X,
|
|
1225
|
+
y,
|
|
1226
|
+
sample_weight=sample_weight,
|
|
1227
|
+
)
|
|
1123
1228
|
|
|
1229
|
+
fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
|
|
1230
|
+
predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
|
|
1231
|
+
score.__doc__ = _sklearn_ForestRegressor.score.__doc__
|
|
1124
1232
|
|
|
1233
|
+
|
|
1234
|
+
@register_hyperparameters({"predict": ("decision_forest", "infer")})
|
|
1125
1235
|
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1126
1236
|
class RandomForestClassifier(ForestClassifier):
|
|
1127
|
-
__doc__ =
|
|
1237
|
+
__doc__ = _sklearn_RandomForestClassifier.__doc__
|
|
1128
1238
|
_onedal_factory = onedal_RandomForestClassifier
|
|
1129
1239
|
|
|
1130
1240
|
if sklearn_check_version("1.2"):
|
|
1131
1241
|
_parameter_constraints: dict = {
|
|
1132
|
-
**
|
|
1242
|
+
**_sklearn_RandomForestClassifier._parameter_constraints,
|
|
1133
1243
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1134
1244
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1135
1245
|
}
|
|
@@ -1200,69 +1310,6 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1200
1310
|
self.min_bin_size = min_bin_size
|
|
1201
1311
|
self.monotonic_cst = monotonic_cst
|
|
1202
1312
|
|
|
1203
|
-
elif sklearn_check_version("1.0"):
|
|
1204
|
-
|
|
1205
|
-
def __init__(
|
|
1206
|
-
self,
|
|
1207
|
-
n_estimators=100,
|
|
1208
|
-
*,
|
|
1209
|
-
criterion="gini",
|
|
1210
|
-
max_depth=None,
|
|
1211
|
-
min_samples_split=2,
|
|
1212
|
-
min_samples_leaf=1,
|
|
1213
|
-
min_weight_fraction_leaf=0.0,
|
|
1214
|
-
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1215
|
-
max_leaf_nodes=None,
|
|
1216
|
-
min_impurity_decrease=0.0,
|
|
1217
|
-
bootstrap=True,
|
|
1218
|
-
oob_score=False,
|
|
1219
|
-
n_jobs=None,
|
|
1220
|
-
random_state=None,
|
|
1221
|
-
verbose=0,
|
|
1222
|
-
warm_start=False,
|
|
1223
|
-
class_weight=None,
|
|
1224
|
-
ccp_alpha=0.0,
|
|
1225
|
-
max_samples=None,
|
|
1226
|
-
max_bins=256,
|
|
1227
|
-
min_bin_size=1,
|
|
1228
|
-
):
|
|
1229
|
-
super().__init__(
|
|
1230
|
-
DecisionTreeClassifier(),
|
|
1231
|
-
n_estimators,
|
|
1232
|
-
estimator_params=(
|
|
1233
|
-
"criterion",
|
|
1234
|
-
"max_depth",
|
|
1235
|
-
"min_samples_split",
|
|
1236
|
-
"min_samples_leaf",
|
|
1237
|
-
"min_weight_fraction_leaf",
|
|
1238
|
-
"max_features",
|
|
1239
|
-
"max_leaf_nodes",
|
|
1240
|
-
"min_impurity_decrease",
|
|
1241
|
-
"random_state",
|
|
1242
|
-
"ccp_alpha",
|
|
1243
|
-
),
|
|
1244
|
-
bootstrap=bootstrap,
|
|
1245
|
-
oob_score=oob_score,
|
|
1246
|
-
n_jobs=n_jobs,
|
|
1247
|
-
random_state=random_state,
|
|
1248
|
-
verbose=verbose,
|
|
1249
|
-
warm_start=warm_start,
|
|
1250
|
-
class_weight=class_weight,
|
|
1251
|
-
max_samples=max_samples,
|
|
1252
|
-
)
|
|
1253
|
-
|
|
1254
|
-
self.criterion = criterion
|
|
1255
|
-
self.max_depth = max_depth
|
|
1256
|
-
self.min_samples_split = min_samples_split
|
|
1257
|
-
self.min_samples_leaf = min_samples_leaf
|
|
1258
|
-
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1259
|
-
self.max_features = max_features
|
|
1260
|
-
self.max_leaf_nodes = max_leaf_nodes
|
|
1261
|
-
self.min_impurity_decrease = min_impurity_decrease
|
|
1262
|
-
self.ccp_alpha = ccp_alpha
|
|
1263
|
-
self.max_bins = max_bins
|
|
1264
|
-
self.min_bin_size = min_bin_size
|
|
1265
|
-
|
|
1266
1313
|
else:
|
|
1267
1314
|
|
|
1268
1315
|
def __init__(
|
|
@@ -1274,10 +1321,9 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1274
1321
|
min_samples_split=2,
|
|
1275
1322
|
min_samples_leaf=1,
|
|
1276
1323
|
min_weight_fraction_leaf=0.0,
|
|
1277
|
-
max_features="auto",
|
|
1324
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1278
1325
|
max_leaf_nodes=None,
|
|
1279
1326
|
min_impurity_decrease=0.0,
|
|
1280
|
-
min_impurity_split=None,
|
|
1281
1327
|
bootstrap=True,
|
|
1282
1328
|
oob_score=False,
|
|
1283
1329
|
n_jobs=None,
|
|
@@ -1302,7 +1348,6 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1302
1348
|
"max_features",
|
|
1303
1349
|
"max_leaf_nodes",
|
|
1304
1350
|
"min_impurity_decrease",
|
|
1305
|
-
"min_impurity_split",
|
|
1306
1351
|
"random_state",
|
|
1307
1352
|
"ccp_alpha",
|
|
1308
1353
|
),
|
|
@@ -1324,22 +1369,19 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1324
1369
|
self.max_features = max_features
|
|
1325
1370
|
self.max_leaf_nodes = max_leaf_nodes
|
|
1326
1371
|
self.min_impurity_decrease = min_impurity_decrease
|
|
1327
|
-
self.min_impurity_split = min_impurity_split
|
|
1328
1372
|
self.ccp_alpha = ccp_alpha
|
|
1329
1373
|
self.max_bins = max_bins
|
|
1330
1374
|
self.min_bin_size = min_bin_size
|
|
1331
|
-
self.max_bins = max_bins
|
|
1332
|
-
self.min_bin_size = min_bin_size
|
|
1333
1375
|
|
|
1334
1376
|
|
|
1335
|
-
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1377
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1336
1378
|
class RandomForestRegressor(ForestRegressor):
|
|
1337
|
-
__doc__ =
|
|
1379
|
+
__doc__ = _sklearn_RandomForestRegressor.__doc__
|
|
1338
1380
|
_onedal_factory = onedal_RandomForestRegressor
|
|
1339
1381
|
|
|
1340
1382
|
if sklearn_check_version("1.2"):
|
|
1341
1383
|
_parameter_constraints: dict = {
|
|
1342
|
-
**
|
|
1384
|
+
**_sklearn_RandomForestRegressor._parameter_constraints,
|
|
1343
1385
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1344
1386
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1345
1387
|
}
|
|
@@ -1408,7 +1450,7 @@ class RandomForestRegressor(ForestRegressor):
|
|
|
1408
1450
|
self.min_bin_size = min_bin_size
|
|
1409
1451
|
self.monotonic_cst = monotonic_cst
|
|
1410
1452
|
|
|
1411
|
-
|
|
1453
|
+
else:
|
|
1412
1454
|
|
|
1413
1455
|
def __init__(
|
|
1414
1456
|
self,
|
|
@@ -1469,78 +1511,15 @@ class RandomForestRegressor(ForestRegressor):
|
|
|
1469
1511
|
self.max_bins = max_bins
|
|
1470
1512
|
self.min_bin_size = min_bin_size
|
|
1471
1513
|
|
|
1472
|
-
else:
|
|
1473
|
-
|
|
1474
|
-
def __init__(
|
|
1475
|
-
self,
|
|
1476
|
-
n_estimators=100,
|
|
1477
|
-
*,
|
|
1478
|
-
criterion="mse",
|
|
1479
|
-
max_depth=None,
|
|
1480
|
-
min_samples_split=2,
|
|
1481
|
-
min_samples_leaf=1,
|
|
1482
|
-
min_weight_fraction_leaf=0.0,
|
|
1483
|
-
max_features="auto",
|
|
1484
|
-
max_leaf_nodes=None,
|
|
1485
|
-
min_impurity_decrease=0.0,
|
|
1486
|
-
min_impurity_split=None,
|
|
1487
|
-
bootstrap=True,
|
|
1488
|
-
oob_score=False,
|
|
1489
|
-
n_jobs=None,
|
|
1490
|
-
random_state=None,
|
|
1491
|
-
verbose=0,
|
|
1492
|
-
warm_start=False,
|
|
1493
|
-
ccp_alpha=0.0,
|
|
1494
|
-
max_samples=None,
|
|
1495
|
-
max_bins=256,
|
|
1496
|
-
min_bin_size=1,
|
|
1497
|
-
):
|
|
1498
|
-
super().__init__(
|
|
1499
|
-
DecisionTreeRegressor(),
|
|
1500
|
-
n_estimators=n_estimators,
|
|
1501
|
-
estimator_params=(
|
|
1502
|
-
"criterion",
|
|
1503
|
-
"max_depth",
|
|
1504
|
-
"min_samples_split",
|
|
1505
|
-
"min_samples_leaf",
|
|
1506
|
-
"min_weight_fraction_leaf",
|
|
1507
|
-
"max_features",
|
|
1508
|
-
"max_leaf_nodes",
|
|
1509
|
-
"min_impurity_decrease",
|
|
1510
|
-
"min_impurity_split" "random_state",
|
|
1511
|
-
"ccp_alpha",
|
|
1512
|
-
),
|
|
1513
|
-
bootstrap=bootstrap,
|
|
1514
|
-
oob_score=oob_score,
|
|
1515
|
-
n_jobs=n_jobs,
|
|
1516
|
-
random_state=random_state,
|
|
1517
|
-
verbose=verbose,
|
|
1518
|
-
warm_start=warm_start,
|
|
1519
|
-
max_samples=max_samples,
|
|
1520
|
-
)
|
|
1521
|
-
|
|
1522
|
-
self.criterion = criterion
|
|
1523
|
-
self.max_depth = max_depth
|
|
1524
|
-
self.min_samples_split = min_samples_split
|
|
1525
|
-
self.min_samples_leaf = min_samples_leaf
|
|
1526
|
-
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1527
|
-
self.max_features = max_features
|
|
1528
|
-
self.max_leaf_nodes = max_leaf_nodes
|
|
1529
|
-
self.min_impurity_decrease = min_impurity_decrease
|
|
1530
|
-
self.min_impurity_split = min_impurity_split
|
|
1531
|
-
self.ccp_alpha = ccp_alpha
|
|
1532
|
-
self.max_bins = max_bins
|
|
1533
|
-
self.min_bin_size = min_bin_size
|
|
1534
|
-
|
|
1535
1514
|
|
|
1536
1515
|
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1537
1516
|
class ExtraTreesClassifier(ForestClassifier):
|
|
1538
|
-
__doc__ =
|
|
1517
|
+
__doc__ = _sklearn_ExtraTreesClassifier.__doc__
|
|
1539
1518
|
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1540
1519
|
|
|
1541
1520
|
if sklearn_check_version("1.2"):
|
|
1542
1521
|
_parameter_constraints: dict = {
|
|
1543
|
-
**
|
|
1522
|
+
**_sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1544
1523
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1545
1524
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1546
1525
|
}
|
|
@@ -1611,69 +1590,6 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1611
1590
|
self.min_bin_size = min_bin_size
|
|
1612
1591
|
self.monotonic_cst = monotonic_cst
|
|
1613
1592
|
|
|
1614
|
-
elif sklearn_check_version("1.0"):
|
|
1615
|
-
|
|
1616
|
-
def __init__(
|
|
1617
|
-
self,
|
|
1618
|
-
n_estimators=100,
|
|
1619
|
-
*,
|
|
1620
|
-
criterion="gini",
|
|
1621
|
-
max_depth=None,
|
|
1622
|
-
min_samples_split=2,
|
|
1623
|
-
min_samples_leaf=1,
|
|
1624
|
-
min_weight_fraction_leaf=0.0,
|
|
1625
|
-
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1626
|
-
max_leaf_nodes=None,
|
|
1627
|
-
min_impurity_decrease=0.0,
|
|
1628
|
-
bootstrap=False,
|
|
1629
|
-
oob_score=False,
|
|
1630
|
-
n_jobs=None,
|
|
1631
|
-
random_state=None,
|
|
1632
|
-
verbose=0,
|
|
1633
|
-
warm_start=False,
|
|
1634
|
-
class_weight=None,
|
|
1635
|
-
ccp_alpha=0.0,
|
|
1636
|
-
max_samples=None,
|
|
1637
|
-
max_bins=256,
|
|
1638
|
-
min_bin_size=1,
|
|
1639
|
-
):
|
|
1640
|
-
super().__init__(
|
|
1641
|
-
ExtraTreeClassifier(),
|
|
1642
|
-
n_estimators,
|
|
1643
|
-
estimator_params=(
|
|
1644
|
-
"criterion",
|
|
1645
|
-
"max_depth",
|
|
1646
|
-
"min_samples_split",
|
|
1647
|
-
"min_samples_leaf",
|
|
1648
|
-
"min_weight_fraction_leaf",
|
|
1649
|
-
"max_features",
|
|
1650
|
-
"max_leaf_nodes",
|
|
1651
|
-
"min_impurity_decrease",
|
|
1652
|
-
"random_state",
|
|
1653
|
-
"ccp_alpha",
|
|
1654
|
-
),
|
|
1655
|
-
bootstrap=bootstrap,
|
|
1656
|
-
oob_score=oob_score,
|
|
1657
|
-
n_jobs=n_jobs,
|
|
1658
|
-
random_state=random_state,
|
|
1659
|
-
verbose=verbose,
|
|
1660
|
-
warm_start=warm_start,
|
|
1661
|
-
class_weight=class_weight,
|
|
1662
|
-
max_samples=max_samples,
|
|
1663
|
-
)
|
|
1664
|
-
|
|
1665
|
-
self.criterion = criterion
|
|
1666
|
-
self.max_depth = max_depth
|
|
1667
|
-
self.min_samples_split = min_samples_split
|
|
1668
|
-
self.min_samples_leaf = min_samples_leaf
|
|
1669
|
-
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1670
|
-
self.max_features = max_features
|
|
1671
|
-
self.max_leaf_nodes = max_leaf_nodes
|
|
1672
|
-
self.min_impurity_decrease = min_impurity_decrease
|
|
1673
|
-
self.ccp_alpha = ccp_alpha
|
|
1674
|
-
self.max_bins = max_bins
|
|
1675
|
-
self.min_bin_size = min_bin_size
|
|
1676
|
-
|
|
1677
1593
|
else:
|
|
1678
1594
|
|
|
1679
1595
|
def __init__(
|
|
@@ -1685,10 +1601,9 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1685
1601
|
min_samples_split=2,
|
|
1686
1602
|
min_samples_leaf=1,
|
|
1687
1603
|
min_weight_fraction_leaf=0.0,
|
|
1688
|
-
max_features="auto",
|
|
1604
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1689
1605
|
max_leaf_nodes=None,
|
|
1690
1606
|
min_impurity_decrease=0.0,
|
|
1691
|
-
min_impurity_split=None,
|
|
1692
1607
|
bootstrap=False,
|
|
1693
1608
|
oob_score=False,
|
|
1694
1609
|
n_jobs=None,
|
|
@@ -1713,7 +1628,6 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1713
1628
|
"max_features",
|
|
1714
1629
|
"max_leaf_nodes",
|
|
1715
1630
|
"min_impurity_decrease",
|
|
1716
|
-
"min_impurity_split",
|
|
1717
1631
|
"random_state",
|
|
1718
1632
|
"ccp_alpha",
|
|
1719
1633
|
),
|
|
@@ -1735,22 +1649,19 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1735
1649
|
self.max_features = max_features
|
|
1736
1650
|
self.max_leaf_nodes = max_leaf_nodes
|
|
1737
1651
|
self.min_impurity_decrease = min_impurity_decrease
|
|
1738
|
-
self.min_impurity_split = min_impurity_split
|
|
1739
1652
|
self.ccp_alpha = ccp_alpha
|
|
1740
1653
|
self.max_bins = max_bins
|
|
1741
1654
|
self.min_bin_size = min_bin_size
|
|
1742
|
-
self.max_bins = max_bins
|
|
1743
|
-
self.min_bin_size = min_bin_size
|
|
1744
1655
|
|
|
1745
1656
|
|
|
1746
|
-
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1657
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1747
1658
|
class ExtraTreesRegressor(ForestRegressor):
|
|
1748
|
-
__doc__ =
|
|
1659
|
+
__doc__ = _sklearn_ExtraTreesRegressor.__doc__
|
|
1749
1660
|
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1750
1661
|
|
|
1751
1662
|
if sklearn_check_version("1.2"):
|
|
1752
1663
|
_parameter_constraints: dict = {
|
|
1753
|
-
**
|
|
1664
|
+
**_sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1754
1665
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1755
1666
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1756
1667
|
}
|
|
@@ -1819,7 +1730,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1819
1730
|
self.min_bin_size = min_bin_size
|
|
1820
1731
|
self.monotonic_cst = monotonic_cst
|
|
1821
1732
|
|
|
1822
|
-
|
|
1733
|
+
else:
|
|
1823
1734
|
|
|
1824
1735
|
def __init__(
|
|
1825
1736
|
self,
|
|
@@ -1880,72 +1791,9 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1880
1791
|
self.max_bins = max_bins
|
|
1881
1792
|
self.min_bin_size = min_bin_size
|
|
1882
1793
|
|
|
1883
|
-
else:
|
|
1884
|
-
|
|
1885
|
-
def __init__(
|
|
1886
|
-
self,
|
|
1887
|
-
n_estimators=100,
|
|
1888
|
-
*,
|
|
1889
|
-
criterion="mse",
|
|
1890
|
-
max_depth=None,
|
|
1891
|
-
min_samples_split=2,
|
|
1892
|
-
min_samples_leaf=1,
|
|
1893
|
-
min_weight_fraction_leaf=0.0,
|
|
1894
|
-
max_features="auto",
|
|
1895
|
-
max_leaf_nodes=None,
|
|
1896
|
-
min_impurity_decrease=0.0,
|
|
1897
|
-
min_impurity_split=None,
|
|
1898
|
-
bootstrap=False,
|
|
1899
|
-
oob_score=False,
|
|
1900
|
-
n_jobs=None,
|
|
1901
|
-
random_state=None,
|
|
1902
|
-
verbose=0,
|
|
1903
|
-
warm_start=False,
|
|
1904
|
-
ccp_alpha=0.0,
|
|
1905
|
-
max_samples=None,
|
|
1906
|
-
max_bins=256,
|
|
1907
|
-
min_bin_size=1,
|
|
1908
|
-
):
|
|
1909
|
-
super().__init__(
|
|
1910
|
-
ExtraTreeRegressor(),
|
|
1911
|
-
n_estimators=n_estimators,
|
|
1912
|
-
estimator_params=(
|
|
1913
|
-
"criterion",
|
|
1914
|
-
"max_depth",
|
|
1915
|
-
"min_samples_split",
|
|
1916
|
-
"min_samples_leaf",
|
|
1917
|
-
"min_weight_fraction_leaf",
|
|
1918
|
-
"max_features",
|
|
1919
|
-
"max_leaf_nodes",
|
|
1920
|
-
"min_impurity_decrease",
|
|
1921
|
-
"min_impurity_split" "random_state",
|
|
1922
|
-
"ccp_alpha",
|
|
1923
|
-
),
|
|
1924
|
-
bootstrap=bootstrap,
|
|
1925
|
-
oob_score=oob_score,
|
|
1926
|
-
n_jobs=n_jobs,
|
|
1927
|
-
random_state=random_state,
|
|
1928
|
-
verbose=verbose,
|
|
1929
|
-
warm_start=warm_start,
|
|
1930
|
-
max_samples=max_samples,
|
|
1931
|
-
)
|
|
1932
|
-
|
|
1933
|
-
self.criterion = criterion
|
|
1934
|
-
self.max_depth = max_depth
|
|
1935
|
-
self.min_samples_split = min_samples_split
|
|
1936
|
-
self.min_samples_leaf = min_samples_leaf
|
|
1937
|
-
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1938
|
-
self.max_features = max_features
|
|
1939
|
-
self.max_leaf_nodes = max_leaf_nodes
|
|
1940
|
-
self.min_impurity_decrease = min_impurity_decrease
|
|
1941
|
-
self.min_impurity_split = min_impurity_split
|
|
1942
|
-
self.ccp_alpha = ccp_alpha
|
|
1943
|
-
self.max_bins = max_bins
|
|
1944
|
-
self.min_bin_size = min_bin_size
|
|
1945
|
-
|
|
1946
1794
|
|
|
1947
1795
|
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1796
|
+
_sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
1797
|
+
_sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
1798
|
+
_sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
1799
|
+
_sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|