scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
- scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
- scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
- scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1799 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
import warnings
|
|
19
|
+
from abc import ABC
|
|
20
|
+
from collections.abc import Iterable
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from scipy import sparse as sp
|
|
24
|
+
from sklearn.base import BaseEstimator, clone
|
|
25
|
+
from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
|
|
26
|
+
from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
|
|
27
|
+
from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
|
|
28
|
+
from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
|
|
29
|
+
from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
|
|
30
|
+
from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
|
|
31
|
+
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
32
|
+
from sklearn.exceptions import DataConversionWarning
|
|
33
|
+
from sklearn.metrics import accuracy_score, r2_score
|
|
34
|
+
from sklearn.tree import (
|
|
35
|
+
DecisionTreeClassifier,
|
|
36
|
+
DecisionTreeRegressor,
|
|
37
|
+
ExtraTreeClassifier,
|
|
38
|
+
ExtraTreeRegressor,
|
|
39
|
+
)
|
|
40
|
+
from sklearn.tree._tree import Tree
|
|
41
|
+
from sklearn.utils import check_random_state, deprecated
|
|
42
|
+
from sklearn.utils.validation import (
|
|
43
|
+
_check_sample_weight,
|
|
44
|
+
check_array,
|
|
45
|
+
check_is_fitted,
|
|
46
|
+
check_X_y,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
50
|
+
from daal4py.sklearn._utils import (
|
|
51
|
+
check_tree_nodes,
|
|
52
|
+
daal_check_version,
|
|
53
|
+
sklearn_check_version,
|
|
54
|
+
)
|
|
55
|
+
from onedal._device_offload import support_input_format
|
|
56
|
+
from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
57
|
+
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
58
|
+
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
59
|
+
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
60
|
+
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
61
|
+
from onedal.utils.validation import _num_features, _num_samples
|
|
62
|
+
from sklearnex._utils import register_hyperparameters
|
|
63
|
+
|
|
64
|
+
from .._config import get_config
|
|
65
|
+
from .._device_offload import dispatch, wrap_output_data
|
|
66
|
+
from .._utils import PatchingConditionsChain
|
|
67
|
+
from ..base import oneDALEstimator
|
|
68
|
+
from ..utils._array_api import get_namespace
|
|
69
|
+
from ..utils.validation import check_n_features, validate_data
|
|
70
|
+
|
|
71
|
+
if sklearn_check_version("1.2"):
|
|
72
|
+
from sklearn.utils._param_validation import Interval
|
|
73
|
+
if sklearn_check_version("1.4"):
|
|
74
|
+
from daal4py.sklearn.utils import _assert_all_finite
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class BaseForest(oneDALEstimator, ABC):
|
|
78
|
+
_onedal_factory = None
|
|
79
|
+
|
|
80
|
+
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
81
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
82
|
+
xp, _ = get_namespace(X)
|
|
83
|
+
if not use_raw_input:
|
|
84
|
+
X, y = validate_data(
|
|
85
|
+
self,
|
|
86
|
+
X,
|
|
87
|
+
y,
|
|
88
|
+
multi_output=True,
|
|
89
|
+
accept_sparse=False,
|
|
90
|
+
dtype=[np.float64, np.float32],
|
|
91
|
+
ensure_all_finite=False,
|
|
92
|
+
ensure_2d=True,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if sample_weight is not None:
|
|
96
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
97
|
+
|
|
98
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
99
|
+
warnings.warn(
|
|
100
|
+
"A column-vector y was passed when a 1d array was"
|
|
101
|
+
" expected. Please change the shape of y to "
|
|
102
|
+
"(n_samples,), for example using ravel().",
|
|
103
|
+
DataConversionWarning,
|
|
104
|
+
stacklevel=2,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if y.ndim == 1:
|
|
108
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
109
|
+
# [:, np.newaxis] that does not.
|
|
110
|
+
y = xp.reshape(y, (-1, 1))
|
|
111
|
+
|
|
112
|
+
self._n_samples, self.n_outputs_ = y.shape
|
|
113
|
+
|
|
114
|
+
if not use_raw_input:
|
|
115
|
+
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
116
|
+
|
|
117
|
+
if expanded_class_weight is not None:
|
|
118
|
+
if sample_weight is not None:
|
|
119
|
+
sample_weight = sample_weight * expanded_class_weight
|
|
120
|
+
else:
|
|
121
|
+
sample_weight = expanded_class_weight
|
|
122
|
+
if sample_weight is not None:
|
|
123
|
+
sample_weight = [sample_weight]
|
|
124
|
+
else:
|
|
125
|
+
# try catch needed for raw_inputs + array_api data where unlike
|
|
126
|
+
# numpy the way to yield unique values is via `unique_values`
|
|
127
|
+
# This should be removed when refactored for gpu zero-copy
|
|
128
|
+
try:
|
|
129
|
+
self.classes_ = xp.unique(y)
|
|
130
|
+
except AttributeError:
|
|
131
|
+
self.classes_ = xp.unique_values(y)
|
|
132
|
+
self.n_classes_ = len(self.classes_)
|
|
133
|
+
self.n_features_in_ = X.shape[1]
|
|
134
|
+
|
|
135
|
+
onedal_params = {
|
|
136
|
+
"n_estimators": self.n_estimators,
|
|
137
|
+
"criterion": self.criterion,
|
|
138
|
+
"max_depth": self.max_depth,
|
|
139
|
+
"min_samples_split": self.min_samples_split,
|
|
140
|
+
"min_samples_leaf": self.min_samples_leaf,
|
|
141
|
+
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
142
|
+
"max_features": self._to_absolute_max_features(
|
|
143
|
+
self.max_features, self.n_features_in_
|
|
144
|
+
),
|
|
145
|
+
"max_leaf_nodes": self.max_leaf_nodes,
|
|
146
|
+
"min_impurity_decrease": self.min_impurity_decrease,
|
|
147
|
+
"bootstrap": self.bootstrap,
|
|
148
|
+
"oob_score": self.oob_score,
|
|
149
|
+
"n_jobs": self.n_jobs,
|
|
150
|
+
"random_state": self.random_state,
|
|
151
|
+
"verbose": self.verbose,
|
|
152
|
+
"warm_start": self.warm_start,
|
|
153
|
+
"error_metric_mode": self._err if self.oob_score else "none",
|
|
154
|
+
"variable_importance_mode": "mdi",
|
|
155
|
+
"class_weight": self.class_weight,
|
|
156
|
+
"max_bins": self.max_bins,
|
|
157
|
+
"min_bin_size": self.min_bin_size,
|
|
158
|
+
"max_samples": self.max_samples,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
onedal_params["min_impurity_split"] = None
|
|
162
|
+
|
|
163
|
+
# Lazy evaluation of estimators_
|
|
164
|
+
self._cached_estimators_ = None
|
|
165
|
+
|
|
166
|
+
# Compute
|
|
167
|
+
self._onedal_estimator = self._onedal_factory(**onedal_params)
|
|
168
|
+
self._onedal_estimator.fit(X, xp.reshape(y, (-1,)), sample_weight, queue=queue)
|
|
169
|
+
|
|
170
|
+
self._save_attributes()
|
|
171
|
+
|
|
172
|
+
# Decapsulate classes_ attributes
|
|
173
|
+
if hasattr(self, "classes_") and self.n_outputs_ == 1:
|
|
174
|
+
self.n_classes_ = (
|
|
175
|
+
self.n_classes_[0]
|
|
176
|
+
if isinstance(self.n_classes_, Iterable)
|
|
177
|
+
else self.n_classes_
|
|
178
|
+
)
|
|
179
|
+
self.classes_ = (
|
|
180
|
+
self.classes_[0]
|
|
181
|
+
if isinstance(self.classes_[0], Iterable)
|
|
182
|
+
else self.classes_
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return self
|
|
186
|
+
|
|
187
|
+
def _save_attributes(self):
|
|
188
|
+
if self.oob_score:
|
|
189
|
+
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
190
|
+
if hasattr(self._onedal_estimator, "oob_prediction_"):
|
|
191
|
+
self.oob_prediction_ = self._onedal_estimator.oob_prediction_
|
|
192
|
+
if hasattr(self._onedal_estimator, "oob_decision_function_"):
|
|
193
|
+
self.oob_decision_function_ = (
|
|
194
|
+
self._onedal_estimator.oob_decision_function_
|
|
195
|
+
)
|
|
196
|
+
if self.bootstrap:
|
|
197
|
+
self._n_samples_bootstrap = max(
|
|
198
|
+
round(
|
|
199
|
+
self._onedal_estimator.observations_per_tree_fraction
|
|
200
|
+
* self._n_samples
|
|
201
|
+
),
|
|
202
|
+
1,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
self._n_samples_bootstrap = None
|
|
206
|
+
self._validate_estimator()
|
|
207
|
+
return self
|
|
208
|
+
|
|
209
|
+
def _to_absolute_max_features(self, max_features, n_features):
|
|
210
|
+
if max_features is None:
|
|
211
|
+
return n_features
|
|
212
|
+
if isinstance(max_features, str):
|
|
213
|
+
if max_features == "auto":
|
|
214
|
+
if not sklearn_check_version("1.3"):
|
|
215
|
+
if sklearn_check_version("1.1"):
|
|
216
|
+
warnings.warn(
|
|
217
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
218
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
219
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
220
|
+
"parameter as it is also the default value for "
|
|
221
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
222
|
+
FutureWarning,
|
|
223
|
+
)
|
|
224
|
+
return (
|
|
225
|
+
max(1, int(np.sqrt(n_features)))
|
|
226
|
+
if isinstance(self, ForestClassifier)
|
|
227
|
+
else n_features
|
|
228
|
+
)
|
|
229
|
+
if max_features == "sqrt":
|
|
230
|
+
return max(1, int(np.sqrt(n_features)))
|
|
231
|
+
if max_features == "log2":
|
|
232
|
+
return max(1, int(np.log2(n_features)))
|
|
233
|
+
allowed_string_values = (
|
|
234
|
+
'"sqrt" or "log2"'
|
|
235
|
+
if sklearn_check_version("1.3")
|
|
236
|
+
else '"auto", "sqrt" or "log2"'
|
|
237
|
+
)
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"Invalid value for max_features. Allowed string "
|
|
240
|
+
f"values are {allowed_string_values}."
|
|
241
|
+
)
|
|
242
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
243
|
+
return max_features
|
|
244
|
+
if max_features > 0.0:
|
|
245
|
+
return max(1, int(max_features * n_features))
|
|
246
|
+
return 0
|
|
247
|
+
|
|
248
|
+
def _check_parameters(self):
|
|
249
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
250
|
+
if not 1 <= self.min_samples_leaf:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"min_samples_leaf must be at least 1 "
|
|
253
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
254
|
+
)
|
|
255
|
+
else: # float
|
|
256
|
+
if not 0.0 < self.min_samples_leaf <= 0.5:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
"min_samples_leaf must be at least 1 "
|
|
259
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
260
|
+
)
|
|
261
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
262
|
+
if not 2 <= self.min_samples_split:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"min_samples_split must be an integer "
|
|
265
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
266
|
+
"got the integer %s" % self.min_samples_split
|
|
267
|
+
)
|
|
268
|
+
else: # float
|
|
269
|
+
if not 0.0 < self.min_samples_split <= 1.0:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
"min_samples_split must be an integer "
|
|
272
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
273
|
+
"got the float %s" % self.min_samples_split
|
|
274
|
+
)
|
|
275
|
+
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
|
|
276
|
+
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
|
|
277
|
+
if hasattr(self, "min_impurity_split"):
|
|
278
|
+
warnings.warn(
|
|
279
|
+
"The min_impurity_split parameter is deprecated. "
|
|
280
|
+
"Its default value has changed from 1e-7 to 0 in "
|
|
281
|
+
"version 0.23, and it will be removed in 0.25. "
|
|
282
|
+
"Use the min_impurity_decrease parameter instead.",
|
|
283
|
+
FutureWarning,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
if getattr(self, "min_impurity_split") < 0.0:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
"min_impurity_split must be greater than " "or equal to 0"
|
|
289
|
+
)
|
|
290
|
+
if self.min_impurity_decrease < 0.0:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
"min_impurity_decrease must be greater than " "or equal to 0"
|
|
293
|
+
)
|
|
294
|
+
if self.max_leaf_nodes is not None:
|
|
295
|
+
if not isinstance(self.max_leaf_nodes, numbers.Integral):
|
|
296
|
+
raise ValueError(
|
|
297
|
+
"max_leaf_nodes must be integral number but was "
|
|
298
|
+
"%r" % self.max_leaf_nodes
|
|
299
|
+
)
|
|
300
|
+
if self.max_leaf_nodes < 2:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
("max_leaf_nodes {0} must be either None " "or larger than 1").format(
|
|
303
|
+
self.max_leaf_nodes
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
if isinstance(self.max_bins, numbers.Integral):
|
|
307
|
+
if not 2 <= self.max_bins:
|
|
308
|
+
raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"max_bins must be integral number but was " "%r" % self.max_bins
|
|
312
|
+
)
|
|
313
|
+
if isinstance(self.min_bin_size, numbers.Integral):
|
|
314
|
+
if not 1 <= self.min_bin_size:
|
|
315
|
+
raise ValueError(
|
|
316
|
+
"min_bin_size must be at least 1, got %s" % self.min_bin_size
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def estimators_(self):
|
|
325
|
+
if hasattr(self, "_cached_estimators_"):
|
|
326
|
+
if self._cached_estimators_ is None:
|
|
327
|
+
self._estimators_()
|
|
328
|
+
return self._cached_estimators_
|
|
329
|
+
else:
|
|
330
|
+
raise AttributeError(
|
|
331
|
+
f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
@estimators_.setter
|
|
335
|
+
def estimators_(self, estimators):
|
|
336
|
+
# Needed to allow for proper sklearn operation in fallback mode
|
|
337
|
+
self._cached_estimators_ = estimators
|
|
338
|
+
|
|
339
|
+
def _estimators_(self):
|
|
340
|
+
# _estimators_ should only be called if _onedal_estimator exists
|
|
341
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
342
|
+
if hasattr(self, "n_classes_"):
|
|
343
|
+
n_classes_ = (
|
|
344
|
+
self.n_classes_
|
|
345
|
+
if isinstance(self.n_classes_, int)
|
|
346
|
+
else self.n_classes_[0]
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
n_classes_ = 1
|
|
350
|
+
|
|
351
|
+
# convert model to estimators
|
|
352
|
+
params = {
|
|
353
|
+
"criterion": self._onedal_estimator.criterion,
|
|
354
|
+
"max_depth": self._onedal_estimator.max_depth,
|
|
355
|
+
"min_samples_split": self._onedal_estimator.min_samples_split,
|
|
356
|
+
"min_samples_leaf": self._onedal_estimator.min_samples_leaf,
|
|
357
|
+
"min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
|
|
358
|
+
"max_features": self._onedal_estimator.max_features,
|
|
359
|
+
"max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
|
|
360
|
+
"min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
|
|
361
|
+
"random_state": None,
|
|
362
|
+
}
|
|
363
|
+
est = self.estimator.__class__(**params)
|
|
364
|
+
# we need to set est.tree_ field with Trees constructed from
|
|
365
|
+
# oneAPI Data Analytics Library solution
|
|
366
|
+
estimators_ = []
|
|
367
|
+
|
|
368
|
+
random_state_checked = check_random_state(self.random_state)
|
|
369
|
+
|
|
370
|
+
for i in range(self._onedal_estimator.n_estimators):
|
|
371
|
+
est_i = clone(est)
|
|
372
|
+
est_i.set_params(
|
|
373
|
+
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
374
|
+
)
|
|
375
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
376
|
+
est_i.n_outputs_ = self.n_outputs_
|
|
377
|
+
est_i.n_classes_ = n_classes_
|
|
378
|
+
tree_i_state_class = self._get_tree_state(
|
|
379
|
+
self._onedal_estimator._onedal_model, i, n_classes_
|
|
380
|
+
)
|
|
381
|
+
tree_i_state_dict = {
|
|
382
|
+
"max_depth": tree_i_state_class.max_depth,
|
|
383
|
+
"node_count": tree_i_state_class.node_count,
|
|
384
|
+
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
385
|
+
"values": tree_i_state_class.value_ar,
|
|
386
|
+
}
|
|
387
|
+
# Note: only on host.
|
|
388
|
+
est_i.tree_ = Tree(
|
|
389
|
+
self.n_features_in_,
|
|
390
|
+
np.array([n_classes_], dtype=np.intp),
|
|
391
|
+
self.n_outputs_,
|
|
392
|
+
)
|
|
393
|
+
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
394
|
+
estimators_.append(est_i)
|
|
395
|
+
|
|
396
|
+
self._cached_estimators_ = estimators_
|
|
397
|
+
|
|
398
|
+
if not sklearn_check_version("1.2"):
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def base_estimator(self):
|
|
402
|
+
return self.estimator
|
|
403
|
+
|
|
404
|
+
@base_estimator.setter
|
|
405
|
+
def base_estimator(self, estimator):
|
|
406
|
+
self.estimator = estimator
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
class ForestClassifier(BaseForest, _sklearn_ForestClassifier):
|
|
410
|
+
# Surprisingly, even though scikit-learn warns against using
|
|
411
|
+
# their ForestClassifier directly, it actually has a more stable
|
|
412
|
+
# API than the user-facing objects (over time). If they change it
|
|
413
|
+
# significantly at some point then this may need to be versioned.
|
|
414
|
+
|
|
415
|
+
_err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
|
|
416
|
+
_get_tree_state = staticmethod(get_tree_state_cls)
|
|
417
|
+
|
|
418
|
+
def __init__(
|
|
419
|
+
self,
|
|
420
|
+
estimator,
|
|
421
|
+
n_estimators=100,
|
|
422
|
+
*,
|
|
423
|
+
estimator_params=tuple(),
|
|
424
|
+
bootstrap=False,
|
|
425
|
+
oob_score=False,
|
|
426
|
+
n_jobs=None,
|
|
427
|
+
random_state=None,
|
|
428
|
+
verbose=0,
|
|
429
|
+
warm_start=False,
|
|
430
|
+
class_weight=None,
|
|
431
|
+
max_samples=None,
|
|
432
|
+
):
|
|
433
|
+
super().__init__(
|
|
434
|
+
estimator,
|
|
435
|
+
n_estimators=n_estimators,
|
|
436
|
+
estimator_params=estimator_params,
|
|
437
|
+
bootstrap=bootstrap,
|
|
438
|
+
oob_score=oob_score,
|
|
439
|
+
n_jobs=n_jobs,
|
|
440
|
+
random_state=random_state,
|
|
441
|
+
verbose=verbose,
|
|
442
|
+
warm_start=warm_start,
|
|
443
|
+
class_weight=class_weight,
|
|
444
|
+
max_samples=max_samples,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# The estimator is checked against the class attribute for conformance.
|
|
448
|
+
# This should only trigger if the user uses this class directly.
|
|
449
|
+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
|
|
450
|
+
self._onedal_factory, onedal_RandomForestClassifier
|
|
451
|
+
):
|
|
452
|
+
self._onedal_factory = onedal_RandomForestClassifier
|
|
453
|
+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
|
|
454
|
+
self._onedal_factory, onedal_ExtraTreesClassifier
|
|
455
|
+
):
|
|
456
|
+
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
457
|
+
|
|
458
|
+
if self._onedal_factory is None:
|
|
459
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
460
|
+
|
|
461
|
+
decision_path = support_input_format(_sklearn_ForestClassifier.decision_path)
|
|
462
|
+
apply = support_input_format(_sklearn_ForestClassifier.apply)
|
|
463
|
+
|
|
464
|
+
def _estimators_(self):
|
|
465
|
+
super()._estimators_()
|
|
466
|
+
for est in self._cached_estimators_:
|
|
467
|
+
est.classes_ = self.classes_
|
|
468
|
+
|
|
469
|
+
def fit(self, X, y, sample_weight=None):
|
|
470
|
+
dispatch(
|
|
471
|
+
self,
|
|
472
|
+
"fit",
|
|
473
|
+
{
|
|
474
|
+
"onedal": self.__class__._onedal_fit,
|
|
475
|
+
"sklearn": _sklearn_ForestClassifier.fit,
|
|
476
|
+
},
|
|
477
|
+
X,
|
|
478
|
+
y,
|
|
479
|
+
sample_weight,
|
|
480
|
+
)
|
|
481
|
+
return self
|
|
482
|
+
|
|
483
|
+
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
484
|
+
if sp.issparse(y):
|
|
485
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
486
|
+
|
|
487
|
+
if sklearn_check_version("1.2"):
|
|
488
|
+
self._validate_params()
|
|
489
|
+
else:
|
|
490
|
+
self._check_parameters()
|
|
491
|
+
|
|
492
|
+
if not self.bootstrap and self.oob_score:
|
|
493
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
494
|
+
|
|
495
|
+
patching_status.and_conditions(
|
|
496
|
+
[
|
|
497
|
+
(
|
|
498
|
+
self.oob_score
|
|
499
|
+
and daal_check_version((2021, "P", 500))
|
|
500
|
+
or not self.oob_score,
|
|
501
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
502
|
+
),
|
|
503
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
504
|
+
(
|
|
505
|
+
self.criterion == "gini",
|
|
506
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
507
|
+
"Only 'gini' criterion is supported.",
|
|
508
|
+
),
|
|
509
|
+
(
|
|
510
|
+
self.ccp_alpha == 0.0,
|
|
511
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
512
|
+
),
|
|
513
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
514
|
+
(
|
|
515
|
+
self.n_estimators <= 6024,
|
|
516
|
+
"More than 6024 estimators is not supported.",
|
|
517
|
+
),
|
|
518
|
+
]
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
if self.bootstrap:
|
|
522
|
+
patching_status.and_conditions(
|
|
523
|
+
[
|
|
524
|
+
(
|
|
525
|
+
self.class_weight != "balanced_subsample",
|
|
526
|
+
"'balanced_subsample' for class_weight is not supported",
|
|
527
|
+
)
|
|
528
|
+
]
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
532
|
+
try:
|
|
533
|
+
_assert_all_finite(X)
|
|
534
|
+
input_is_finite = True
|
|
535
|
+
except ValueError:
|
|
536
|
+
input_is_finite = False
|
|
537
|
+
patching_status.and_conditions(
|
|
538
|
+
[
|
|
539
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
540
|
+
(
|
|
541
|
+
self.monotonic_cst is None,
|
|
542
|
+
"Monotonicity constraints are not supported.",
|
|
543
|
+
),
|
|
544
|
+
]
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
if patching_status.get_status():
|
|
548
|
+
if sklearn_check_version("1.6"):
|
|
549
|
+
X, y = check_X_y(
|
|
550
|
+
X,
|
|
551
|
+
y,
|
|
552
|
+
multi_output=True,
|
|
553
|
+
accept_sparse=True,
|
|
554
|
+
dtype=[np.float64, np.float32],
|
|
555
|
+
ensure_all_finite=False,
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
X, y = check_X_y(
|
|
559
|
+
X,
|
|
560
|
+
y,
|
|
561
|
+
multi_output=True,
|
|
562
|
+
accept_sparse=True,
|
|
563
|
+
dtype=[np.float64, np.float32],
|
|
564
|
+
force_all_finite=False,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
568
|
+
warnings.warn(
|
|
569
|
+
"A column-vector y was passed when a 1d array was"
|
|
570
|
+
" expected. Please change the shape of y to "
|
|
571
|
+
"(n_samples,), for example using ravel().",
|
|
572
|
+
DataConversionWarning,
|
|
573
|
+
stacklevel=2,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if y.ndim == 1:
|
|
577
|
+
y = np.reshape(y, (-1, 1))
|
|
578
|
+
|
|
579
|
+
self.n_outputs_ = y.shape[1]
|
|
580
|
+
|
|
581
|
+
patching_status.and_conditions(
|
|
582
|
+
[
|
|
583
|
+
(
|
|
584
|
+
self.n_outputs_ == 1,
|
|
585
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
586
|
+
),
|
|
587
|
+
(
|
|
588
|
+
y.dtype in [np.float32, np.float64, np.int32, np.int64],
|
|
589
|
+
f"Datatype ({y.dtype}) for y is not supported.",
|
|
590
|
+
),
|
|
591
|
+
]
|
|
592
|
+
)
|
|
593
|
+
# TODO: Fix to support integers as input
|
|
594
|
+
|
|
595
|
+
if self.n_outputs_ == 1:
|
|
596
|
+
xp, is_array_api_compliant = get_namespace(y)
|
|
597
|
+
sety = xp.unique_values(y) if is_array_api_compliant else np.unique(y)
|
|
598
|
+
num_classes = sety.shape[0]
|
|
599
|
+
patching_status.and_conditions(
|
|
600
|
+
[
|
|
601
|
+
(
|
|
602
|
+
num_classes >= 2,
|
|
603
|
+
"Number of classes must be at least 2.",
|
|
604
|
+
),
|
|
605
|
+
]
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
609
|
+
|
|
610
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
611
|
+
raise ValueError(
|
|
612
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
613
|
+
"Either switch to `bootstrap=True` or set "
|
|
614
|
+
"`max_sample=None`."
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
if (
|
|
618
|
+
patching_status.get_status()
|
|
619
|
+
and (self.random_state is not None)
|
|
620
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
621
|
+
):
|
|
622
|
+
warnings.warn(
|
|
623
|
+
"Setting 'random_state' value is not supported. "
|
|
624
|
+
"State set by oneDAL to default value (777).",
|
|
625
|
+
RuntimeWarning,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
return patching_status, X, y, sample_weight
|
|
629
|
+
|
|
630
|
+
@wrap_output_data
|
|
631
|
+
def predict(self, X):
|
|
632
|
+
check_is_fitted(self)
|
|
633
|
+
return dispatch(
|
|
634
|
+
self,
|
|
635
|
+
"predict",
|
|
636
|
+
{
|
|
637
|
+
"onedal": self.__class__._onedal_predict,
|
|
638
|
+
"sklearn": _sklearn_ForestClassifier.predict,
|
|
639
|
+
},
|
|
640
|
+
X,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
@wrap_output_data
|
|
644
|
+
def predict_proba(self, X):
|
|
645
|
+
# TODO:
|
|
646
|
+
# _check_proba()
|
|
647
|
+
# self._check_proba()
|
|
648
|
+
check_is_fitted(self)
|
|
649
|
+
return dispatch(
|
|
650
|
+
self,
|
|
651
|
+
"predict_proba",
|
|
652
|
+
{
|
|
653
|
+
"onedal": self.__class__._onedal_predict_proba,
|
|
654
|
+
"sklearn": _sklearn_ForestClassifier.predict_proba,
|
|
655
|
+
},
|
|
656
|
+
X,
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
def predict_log_proba(self, X):
|
|
660
|
+
xp, _ = get_namespace(X)
|
|
661
|
+
proba = self.predict_proba(X)
|
|
662
|
+
|
|
663
|
+
if self.n_outputs_ == 1:
|
|
664
|
+
return xp.log(proba)
|
|
665
|
+
|
|
666
|
+
else:
|
|
667
|
+
for k in range(self.n_outputs_):
|
|
668
|
+
proba[k] = xp.log(proba[k])
|
|
669
|
+
|
|
670
|
+
return proba
|
|
671
|
+
|
|
672
|
+
@wrap_output_data
|
|
673
|
+
def score(self, X, y, sample_weight=None):
|
|
674
|
+
check_is_fitted(self)
|
|
675
|
+
return dispatch(
|
|
676
|
+
self,
|
|
677
|
+
"score",
|
|
678
|
+
{
|
|
679
|
+
"onedal": self.__class__._onedal_score,
|
|
680
|
+
"sklearn": _sklearn_ForestClassifier.score,
|
|
681
|
+
},
|
|
682
|
+
X,
|
|
683
|
+
y,
|
|
684
|
+
sample_weight=sample_weight,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
|
|
688
|
+
predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
|
|
689
|
+
predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
|
|
690
|
+
predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
691
|
+
score.__doc__ = _sklearn_ForestClassifier.score.__doc__
|
|
692
|
+
|
|
693
|
+
def _onedal_cpu_supported(self, method_name, *data):
|
|
694
|
+
class_name = self.__class__.__name__
|
|
695
|
+
patching_status = PatchingConditionsChain(
|
|
696
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
if method_name == "fit":
|
|
700
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
701
|
+
patching_status, *data
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
patching_status.and_conditions(
|
|
705
|
+
[
|
|
706
|
+
(
|
|
707
|
+
daal_check_version((2023, "P", 200))
|
|
708
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
709
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
710
|
+
),
|
|
711
|
+
(
|
|
712
|
+
not sp.issparse(sample_weight),
|
|
713
|
+
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
714
|
+
),
|
|
715
|
+
]
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
719
|
+
X = data[0]
|
|
720
|
+
|
|
721
|
+
patching_status.and_conditions(
|
|
722
|
+
[
|
|
723
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
724
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
725
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
726
|
+
(
|
|
727
|
+
daal_check_version((2023, "P", 100))
|
|
728
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
729
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
730
|
+
),
|
|
731
|
+
]
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
if method_name == "predict_proba":
|
|
735
|
+
patching_status.and_conditions(
|
|
736
|
+
[
|
|
737
|
+
(
|
|
738
|
+
daal_check_version((2021, "P", 400)),
|
|
739
|
+
"oneDAL version is lower than 2021.4.",
|
|
740
|
+
)
|
|
741
|
+
]
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
if hasattr(self, "n_outputs_"):
|
|
745
|
+
patching_status.and_conditions(
|
|
746
|
+
[
|
|
747
|
+
(
|
|
748
|
+
self.n_outputs_ == 1,
|
|
749
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
750
|
+
),
|
|
751
|
+
]
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
else:
|
|
755
|
+
raise RuntimeError(
|
|
756
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
return patching_status
|
|
760
|
+
|
|
761
|
+
def _onedal_gpu_supported(self, method_name, *data):
|
|
762
|
+
class_name = self.__class__.__name__
|
|
763
|
+
patching_status = PatchingConditionsChain(
|
|
764
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if method_name == "fit":
|
|
768
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
769
|
+
patching_status, *data
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
patching_status.and_conditions(
|
|
773
|
+
[
|
|
774
|
+
(
|
|
775
|
+
daal_check_version((2023, "P", 100))
|
|
776
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
777
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
778
|
+
),
|
|
779
|
+
(
|
|
780
|
+
not self.oob_score,
|
|
781
|
+
"oob_scores using r2 or accuracy not implemented.",
|
|
782
|
+
),
|
|
783
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
784
|
+
]
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
788
|
+
X = data[0]
|
|
789
|
+
|
|
790
|
+
patching_status.and_conditions(
|
|
791
|
+
[
|
|
792
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
|
|
793
|
+
(
|
|
794
|
+
not sp.issparse(X),
|
|
795
|
+
"X is sparse. Sparse input is not supported.",
|
|
796
|
+
),
|
|
797
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
798
|
+
(
|
|
799
|
+
daal_check_version((2023, "P", 100)),
|
|
800
|
+
"ExtraTrees supported starting from oneDAL version 2023.1",
|
|
801
|
+
),
|
|
802
|
+
]
|
|
803
|
+
)
|
|
804
|
+
if hasattr(self, "n_outputs_"):
|
|
805
|
+
patching_status.and_conditions(
|
|
806
|
+
[
|
|
807
|
+
(
|
|
808
|
+
self.n_outputs_ == 1,
|
|
809
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
810
|
+
),
|
|
811
|
+
]
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
else:
|
|
815
|
+
raise RuntimeError(
|
|
816
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
return patching_status
|
|
820
|
+
|
|
821
|
+
def _onedal_predict(self, X, queue=None):
|
|
822
|
+
xp, _ = get_namespace(X)
|
|
823
|
+
if not get_config()["use_raw_input"]:
|
|
824
|
+
X = validate_data(
|
|
825
|
+
self,
|
|
826
|
+
X,
|
|
827
|
+
dtype=[np.float64, np.float32],
|
|
828
|
+
ensure_all_finite=False,
|
|
829
|
+
reset=False,
|
|
830
|
+
ensure_2d=True,
|
|
831
|
+
)
|
|
832
|
+
if hasattr(self, "n_features_in_"):
|
|
833
|
+
try:
|
|
834
|
+
num_features = _num_features(X)
|
|
835
|
+
except TypeError:
|
|
836
|
+
num_features = _num_samples(X)
|
|
837
|
+
if num_features != self.n_features_in_:
|
|
838
|
+
raise ValueError(
|
|
839
|
+
(
|
|
840
|
+
f"X has {num_features} features, "
|
|
841
|
+
f"but {self.__class__.__name__} is expecting "
|
|
842
|
+
f"{self.n_features_in_} features as input"
|
|
843
|
+
)
|
|
844
|
+
)
|
|
845
|
+
check_n_features(self, X, reset=False)
|
|
846
|
+
|
|
847
|
+
res = self._onedal_estimator.predict(X, queue=queue)
|
|
848
|
+
try:
|
|
849
|
+
return xp.take(
|
|
850
|
+
xp.asarray(self.classes_, device=res.sycl_queue),
|
|
851
|
+
xp.astype(xp.reshape(res, (-1,)), xp.int64),
|
|
852
|
+
)
|
|
853
|
+
except AttributeError:
|
|
854
|
+
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
855
|
+
|
|
856
|
+
def _onedal_predict_proba(self, X, queue=None):
|
|
857
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
858
|
+
if not use_raw_input:
|
|
859
|
+
X = validate_data(
|
|
860
|
+
self,
|
|
861
|
+
X,
|
|
862
|
+
dtype=[np.float64, np.float32],
|
|
863
|
+
ensure_all_finite=False,
|
|
864
|
+
reset=False,
|
|
865
|
+
ensure_2d=True,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
869
|
+
|
|
870
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
871
|
+
return accuracy_score(
|
|
872
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
class ForestRegressor(BaseForest, _sklearn_ForestRegressor):
|
|
877
|
+
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
878
|
+
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
879
|
+
|
|
880
|
+
def __init__(
|
|
881
|
+
self,
|
|
882
|
+
estimator,
|
|
883
|
+
n_estimators=100,
|
|
884
|
+
*,
|
|
885
|
+
estimator_params=tuple(),
|
|
886
|
+
bootstrap=False,
|
|
887
|
+
oob_score=False,
|
|
888
|
+
n_jobs=None,
|
|
889
|
+
random_state=None,
|
|
890
|
+
verbose=0,
|
|
891
|
+
warm_start=False,
|
|
892
|
+
max_samples=None,
|
|
893
|
+
):
|
|
894
|
+
super().__init__(
|
|
895
|
+
estimator,
|
|
896
|
+
n_estimators=n_estimators,
|
|
897
|
+
estimator_params=estimator_params,
|
|
898
|
+
bootstrap=bootstrap,
|
|
899
|
+
oob_score=oob_score,
|
|
900
|
+
n_jobs=n_jobs,
|
|
901
|
+
random_state=random_state,
|
|
902
|
+
verbose=verbose,
|
|
903
|
+
warm_start=warm_start,
|
|
904
|
+
max_samples=max_samples,
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# The splitter is checked against the class attribute for conformance
|
|
908
|
+
# This should only trigger if the user uses this class directly.
|
|
909
|
+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
|
|
910
|
+
self._onedal_factory, onedal_RandomForestRegressor
|
|
911
|
+
):
|
|
912
|
+
self._onedal_factory = onedal_RandomForestRegressor
|
|
913
|
+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
|
|
914
|
+
self._onedal_factory, onedal_ExtraTreesRegressor
|
|
915
|
+
):
|
|
916
|
+
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
917
|
+
|
|
918
|
+
if self._onedal_factory is None:
|
|
919
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
920
|
+
|
|
921
|
+
decision_path = support_input_format(_sklearn_ForestRegressor.decision_path)
|
|
922
|
+
apply = support_input_format(_sklearn_ForestRegressor.apply)
|
|
923
|
+
|
|
924
|
+
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
925
|
+
if sp.issparse(y):
|
|
926
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
927
|
+
|
|
928
|
+
if sklearn_check_version("1.2"):
|
|
929
|
+
self._validate_params()
|
|
930
|
+
else:
|
|
931
|
+
self._check_parameters()
|
|
932
|
+
|
|
933
|
+
if not self.bootstrap and self.oob_score:
|
|
934
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
935
|
+
|
|
936
|
+
if not sklearn_check_version("1.2") and self.criterion == "mse":
|
|
937
|
+
warnings.warn(
|
|
938
|
+
"Criterion 'mse' was deprecated in v1.0 and will be "
|
|
939
|
+
"removed in version 1.2. Use `criterion='squared_error'` "
|
|
940
|
+
"which is equivalent.",
|
|
941
|
+
FutureWarning,
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
patching_status.and_conditions(
|
|
945
|
+
[
|
|
946
|
+
(
|
|
947
|
+
self.oob_score
|
|
948
|
+
and daal_check_version((2021, "P", 500))
|
|
949
|
+
or not self.oob_score,
|
|
950
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
951
|
+
),
|
|
952
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
953
|
+
(
|
|
954
|
+
self.criterion in ["mse", "squared_error"],
|
|
955
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
956
|
+
"Only 'mse' and 'squared_error' criteria are supported.",
|
|
957
|
+
),
|
|
958
|
+
(
|
|
959
|
+
self.ccp_alpha == 0.0,
|
|
960
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
961
|
+
),
|
|
962
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
963
|
+
(
|
|
964
|
+
self.n_estimators <= 6024,
|
|
965
|
+
"More than 6024 estimators is not supported.",
|
|
966
|
+
),
|
|
967
|
+
]
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
971
|
+
try:
|
|
972
|
+
_assert_all_finite(X)
|
|
973
|
+
input_is_finite = True
|
|
974
|
+
except ValueError:
|
|
975
|
+
input_is_finite = False
|
|
976
|
+
patching_status.and_conditions(
|
|
977
|
+
[
|
|
978
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
979
|
+
(
|
|
980
|
+
self.monotonic_cst is None,
|
|
981
|
+
"Monotonicity constraints are not supported.",
|
|
982
|
+
),
|
|
983
|
+
]
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
if patching_status.get_status():
|
|
987
|
+
if sklearn_check_version("1.6"):
|
|
988
|
+
X, y = check_X_y(
|
|
989
|
+
X,
|
|
990
|
+
y,
|
|
991
|
+
multi_output=True,
|
|
992
|
+
accept_sparse=True,
|
|
993
|
+
dtype=[np.float64, np.float32],
|
|
994
|
+
ensure_all_finite=False,
|
|
995
|
+
)
|
|
996
|
+
else:
|
|
997
|
+
X, y = check_X_y(
|
|
998
|
+
X,
|
|
999
|
+
y,
|
|
1000
|
+
multi_output=True,
|
|
1001
|
+
accept_sparse=True,
|
|
1002
|
+
dtype=[np.float64, np.float32],
|
|
1003
|
+
force_all_finite=False,
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
1007
|
+
warnings.warn(
|
|
1008
|
+
"A column-vector y was passed when a 1d array was"
|
|
1009
|
+
" expected. Please change the shape of y to "
|
|
1010
|
+
"(n_samples,), for example using ravel().",
|
|
1011
|
+
DataConversionWarning,
|
|
1012
|
+
stacklevel=2,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
if y.ndim == 1:
|
|
1016
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
1017
|
+
# [:, np.newaxis] that does not.
|
|
1018
|
+
y = np.reshape(y, (-1, 1))
|
|
1019
|
+
|
|
1020
|
+
self.n_outputs_ = y.shape[1]
|
|
1021
|
+
|
|
1022
|
+
patching_status.and_conditions(
|
|
1023
|
+
[
|
|
1024
|
+
(
|
|
1025
|
+
self.n_outputs_ == 1,
|
|
1026
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1027
|
+
)
|
|
1028
|
+
]
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
# Sklearn function used for doing checks on max_samples attribute
|
|
1032
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
1033
|
+
|
|
1034
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
1035
|
+
raise ValueError(
|
|
1036
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
1037
|
+
"Either switch to `bootstrap=True` or set "
|
|
1038
|
+
"`max_sample=None`."
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
if (
|
|
1042
|
+
patching_status.get_status()
|
|
1043
|
+
and (self.random_state is not None)
|
|
1044
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
1045
|
+
):
|
|
1046
|
+
warnings.warn(
|
|
1047
|
+
"Setting 'random_state' value is not supported. "
|
|
1048
|
+
"State set by oneDAL to default value (777).",
|
|
1049
|
+
RuntimeWarning,
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
return patching_status, X, y, sample_weight
|
|
1053
|
+
|
|
1054
|
+
def _onedal_cpu_supported(self, method_name, *data):
|
|
1055
|
+
class_name = self.__class__.__name__
|
|
1056
|
+
patching_status = PatchingConditionsChain(
|
|
1057
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
if method_name == "fit":
|
|
1061
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
1062
|
+
patching_status, *data
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
patching_status.and_conditions(
|
|
1066
|
+
[
|
|
1067
|
+
(
|
|
1068
|
+
daal_check_version((2023, "P", 200))
|
|
1069
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1070
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
1071
|
+
),
|
|
1072
|
+
(
|
|
1073
|
+
not sp.issparse(sample_weight),
|
|
1074
|
+
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
1075
|
+
),
|
|
1076
|
+
]
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
elif method_name in ["predict", "score"]:
|
|
1080
|
+
X = data[0]
|
|
1081
|
+
|
|
1082
|
+
patching_status.and_conditions(
|
|
1083
|
+
[
|
|
1084
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
1085
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1086
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
1087
|
+
(
|
|
1088
|
+
daal_check_version((2023, "P", 200))
|
|
1089
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1090
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
1091
|
+
),
|
|
1092
|
+
]
|
|
1093
|
+
)
|
|
1094
|
+
if hasattr(self, "n_outputs_"):
|
|
1095
|
+
patching_status.and_conditions(
|
|
1096
|
+
[
|
|
1097
|
+
(
|
|
1098
|
+
self.n_outputs_ == 1,
|
|
1099
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1100
|
+
),
|
|
1101
|
+
]
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
else:
|
|
1105
|
+
raise RuntimeError(
|
|
1106
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
return patching_status
|
|
1110
|
+
|
|
1111
|
+
def _onedal_gpu_supported(self, method_name, *data):
|
|
1112
|
+
class_name = self.__class__.__name__
|
|
1113
|
+
patching_status = PatchingConditionsChain(
|
|
1114
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
if method_name == "fit":
|
|
1118
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
1119
|
+
patching_status, *data
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
patching_status.and_conditions(
|
|
1123
|
+
[
|
|
1124
|
+
(
|
|
1125
|
+
daal_check_version((2023, "P", 100))
|
|
1126
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1127
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1128
|
+
),
|
|
1129
|
+
(not self.oob_score, "oob_score value is not sklearn conformant."),
|
|
1130
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
1131
|
+
]
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
elif method_name in ["predict", "score"]:
|
|
1135
|
+
X = data[0]
|
|
1136
|
+
|
|
1137
|
+
patching_status.and_conditions(
|
|
1138
|
+
[
|
|
1139
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
1140
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1141
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
1142
|
+
(
|
|
1143
|
+
daal_check_version((2023, "P", 100))
|
|
1144
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1145
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1146
|
+
),
|
|
1147
|
+
]
|
|
1148
|
+
)
|
|
1149
|
+
if hasattr(self, "n_outputs_"):
|
|
1150
|
+
patching_status.and_conditions(
|
|
1151
|
+
[
|
|
1152
|
+
(
|
|
1153
|
+
self.n_outputs_ == 1,
|
|
1154
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1155
|
+
),
|
|
1156
|
+
]
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
else:
|
|
1160
|
+
raise RuntimeError(
|
|
1161
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
1162
|
+
)
|
|
1163
|
+
|
|
1164
|
+
return patching_status
|
|
1165
|
+
|
|
1166
|
+
def _onedal_predict(self, X, queue=None):
|
|
1167
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
1168
|
+
use_raw_input = get_config().get("use_raw_input", False) is True
|
|
1169
|
+
|
|
1170
|
+
if not use_raw_input:
|
|
1171
|
+
X = validate_data(
|
|
1172
|
+
self,
|
|
1173
|
+
X,
|
|
1174
|
+
dtype=[np.float64, np.float32],
|
|
1175
|
+
ensure_all_finite=False,
|
|
1176
|
+
reset=False,
|
|
1177
|
+
ensure_2d=True,
|
|
1178
|
+
) # Warning, order of dtype matters
|
|
1179
|
+
|
|
1180
|
+
return self._onedal_estimator.predict(X, queue=queue)
|
|
1181
|
+
|
|
1182
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
1183
|
+
return r2_score(
|
|
1184
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
def fit(self, X, y, sample_weight=None):
|
|
1188
|
+
dispatch(
|
|
1189
|
+
self,
|
|
1190
|
+
"fit",
|
|
1191
|
+
{
|
|
1192
|
+
"onedal": self.__class__._onedal_fit,
|
|
1193
|
+
"sklearn": _sklearn_ForestRegressor.fit,
|
|
1194
|
+
},
|
|
1195
|
+
X,
|
|
1196
|
+
y,
|
|
1197
|
+
sample_weight,
|
|
1198
|
+
)
|
|
1199
|
+
return self
|
|
1200
|
+
|
|
1201
|
+
@wrap_output_data
|
|
1202
|
+
def predict(self, X):
|
|
1203
|
+
check_is_fitted(self)
|
|
1204
|
+
return dispatch(
|
|
1205
|
+
self,
|
|
1206
|
+
"predict",
|
|
1207
|
+
{
|
|
1208
|
+
"onedal": self.__class__._onedal_predict,
|
|
1209
|
+
"sklearn": _sklearn_ForestRegressor.predict,
|
|
1210
|
+
},
|
|
1211
|
+
X,
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
@wrap_output_data
|
|
1215
|
+
def score(self, X, y, sample_weight=None):
|
|
1216
|
+
check_is_fitted(self)
|
|
1217
|
+
return dispatch(
|
|
1218
|
+
self,
|
|
1219
|
+
"score",
|
|
1220
|
+
{
|
|
1221
|
+
"onedal": self.__class__._onedal_score,
|
|
1222
|
+
"sklearn": _sklearn_ForestRegressor.score,
|
|
1223
|
+
},
|
|
1224
|
+
X,
|
|
1225
|
+
y,
|
|
1226
|
+
sample_weight=sample_weight,
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
|
|
1230
|
+
predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
|
|
1231
|
+
score.__doc__ = _sklearn_ForestRegressor.score.__doc__
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
@register_hyperparameters({"predict": ("decision_forest", "infer")})
|
|
1235
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1236
|
+
class RandomForestClassifier(ForestClassifier):
|
|
1237
|
+
__doc__ = _sklearn_RandomForestClassifier.__doc__
|
|
1238
|
+
_onedal_factory = onedal_RandomForestClassifier
|
|
1239
|
+
|
|
1240
|
+
if sklearn_check_version("1.2"):
|
|
1241
|
+
_parameter_constraints: dict = {
|
|
1242
|
+
**_sklearn_RandomForestClassifier._parameter_constraints,
|
|
1243
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1244
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1245
|
+
}
|
|
1246
|
+
|
|
1247
|
+
if sklearn_check_version("1.4"):
|
|
1248
|
+
|
|
1249
|
+
def __init__(
|
|
1250
|
+
self,
|
|
1251
|
+
n_estimators=100,
|
|
1252
|
+
*,
|
|
1253
|
+
criterion="gini",
|
|
1254
|
+
max_depth=None,
|
|
1255
|
+
min_samples_split=2,
|
|
1256
|
+
min_samples_leaf=1,
|
|
1257
|
+
min_weight_fraction_leaf=0.0,
|
|
1258
|
+
max_features="sqrt",
|
|
1259
|
+
max_leaf_nodes=None,
|
|
1260
|
+
min_impurity_decrease=0.0,
|
|
1261
|
+
bootstrap=True,
|
|
1262
|
+
oob_score=False,
|
|
1263
|
+
n_jobs=None,
|
|
1264
|
+
random_state=None,
|
|
1265
|
+
verbose=0,
|
|
1266
|
+
warm_start=False,
|
|
1267
|
+
class_weight=None,
|
|
1268
|
+
ccp_alpha=0.0,
|
|
1269
|
+
max_samples=None,
|
|
1270
|
+
monotonic_cst=None,
|
|
1271
|
+
max_bins=256,
|
|
1272
|
+
min_bin_size=1,
|
|
1273
|
+
):
|
|
1274
|
+
super().__init__(
|
|
1275
|
+
DecisionTreeClassifier(),
|
|
1276
|
+
n_estimators,
|
|
1277
|
+
estimator_params=(
|
|
1278
|
+
"criterion",
|
|
1279
|
+
"max_depth",
|
|
1280
|
+
"min_samples_split",
|
|
1281
|
+
"min_samples_leaf",
|
|
1282
|
+
"min_weight_fraction_leaf",
|
|
1283
|
+
"max_features",
|
|
1284
|
+
"max_leaf_nodes",
|
|
1285
|
+
"min_impurity_decrease",
|
|
1286
|
+
"random_state",
|
|
1287
|
+
"ccp_alpha",
|
|
1288
|
+
"monotonic_cst",
|
|
1289
|
+
),
|
|
1290
|
+
bootstrap=bootstrap,
|
|
1291
|
+
oob_score=oob_score,
|
|
1292
|
+
n_jobs=n_jobs,
|
|
1293
|
+
random_state=random_state,
|
|
1294
|
+
verbose=verbose,
|
|
1295
|
+
warm_start=warm_start,
|
|
1296
|
+
class_weight=class_weight,
|
|
1297
|
+
max_samples=max_samples,
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
self.criterion = criterion
|
|
1301
|
+
self.max_depth = max_depth
|
|
1302
|
+
self.min_samples_split = min_samples_split
|
|
1303
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1304
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1305
|
+
self.max_features = max_features
|
|
1306
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1307
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1308
|
+
self.ccp_alpha = ccp_alpha
|
|
1309
|
+
self.max_bins = max_bins
|
|
1310
|
+
self.min_bin_size = min_bin_size
|
|
1311
|
+
self.monotonic_cst = monotonic_cst
|
|
1312
|
+
|
|
1313
|
+
else:
|
|
1314
|
+
|
|
1315
|
+
def __init__(
|
|
1316
|
+
self,
|
|
1317
|
+
n_estimators=100,
|
|
1318
|
+
*,
|
|
1319
|
+
criterion="gini",
|
|
1320
|
+
max_depth=None,
|
|
1321
|
+
min_samples_split=2,
|
|
1322
|
+
min_samples_leaf=1,
|
|
1323
|
+
min_weight_fraction_leaf=0.0,
|
|
1324
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1325
|
+
max_leaf_nodes=None,
|
|
1326
|
+
min_impurity_decrease=0.0,
|
|
1327
|
+
bootstrap=True,
|
|
1328
|
+
oob_score=False,
|
|
1329
|
+
n_jobs=None,
|
|
1330
|
+
random_state=None,
|
|
1331
|
+
verbose=0,
|
|
1332
|
+
warm_start=False,
|
|
1333
|
+
class_weight=None,
|
|
1334
|
+
ccp_alpha=0.0,
|
|
1335
|
+
max_samples=None,
|
|
1336
|
+
max_bins=256,
|
|
1337
|
+
min_bin_size=1,
|
|
1338
|
+
):
|
|
1339
|
+
super().__init__(
|
|
1340
|
+
DecisionTreeClassifier(),
|
|
1341
|
+
n_estimators,
|
|
1342
|
+
estimator_params=(
|
|
1343
|
+
"criterion",
|
|
1344
|
+
"max_depth",
|
|
1345
|
+
"min_samples_split",
|
|
1346
|
+
"min_samples_leaf",
|
|
1347
|
+
"min_weight_fraction_leaf",
|
|
1348
|
+
"max_features",
|
|
1349
|
+
"max_leaf_nodes",
|
|
1350
|
+
"min_impurity_decrease",
|
|
1351
|
+
"random_state",
|
|
1352
|
+
"ccp_alpha",
|
|
1353
|
+
),
|
|
1354
|
+
bootstrap=bootstrap,
|
|
1355
|
+
oob_score=oob_score,
|
|
1356
|
+
n_jobs=n_jobs,
|
|
1357
|
+
random_state=random_state,
|
|
1358
|
+
verbose=verbose,
|
|
1359
|
+
warm_start=warm_start,
|
|
1360
|
+
class_weight=class_weight,
|
|
1361
|
+
max_samples=max_samples,
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
self.criterion = criterion
|
|
1365
|
+
self.max_depth = max_depth
|
|
1366
|
+
self.min_samples_split = min_samples_split
|
|
1367
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1368
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1369
|
+
self.max_features = max_features
|
|
1370
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1371
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1372
|
+
self.ccp_alpha = ccp_alpha
|
|
1373
|
+
self.max_bins = max_bins
|
|
1374
|
+
self.min_bin_size = min_bin_size
|
|
1375
|
+
|
|
1376
|
+
|
|
1377
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1378
|
+
class RandomForestRegressor(ForestRegressor):
|
|
1379
|
+
__doc__ = _sklearn_RandomForestRegressor.__doc__
|
|
1380
|
+
_onedal_factory = onedal_RandomForestRegressor
|
|
1381
|
+
|
|
1382
|
+
if sklearn_check_version("1.2"):
|
|
1383
|
+
_parameter_constraints: dict = {
|
|
1384
|
+
**_sklearn_RandomForestRegressor._parameter_constraints,
|
|
1385
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1386
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
if sklearn_check_version("1.4"):
|
|
1390
|
+
|
|
1391
|
+
def __init__(
|
|
1392
|
+
self,
|
|
1393
|
+
n_estimators=100,
|
|
1394
|
+
*,
|
|
1395
|
+
criterion="squared_error",
|
|
1396
|
+
max_depth=None,
|
|
1397
|
+
min_samples_split=2,
|
|
1398
|
+
min_samples_leaf=1,
|
|
1399
|
+
min_weight_fraction_leaf=0.0,
|
|
1400
|
+
max_features=1.0,
|
|
1401
|
+
max_leaf_nodes=None,
|
|
1402
|
+
min_impurity_decrease=0.0,
|
|
1403
|
+
bootstrap=True,
|
|
1404
|
+
oob_score=False,
|
|
1405
|
+
n_jobs=None,
|
|
1406
|
+
random_state=None,
|
|
1407
|
+
verbose=0,
|
|
1408
|
+
warm_start=False,
|
|
1409
|
+
ccp_alpha=0.0,
|
|
1410
|
+
max_samples=None,
|
|
1411
|
+
monotonic_cst=None,
|
|
1412
|
+
max_bins=256,
|
|
1413
|
+
min_bin_size=1,
|
|
1414
|
+
):
|
|
1415
|
+
super().__init__(
|
|
1416
|
+
DecisionTreeRegressor(),
|
|
1417
|
+
n_estimators=n_estimators,
|
|
1418
|
+
estimator_params=(
|
|
1419
|
+
"criterion",
|
|
1420
|
+
"max_depth",
|
|
1421
|
+
"min_samples_split",
|
|
1422
|
+
"min_samples_leaf",
|
|
1423
|
+
"min_weight_fraction_leaf",
|
|
1424
|
+
"max_features",
|
|
1425
|
+
"max_leaf_nodes",
|
|
1426
|
+
"min_impurity_decrease",
|
|
1427
|
+
"random_state",
|
|
1428
|
+
"ccp_alpha",
|
|
1429
|
+
"monotonic_cst",
|
|
1430
|
+
),
|
|
1431
|
+
bootstrap=bootstrap,
|
|
1432
|
+
oob_score=oob_score,
|
|
1433
|
+
n_jobs=n_jobs,
|
|
1434
|
+
random_state=random_state,
|
|
1435
|
+
verbose=verbose,
|
|
1436
|
+
warm_start=warm_start,
|
|
1437
|
+
max_samples=max_samples,
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
self.criterion = criterion
|
|
1441
|
+
self.max_depth = max_depth
|
|
1442
|
+
self.min_samples_split = min_samples_split
|
|
1443
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1444
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1445
|
+
self.max_features = max_features
|
|
1446
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1447
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1448
|
+
self.ccp_alpha = ccp_alpha
|
|
1449
|
+
self.max_bins = max_bins
|
|
1450
|
+
self.min_bin_size = min_bin_size
|
|
1451
|
+
self.monotonic_cst = monotonic_cst
|
|
1452
|
+
|
|
1453
|
+
else:
|
|
1454
|
+
|
|
1455
|
+
def __init__(
|
|
1456
|
+
self,
|
|
1457
|
+
n_estimators=100,
|
|
1458
|
+
*,
|
|
1459
|
+
criterion="squared_error",
|
|
1460
|
+
max_depth=None,
|
|
1461
|
+
min_samples_split=2,
|
|
1462
|
+
min_samples_leaf=1,
|
|
1463
|
+
min_weight_fraction_leaf=0.0,
|
|
1464
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1465
|
+
max_leaf_nodes=None,
|
|
1466
|
+
min_impurity_decrease=0.0,
|
|
1467
|
+
bootstrap=True,
|
|
1468
|
+
oob_score=False,
|
|
1469
|
+
n_jobs=None,
|
|
1470
|
+
random_state=None,
|
|
1471
|
+
verbose=0,
|
|
1472
|
+
warm_start=False,
|
|
1473
|
+
ccp_alpha=0.0,
|
|
1474
|
+
max_samples=None,
|
|
1475
|
+
max_bins=256,
|
|
1476
|
+
min_bin_size=1,
|
|
1477
|
+
):
|
|
1478
|
+
super().__init__(
|
|
1479
|
+
DecisionTreeRegressor(),
|
|
1480
|
+
n_estimators=n_estimators,
|
|
1481
|
+
estimator_params=(
|
|
1482
|
+
"criterion",
|
|
1483
|
+
"max_depth",
|
|
1484
|
+
"min_samples_split",
|
|
1485
|
+
"min_samples_leaf",
|
|
1486
|
+
"min_weight_fraction_leaf",
|
|
1487
|
+
"max_features",
|
|
1488
|
+
"max_leaf_nodes",
|
|
1489
|
+
"min_impurity_decrease",
|
|
1490
|
+
"random_state",
|
|
1491
|
+
"ccp_alpha",
|
|
1492
|
+
),
|
|
1493
|
+
bootstrap=bootstrap,
|
|
1494
|
+
oob_score=oob_score,
|
|
1495
|
+
n_jobs=n_jobs,
|
|
1496
|
+
random_state=random_state,
|
|
1497
|
+
verbose=verbose,
|
|
1498
|
+
warm_start=warm_start,
|
|
1499
|
+
max_samples=max_samples,
|
|
1500
|
+
)
|
|
1501
|
+
|
|
1502
|
+
self.criterion = criterion
|
|
1503
|
+
self.max_depth = max_depth
|
|
1504
|
+
self.min_samples_split = min_samples_split
|
|
1505
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1506
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1507
|
+
self.max_features = max_features
|
|
1508
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1509
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1510
|
+
self.ccp_alpha = ccp_alpha
|
|
1511
|
+
self.max_bins = max_bins
|
|
1512
|
+
self.min_bin_size = min_bin_size
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1516
|
+
class ExtraTreesClassifier(ForestClassifier):
|
|
1517
|
+
__doc__ = _sklearn_ExtraTreesClassifier.__doc__
|
|
1518
|
+
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1519
|
+
|
|
1520
|
+
if sklearn_check_version("1.2"):
|
|
1521
|
+
_parameter_constraints: dict = {
|
|
1522
|
+
**_sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1523
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1524
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
if sklearn_check_version("1.4"):
|
|
1528
|
+
|
|
1529
|
+
def __init__(
|
|
1530
|
+
self,
|
|
1531
|
+
n_estimators=100,
|
|
1532
|
+
*,
|
|
1533
|
+
criterion="gini",
|
|
1534
|
+
max_depth=None,
|
|
1535
|
+
min_samples_split=2,
|
|
1536
|
+
min_samples_leaf=1,
|
|
1537
|
+
min_weight_fraction_leaf=0.0,
|
|
1538
|
+
max_features="sqrt",
|
|
1539
|
+
max_leaf_nodes=None,
|
|
1540
|
+
min_impurity_decrease=0.0,
|
|
1541
|
+
bootstrap=False,
|
|
1542
|
+
oob_score=False,
|
|
1543
|
+
n_jobs=None,
|
|
1544
|
+
random_state=None,
|
|
1545
|
+
verbose=0,
|
|
1546
|
+
warm_start=False,
|
|
1547
|
+
class_weight=None,
|
|
1548
|
+
ccp_alpha=0.0,
|
|
1549
|
+
max_samples=None,
|
|
1550
|
+
monotonic_cst=None,
|
|
1551
|
+
max_bins=256,
|
|
1552
|
+
min_bin_size=1,
|
|
1553
|
+
):
|
|
1554
|
+
super().__init__(
|
|
1555
|
+
ExtraTreeClassifier(),
|
|
1556
|
+
n_estimators,
|
|
1557
|
+
estimator_params=(
|
|
1558
|
+
"criterion",
|
|
1559
|
+
"max_depth",
|
|
1560
|
+
"min_samples_split",
|
|
1561
|
+
"min_samples_leaf",
|
|
1562
|
+
"min_weight_fraction_leaf",
|
|
1563
|
+
"max_features",
|
|
1564
|
+
"max_leaf_nodes",
|
|
1565
|
+
"min_impurity_decrease",
|
|
1566
|
+
"random_state",
|
|
1567
|
+
"ccp_alpha",
|
|
1568
|
+
"monotonic_cst",
|
|
1569
|
+
),
|
|
1570
|
+
bootstrap=bootstrap,
|
|
1571
|
+
oob_score=oob_score,
|
|
1572
|
+
n_jobs=n_jobs,
|
|
1573
|
+
random_state=random_state,
|
|
1574
|
+
verbose=verbose,
|
|
1575
|
+
warm_start=warm_start,
|
|
1576
|
+
class_weight=class_weight,
|
|
1577
|
+
max_samples=max_samples,
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
self.criterion = criterion
|
|
1581
|
+
self.max_depth = max_depth
|
|
1582
|
+
self.min_samples_split = min_samples_split
|
|
1583
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1584
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1585
|
+
self.max_features = max_features
|
|
1586
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1587
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1588
|
+
self.ccp_alpha = ccp_alpha
|
|
1589
|
+
self.max_bins = max_bins
|
|
1590
|
+
self.min_bin_size = min_bin_size
|
|
1591
|
+
self.monotonic_cst = monotonic_cst
|
|
1592
|
+
|
|
1593
|
+
else:
|
|
1594
|
+
|
|
1595
|
+
def __init__(
|
|
1596
|
+
self,
|
|
1597
|
+
n_estimators=100,
|
|
1598
|
+
*,
|
|
1599
|
+
criterion="gini",
|
|
1600
|
+
max_depth=None,
|
|
1601
|
+
min_samples_split=2,
|
|
1602
|
+
min_samples_leaf=1,
|
|
1603
|
+
min_weight_fraction_leaf=0.0,
|
|
1604
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1605
|
+
max_leaf_nodes=None,
|
|
1606
|
+
min_impurity_decrease=0.0,
|
|
1607
|
+
bootstrap=False,
|
|
1608
|
+
oob_score=False,
|
|
1609
|
+
n_jobs=None,
|
|
1610
|
+
random_state=None,
|
|
1611
|
+
verbose=0,
|
|
1612
|
+
warm_start=False,
|
|
1613
|
+
class_weight=None,
|
|
1614
|
+
ccp_alpha=0.0,
|
|
1615
|
+
max_samples=None,
|
|
1616
|
+
max_bins=256,
|
|
1617
|
+
min_bin_size=1,
|
|
1618
|
+
):
|
|
1619
|
+
super().__init__(
|
|
1620
|
+
ExtraTreeClassifier(),
|
|
1621
|
+
n_estimators,
|
|
1622
|
+
estimator_params=(
|
|
1623
|
+
"criterion",
|
|
1624
|
+
"max_depth",
|
|
1625
|
+
"min_samples_split",
|
|
1626
|
+
"min_samples_leaf",
|
|
1627
|
+
"min_weight_fraction_leaf",
|
|
1628
|
+
"max_features",
|
|
1629
|
+
"max_leaf_nodes",
|
|
1630
|
+
"min_impurity_decrease",
|
|
1631
|
+
"random_state",
|
|
1632
|
+
"ccp_alpha",
|
|
1633
|
+
),
|
|
1634
|
+
bootstrap=bootstrap,
|
|
1635
|
+
oob_score=oob_score,
|
|
1636
|
+
n_jobs=n_jobs,
|
|
1637
|
+
random_state=random_state,
|
|
1638
|
+
verbose=verbose,
|
|
1639
|
+
warm_start=warm_start,
|
|
1640
|
+
class_weight=class_weight,
|
|
1641
|
+
max_samples=max_samples,
|
|
1642
|
+
)
|
|
1643
|
+
|
|
1644
|
+
self.criterion = criterion
|
|
1645
|
+
self.max_depth = max_depth
|
|
1646
|
+
self.min_samples_split = min_samples_split
|
|
1647
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1648
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1649
|
+
self.max_features = max_features
|
|
1650
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1651
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1652
|
+
self.ccp_alpha = ccp_alpha
|
|
1653
|
+
self.max_bins = max_bins
|
|
1654
|
+
self.min_bin_size = min_bin_size
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1658
|
+
class ExtraTreesRegressor(ForestRegressor):
|
|
1659
|
+
__doc__ = _sklearn_ExtraTreesRegressor.__doc__
|
|
1660
|
+
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1661
|
+
|
|
1662
|
+
if sklearn_check_version("1.2"):
|
|
1663
|
+
_parameter_constraints: dict = {
|
|
1664
|
+
**_sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1665
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1666
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1667
|
+
}
|
|
1668
|
+
|
|
1669
|
+
if sklearn_check_version("1.4"):
|
|
1670
|
+
|
|
1671
|
+
def __init__(
|
|
1672
|
+
self,
|
|
1673
|
+
n_estimators=100,
|
|
1674
|
+
*,
|
|
1675
|
+
criterion="squared_error",
|
|
1676
|
+
max_depth=None,
|
|
1677
|
+
min_samples_split=2,
|
|
1678
|
+
min_samples_leaf=1,
|
|
1679
|
+
min_weight_fraction_leaf=0.0,
|
|
1680
|
+
max_features=1.0,
|
|
1681
|
+
max_leaf_nodes=None,
|
|
1682
|
+
min_impurity_decrease=0.0,
|
|
1683
|
+
bootstrap=False,
|
|
1684
|
+
oob_score=False,
|
|
1685
|
+
n_jobs=None,
|
|
1686
|
+
random_state=None,
|
|
1687
|
+
verbose=0,
|
|
1688
|
+
warm_start=False,
|
|
1689
|
+
ccp_alpha=0.0,
|
|
1690
|
+
max_samples=None,
|
|
1691
|
+
monotonic_cst=None,
|
|
1692
|
+
max_bins=256,
|
|
1693
|
+
min_bin_size=1,
|
|
1694
|
+
):
|
|
1695
|
+
super().__init__(
|
|
1696
|
+
ExtraTreeRegressor(),
|
|
1697
|
+
n_estimators=n_estimators,
|
|
1698
|
+
estimator_params=(
|
|
1699
|
+
"criterion",
|
|
1700
|
+
"max_depth",
|
|
1701
|
+
"min_samples_split",
|
|
1702
|
+
"min_samples_leaf",
|
|
1703
|
+
"min_weight_fraction_leaf",
|
|
1704
|
+
"max_features",
|
|
1705
|
+
"max_leaf_nodes",
|
|
1706
|
+
"min_impurity_decrease",
|
|
1707
|
+
"random_state",
|
|
1708
|
+
"ccp_alpha",
|
|
1709
|
+
"monotonic_cst",
|
|
1710
|
+
),
|
|
1711
|
+
bootstrap=bootstrap,
|
|
1712
|
+
oob_score=oob_score,
|
|
1713
|
+
n_jobs=n_jobs,
|
|
1714
|
+
random_state=random_state,
|
|
1715
|
+
verbose=verbose,
|
|
1716
|
+
warm_start=warm_start,
|
|
1717
|
+
max_samples=max_samples,
|
|
1718
|
+
)
|
|
1719
|
+
|
|
1720
|
+
self.criterion = criterion
|
|
1721
|
+
self.max_depth = max_depth
|
|
1722
|
+
self.min_samples_split = min_samples_split
|
|
1723
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1724
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1725
|
+
self.max_features = max_features
|
|
1726
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1727
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1728
|
+
self.ccp_alpha = ccp_alpha
|
|
1729
|
+
self.max_bins = max_bins
|
|
1730
|
+
self.min_bin_size = min_bin_size
|
|
1731
|
+
self.monotonic_cst = monotonic_cst
|
|
1732
|
+
|
|
1733
|
+
else:
|
|
1734
|
+
|
|
1735
|
+
def __init__(
|
|
1736
|
+
self,
|
|
1737
|
+
n_estimators=100,
|
|
1738
|
+
*,
|
|
1739
|
+
criterion="squared_error",
|
|
1740
|
+
max_depth=None,
|
|
1741
|
+
min_samples_split=2,
|
|
1742
|
+
min_samples_leaf=1,
|
|
1743
|
+
min_weight_fraction_leaf=0.0,
|
|
1744
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1745
|
+
max_leaf_nodes=None,
|
|
1746
|
+
min_impurity_decrease=0.0,
|
|
1747
|
+
bootstrap=False,
|
|
1748
|
+
oob_score=False,
|
|
1749
|
+
n_jobs=None,
|
|
1750
|
+
random_state=None,
|
|
1751
|
+
verbose=0,
|
|
1752
|
+
warm_start=False,
|
|
1753
|
+
ccp_alpha=0.0,
|
|
1754
|
+
max_samples=None,
|
|
1755
|
+
max_bins=256,
|
|
1756
|
+
min_bin_size=1,
|
|
1757
|
+
):
|
|
1758
|
+
super().__init__(
|
|
1759
|
+
ExtraTreeRegressor(),
|
|
1760
|
+
n_estimators=n_estimators,
|
|
1761
|
+
estimator_params=(
|
|
1762
|
+
"criterion",
|
|
1763
|
+
"max_depth",
|
|
1764
|
+
"min_samples_split",
|
|
1765
|
+
"min_samples_leaf",
|
|
1766
|
+
"min_weight_fraction_leaf",
|
|
1767
|
+
"max_features",
|
|
1768
|
+
"max_leaf_nodes",
|
|
1769
|
+
"min_impurity_decrease",
|
|
1770
|
+
"random_state",
|
|
1771
|
+
"ccp_alpha",
|
|
1772
|
+
),
|
|
1773
|
+
bootstrap=bootstrap,
|
|
1774
|
+
oob_score=oob_score,
|
|
1775
|
+
n_jobs=n_jobs,
|
|
1776
|
+
random_state=random_state,
|
|
1777
|
+
verbose=verbose,
|
|
1778
|
+
warm_start=warm_start,
|
|
1779
|
+
max_samples=max_samples,
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
self.criterion = criterion
|
|
1783
|
+
self.max_depth = max_depth
|
|
1784
|
+
self.min_samples_split = min_samples_split
|
|
1785
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1786
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1787
|
+
self.max_features = max_features
|
|
1788
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1789
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1790
|
+
self.ccp_alpha = ccp_alpha
|
|
1791
|
+
self.max_bins = max_bins
|
|
1792
|
+
self.min_bin_size = min_bin_size
|
|
1793
|
+
|
|
1794
|
+
|
|
1795
|
+
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
1796
|
+
_sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
1797
|
+
_sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
1798
|
+
_sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
1799
|
+
_sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|