scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
- scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
- scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
- scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py
ADDED
|
@@ -0,0 +1,1285 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
import warnings
|
|
19
|
+
from math import ceil
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from scipy import sparse as sp
|
|
23
|
+
from sklearn.base import clone
|
|
24
|
+
from sklearn.ensemble import RandomForestClassifier as RandomForestClassifier_original
|
|
25
|
+
from sklearn.ensemble import RandomForestRegressor as RandomForestRegressor_original
|
|
26
|
+
from sklearn.exceptions import DataConversionWarning
|
|
27
|
+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
|
28
|
+
from sklearn.tree._tree import Tree
|
|
29
|
+
from sklearn.utils import check_array, check_random_state, deprecated
|
|
30
|
+
from sklearn.utils.validation import (
|
|
31
|
+
_num_samples,
|
|
32
|
+
check_consistent_length,
|
|
33
|
+
check_is_fitted,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
import daal4py
|
|
37
|
+
from daal4py.sklearn._utils import (
|
|
38
|
+
PatchingConditionsChain,
|
|
39
|
+
check_tree_nodes,
|
|
40
|
+
daal_check_version,
|
|
41
|
+
getFPType,
|
|
42
|
+
sklearn_check_version,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
from .._n_jobs_support import control_n_jobs
|
|
46
|
+
from ..utils.validation import _daal_num_features, check_feature_names, check_n_features
|
|
47
|
+
|
|
48
|
+
if sklearn_check_version("1.2"):
|
|
49
|
+
from sklearn.utils._param_validation import Interval, StrOptions
|
|
50
|
+
if sklearn_check_version("1.4"):
|
|
51
|
+
from daal4py.sklearn.utils import _assert_all_finite
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _to_absolute_max_features(max_features, n_features, is_classification=False):
|
|
55
|
+
if max_features is None:
|
|
56
|
+
return n_features
|
|
57
|
+
if isinstance(max_features, str):
|
|
58
|
+
if max_features == "auto":
|
|
59
|
+
if not sklearn_check_version("1.3"):
|
|
60
|
+
if sklearn_check_version("1.1"):
|
|
61
|
+
warnings.warn(
|
|
62
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
63
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
64
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
65
|
+
"parameter as it is also the default value for "
|
|
66
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
67
|
+
FutureWarning,
|
|
68
|
+
)
|
|
69
|
+
return (
|
|
70
|
+
max(1, int(np.sqrt(n_features))) if is_classification else n_features
|
|
71
|
+
)
|
|
72
|
+
if max_features == "sqrt":
|
|
73
|
+
return max(1, int(np.sqrt(n_features)))
|
|
74
|
+
if max_features == "log2":
|
|
75
|
+
return max(1, int(np.log2(n_features)))
|
|
76
|
+
allowed_string_values = (
|
|
77
|
+
'"sqrt" or "log2"'
|
|
78
|
+
if sklearn_check_version("1.3")
|
|
79
|
+
else '"auto", "sqrt" or "log2"'
|
|
80
|
+
)
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"Invalid value for max_features. Allowed string "
|
|
83
|
+
f"values are {allowed_string_values}."
|
|
84
|
+
)
|
|
85
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
86
|
+
return max_features
|
|
87
|
+
if max_features > 0.0:
|
|
88
|
+
return max(1, int(max_features * n_features))
|
|
89
|
+
return 0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_n_samples_bootstrap(n_samples, max_samples):
|
|
93
|
+
if max_samples is None:
|
|
94
|
+
return 1.0
|
|
95
|
+
|
|
96
|
+
if isinstance(max_samples, numbers.Integral):
|
|
97
|
+
if not sklearn_check_version("1.2"):
|
|
98
|
+
if not (1 <= max_samples <= n_samples):
|
|
99
|
+
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
100
|
+
raise ValueError(msg.format(n_samples, max_samples))
|
|
101
|
+
else:
|
|
102
|
+
if max_samples > n_samples:
|
|
103
|
+
msg = "`max_samples` must be <= n_samples={} but got value {}"
|
|
104
|
+
raise ValueError(msg.format(n_samples, max_samples))
|
|
105
|
+
return max(float(max_samples / n_samples), 1 / n_samples)
|
|
106
|
+
|
|
107
|
+
if isinstance(max_samples, numbers.Real):
|
|
108
|
+
if sklearn_check_version("1.2"):
|
|
109
|
+
pass
|
|
110
|
+
else:
|
|
111
|
+
if not (0 < float(max_samples) <= 1):
|
|
112
|
+
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
|
|
113
|
+
raise ValueError(msg.format(max_samples))
|
|
114
|
+
return max(float(max_samples), 1 / n_samples)
|
|
115
|
+
|
|
116
|
+
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
117
|
+
raise TypeError(msg.format(type(max_samples)))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def check_sample_weight(sample_weight, X, dtype=None):
|
|
121
|
+
n_samples = _num_samples(X)
|
|
122
|
+
|
|
123
|
+
if dtype is not None and dtype not in [np.float32, np.float64]:
|
|
124
|
+
dtype = np.float64
|
|
125
|
+
|
|
126
|
+
if sample_weight is None:
|
|
127
|
+
sample_weight = np.ones(n_samples, dtype=dtype)
|
|
128
|
+
elif isinstance(sample_weight, numbers.Number):
|
|
129
|
+
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
|
|
130
|
+
else:
|
|
131
|
+
if dtype is None:
|
|
132
|
+
dtype = [np.float64, np.float32]
|
|
133
|
+
sample_weight = check_array(
|
|
134
|
+
sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype, order="C"
|
|
135
|
+
)
|
|
136
|
+
if sample_weight.ndim != 1:
|
|
137
|
+
raise ValueError("Sample weights must be 1D array or scalar")
|
|
138
|
+
|
|
139
|
+
if sample_weight.shape != (n_samples,):
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"sample_weight.shape == {}, expected {}!".format(
|
|
142
|
+
sample_weight.shape, (n_samples,)
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
return sample_weight
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class RandomForestBase:
|
|
149
|
+
def fit(self, X, y, sample_weight=None): ...
|
|
150
|
+
|
|
151
|
+
def predict(self, X): ...
|
|
152
|
+
|
|
153
|
+
def _check_parameters(self) -> None:
|
|
154
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
157
|
+
"Either switch to `bootstrap=True` or set "
|
|
158
|
+
"`max_sample=None`."
|
|
159
|
+
)
|
|
160
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
161
|
+
if not 1 <= self.min_samples_leaf:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"min_samples_leaf must be at least 1 "
|
|
164
|
+
f"or in (0, 0.5], got {self.min_samples_leaf}"
|
|
165
|
+
)
|
|
166
|
+
else: # float
|
|
167
|
+
if not 0.0 < self.min_samples_leaf <= 0.5:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"min_samples_leaf must be at least 1 "
|
|
170
|
+
f"or in (0, 0.5], got {self.min_samples_leaf}"
|
|
171
|
+
)
|
|
172
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
173
|
+
if not 2 <= self.min_samples_split:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"min_samples_split must be an integer "
|
|
176
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
177
|
+
f"got the integer {self.min_samples_split}"
|
|
178
|
+
)
|
|
179
|
+
else: # float
|
|
180
|
+
if not 0.0 < self.min_samples_split <= 1.0:
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"min_samples_split must be an integer "
|
|
183
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
184
|
+
"got the float {self.min_samples_split}"
|
|
185
|
+
)
|
|
186
|
+
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
|
|
187
|
+
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
|
|
188
|
+
if self.min_impurity_split is not None:
|
|
189
|
+
warnings.warn(
|
|
190
|
+
"The min_impurity_split parameter is deprecated. "
|
|
191
|
+
"Its default value has changed from 1e-7 to 0 in "
|
|
192
|
+
"version 0.23, and it will be removed in 0.25. "
|
|
193
|
+
"Use the min_impurity_decrease parameter instead.",
|
|
194
|
+
FutureWarning,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if self.min_impurity_split < 0.0:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"min_impurity_split must be greater " "than or equal to 0"
|
|
200
|
+
)
|
|
201
|
+
if self.min_impurity_decrease < 0.0:
|
|
202
|
+
raise ValueError("min_impurity_decrease must be greater than or equal to 0")
|
|
203
|
+
if self.max_leaf_nodes is not None:
|
|
204
|
+
if not isinstance(self.max_leaf_nodes, numbers.Integral):
|
|
205
|
+
raise ValueError(
|
|
206
|
+
"max_leaf_nodes must be integral number but was "
|
|
207
|
+
f"{self.max_leaf_nodes}"
|
|
208
|
+
)
|
|
209
|
+
if self.max_leaf_nodes < 2:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"max_leaf_nodes {self.max_leaf_nodes} must be either None "
|
|
212
|
+
"or larger than 1"
|
|
213
|
+
)
|
|
214
|
+
if isinstance(self.maxBins, numbers.Integral):
|
|
215
|
+
if not 2 <= self.maxBins:
|
|
216
|
+
raise ValueError(f"maxBins must be at least 2, got {self.maxBins}")
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError(f"maxBins must be integral number but was {self.maxBins}")
|
|
219
|
+
if isinstance(self.minBinSize, numbers.Integral):
|
|
220
|
+
if not 1 <= self.minBinSize:
|
|
221
|
+
raise ValueError(f"minBinSize must be at least 1, got {self.minBinSize}")
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"minBinSize must be integral number but was {self.minBinSize}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
|
|
229
|
+
class RandomForestClassifier(RandomForestClassifier_original, RandomForestBase):
|
|
230
|
+
__doc__ = RandomForestClassifier_original.__doc__
|
|
231
|
+
|
|
232
|
+
if sklearn_check_version("1.2"):
|
|
233
|
+
_parameter_constraints: dict = {
|
|
234
|
+
**RandomForestClassifier_original._parameter_constraints,
|
|
235
|
+
"maxBins": [Interval(numbers.Integral, 0, None, closed="left")],
|
|
236
|
+
"minBinSize": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
237
|
+
"binningStrategy": [StrOptions({"quantiles", "averages"})],
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
if sklearn_check_version("1.4"):
|
|
241
|
+
|
|
242
|
+
def __init__(
|
|
243
|
+
self,
|
|
244
|
+
n_estimators=100,
|
|
245
|
+
criterion="gini",
|
|
246
|
+
max_depth=None,
|
|
247
|
+
min_samples_split=2,
|
|
248
|
+
min_samples_leaf=1,
|
|
249
|
+
min_weight_fraction_leaf=0.0,
|
|
250
|
+
max_features="sqrt",
|
|
251
|
+
max_leaf_nodes=None,
|
|
252
|
+
min_impurity_decrease=0.0,
|
|
253
|
+
bootstrap=True,
|
|
254
|
+
oob_score=False,
|
|
255
|
+
n_jobs=None,
|
|
256
|
+
random_state=None,
|
|
257
|
+
verbose=0,
|
|
258
|
+
warm_start=False,
|
|
259
|
+
class_weight=None,
|
|
260
|
+
ccp_alpha=0.0,
|
|
261
|
+
max_samples=None,
|
|
262
|
+
monotonic_cst=None,
|
|
263
|
+
maxBins=256,
|
|
264
|
+
minBinSize=1,
|
|
265
|
+
binningStrategy="quantiles",
|
|
266
|
+
):
|
|
267
|
+
super().__init__(
|
|
268
|
+
n_estimators=n_estimators,
|
|
269
|
+
criterion=criterion,
|
|
270
|
+
max_depth=max_depth,
|
|
271
|
+
min_samples_split=min_samples_split,
|
|
272
|
+
min_samples_leaf=min_samples_leaf,
|
|
273
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
274
|
+
max_features=max_features,
|
|
275
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
276
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
277
|
+
bootstrap=bootstrap,
|
|
278
|
+
oob_score=oob_score,
|
|
279
|
+
n_jobs=n_jobs,
|
|
280
|
+
random_state=random_state,
|
|
281
|
+
verbose=verbose,
|
|
282
|
+
warm_start=warm_start,
|
|
283
|
+
class_weight=class_weight,
|
|
284
|
+
monotonic_cst=monotonic_cst,
|
|
285
|
+
)
|
|
286
|
+
self.ccp_alpha = ccp_alpha
|
|
287
|
+
self.max_samples = max_samples
|
|
288
|
+
self.monotonic_cst = monotonic_cst
|
|
289
|
+
self.maxBins = maxBins
|
|
290
|
+
self.minBinSize = minBinSize
|
|
291
|
+
self.min_impurity_split = None
|
|
292
|
+
self.binningStrategy = binningStrategy
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
|
|
296
|
+
def __init__(
|
|
297
|
+
self,
|
|
298
|
+
n_estimators=100,
|
|
299
|
+
criterion="gini",
|
|
300
|
+
max_depth=None,
|
|
301
|
+
min_samples_split=2,
|
|
302
|
+
min_samples_leaf=1,
|
|
303
|
+
min_weight_fraction_leaf=0.0,
|
|
304
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
305
|
+
max_leaf_nodes=None,
|
|
306
|
+
min_impurity_decrease=0.0,
|
|
307
|
+
bootstrap=True,
|
|
308
|
+
oob_score=False,
|
|
309
|
+
n_jobs=None,
|
|
310
|
+
random_state=None,
|
|
311
|
+
verbose=0,
|
|
312
|
+
warm_start=False,
|
|
313
|
+
class_weight=None,
|
|
314
|
+
ccp_alpha=0.0,
|
|
315
|
+
max_samples=None,
|
|
316
|
+
maxBins=256,
|
|
317
|
+
minBinSize=1,
|
|
318
|
+
binningStrategy="quantiles",
|
|
319
|
+
):
|
|
320
|
+
super().__init__(
|
|
321
|
+
n_estimators=n_estimators,
|
|
322
|
+
criterion=criterion,
|
|
323
|
+
max_depth=max_depth,
|
|
324
|
+
min_samples_split=min_samples_split,
|
|
325
|
+
min_samples_leaf=min_samples_leaf,
|
|
326
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
327
|
+
max_features=max_features,
|
|
328
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
329
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
330
|
+
bootstrap=bootstrap,
|
|
331
|
+
oob_score=oob_score,
|
|
332
|
+
n_jobs=n_jobs,
|
|
333
|
+
random_state=random_state,
|
|
334
|
+
verbose=verbose,
|
|
335
|
+
warm_start=warm_start,
|
|
336
|
+
class_weight=class_weight,
|
|
337
|
+
)
|
|
338
|
+
self.ccp_alpha = ccp_alpha
|
|
339
|
+
self.max_samples = max_samples
|
|
340
|
+
self.maxBins = maxBins
|
|
341
|
+
self.minBinSize = minBinSize
|
|
342
|
+
self.min_impurity_split = None
|
|
343
|
+
self.binningStrategy = binningStrategy
|
|
344
|
+
|
|
345
|
+
def fit(self, X, y, sample_weight=None):
|
|
346
|
+
"""
|
|
347
|
+
Build a forest of trees from the training set (X, y).
|
|
348
|
+
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
352
|
+
The training input samples. Internally, its dtype will be converted
|
|
353
|
+
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
354
|
+
converted into a sparse ``csc_matrix``.
|
|
355
|
+
|
|
356
|
+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
|
|
357
|
+
The target values (class labels in classification, real numbers in
|
|
358
|
+
regression).
|
|
359
|
+
|
|
360
|
+
sample_weight : array-like of shape (n_samples,), default=None
|
|
361
|
+
Sample weights. If None, then samples are equally weighted. Splits
|
|
362
|
+
that would create child nodes with net zero or negative weight are
|
|
363
|
+
ignored while searching for a split in each node. In the case of
|
|
364
|
+
classification, splits are also ignored if they would result in any
|
|
365
|
+
single class carrying a negative weight in either child node.
|
|
366
|
+
|
|
367
|
+
Returns
|
|
368
|
+
-------
|
|
369
|
+
self : object
|
|
370
|
+
"""
|
|
371
|
+
if sp.issparse(y):
|
|
372
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
373
|
+
if sklearn_check_version("1.2"):
|
|
374
|
+
self._validate_params()
|
|
375
|
+
else:
|
|
376
|
+
self._check_parameters()
|
|
377
|
+
if sample_weight is not None:
|
|
378
|
+
sample_weight = check_sample_weight(sample_weight, X)
|
|
379
|
+
|
|
380
|
+
_patching_status = PatchingConditionsChain(
|
|
381
|
+
"sklearn.ensemble.RandomForestClassifier.fit"
|
|
382
|
+
)
|
|
383
|
+
_dal_ready = _patching_status.and_conditions(
|
|
384
|
+
[
|
|
385
|
+
(
|
|
386
|
+
self.oob_score
|
|
387
|
+
and daal_check_version((2021, "P", 500))
|
|
388
|
+
or not self.oob_score,
|
|
389
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
390
|
+
),
|
|
391
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
392
|
+
(
|
|
393
|
+
self.criterion == "gini",
|
|
394
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
395
|
+
"Only 'gini' criterion is supported.",
|
|
396
|
+
),
|
|
397
|
+
(
|
|
398
|
+
self.ccp_alpha == 0.0,
|
|
399
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
400
|
+
),
|
|
401
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
402
|
+
]
|
|
403
|
+
)
|
|
404
|
+
if _dal_ready and sklearn_check_version("1.4"):
|
|
405
|
+
try:
|
|
406
|
+
_assert_all_finite(X)
|
|
407
|
+
input_is_finite = True
|
|
408
|
+
except ValueError:
|
|
409
|
+
input_is_finite = False
|
|
410
|
+
_patching_status.and_conditions(
|
|
411
|
+
[
|
|
412
|
+
(
|
|
413
|
+
input_is_finite,
|
|
414
|
+
"Non-finite input is not supported.",
|
|
415
|
+
),
|
|
416
|
+
(
|
|
417
|
+
self.monotonic_cst is None,
|
|
418
|
+
"Monotonicity constraints are not supported.",
|
|
419
|
+
),
|
|
420
|
+
]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
if _dal_ready:
|
|
424
|
+
check_feature_names(self, X, reset=True)
|
|
425
|
+
if sklearn_check_version("1.6"):
|
|
426
|
+
X = check_array(
|
|
427
|
+
X,
|
|
428
|
+
dtype=[np.float32, np.float64],
|
|
429
|
+
ensure_all_finite=False,
|
|
430
|
+
)
|
|
431
|
+
else:
|
|
432
|
+
X = check_array(
|
|
433
|
+
X,
|
|
434
|
+
dtype=[np.float32, np.float64],
|
|
435
|
+
force_all_finite=not sklearn_check_version("1.4"),
|
|
436
|
+
)
|
|
437
|
+
y = np.asarray(y)
|
|
438
|
+
y = np.atleast_1d(y)
|
|
439
|
+
|
|
440
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
441
|
+
warnings.warn(
|
|
442
|
+
"A column-vector y was passed when a 1d array was"
|
|
443
|
+
" expected. Please change the shape of y to "
|
|
444
|
+
"(n_samples,), for example using ravel().",
|
|
445
|
+
DataConversionWarning,
|
|
446
|
+
stacklevel=2,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
check_consistent_length(X, y)
|
|
450
|
+
|
|
451
|
+
if y.ndim == 1:
|
|
452
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
453
|
+
# [:, np.newaxis] that does not.
|
|
454
|
+
y = np.reshape(y, (-1, 1))
|
|
455
|
+
|
|
456
|
+
self.n_outputs_ = y.shape[1]
|
|
457
|
+
_dal_ready = _patching_status.and_conditions(
|
|
458
|
+
[
|
|
459
|
+
(
|
|
460
|
+
self.n_outputs_ == 1,
|
|
461
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
462
|
+
)
|
|
463
|
+
]
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
_patching_status.write_log()
|
|
467
|
+
if _dal_ready:
|
|
468
|
+
self._daal_fit_classifier(X, y, sample_weight=sample_weight)
|
|
469
|
+
|
|
470
|
+
if sklearn_check_version("1.2"):
|
|
471
|
+
self._estimator = DecisionTreeClassifier()
|
|
472
|
+
self.estimators_ = self._estimators_
|
|
473
|
+
|
|
474
|
+
# Decapsulate classes_ attributes
|
|
475
|
+
self.n_classes_ = self.n_classes_[0]
|
|
476
|
+
self.classes_ = self.classes_[0]
|
|
477
|
+
return self
|
|
478
|
+
return super().fit(X, y, sample_weight=sample_weight)
|
|
479
|
+
|
|
480
|
+
def predict(self, X):
|
|
481
|
+
"""
|
|
482
|
+
Predict class for X.
|
|
483
|
+
|
|
484
|
+
The predicted class of an input sample is a vote by the trees in
|
|
485
|
+
the forest, weighted by their probability estimates. That is,
|
|
486
|
+
the predicted class is the one with highest mean probability
|
|
487
|
+
estimate across the trees.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
492
|
+
The input samples. Internally, its dtype will be converted to
|
|
493
|
+
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
494
|
+
converted into a sparse ``csr_matrix``.
|
|
495
|
+
|
|
496
|
+
Returns
|
|
497
|
+
-------
|
|
498
|
+
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
|
|
499
|
+
The predicted classes.
|
|
500
|
+
"""
|
|
501
|
+
_patching_status = PatchingConditionsChain(
|
|
502
|
+
"sklearn.ensemble.RandomForestClassifier.predict"
|
|
503
|
+
)
|
|
504
|
+
_dal_ready = _patching_status.and_conditions(
|
|
505
|
+
[
|
|
506
|
+
(hasattr(self, "daal_model_"), "oneDAL model was not trained."),
|
|
507
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
508
|
+
]
|
|
509
|
+
)
|
|
510
|
+
if hasattr(self, "n_outputs_"):
|
|
511
|
+
_dal_ready = _patching_status.and_conditions(
|
|
512
|
+
[
|
|
513
|
+
(
|
|
514
|
+
self.n_outputs_ == 1,
|
|
515
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
516
|
+
)
|
|
517
|
+
]
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
_patching_status.write_log()
|
|
521
|
+
if not _dal_ready:
|
|
522
|
+
return super().predict(X)
|
|
523
|
+
|
|
524
|
+
check_feature_names(self, X, reset=False)
|
|
525
|
+
X = check_array(
|
|
526
|
+
X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
|
|
527
|
+
)
|
|
528
|
+
return self._daal_predict_classifier(X)
|
|
529
|
+
|
|
530
|
+
def predict_proba(self, X):
|
|
531
|
+
"""
|
|
532
|
+
Predict class probabilities for X.
|
|
533
|
+
|
|
534
|
+
The predicted class probabilities of an input sample are computed as
|
|
535
|
+
the mean predicted class probabilities of the trees in the forest.
|
|
536
|
+
The class probability of a single tree is the fraction of samples of
|
|
537
|
+
the same class in a leaf.
|
|
538
|
+
|
|
539
|
+
Parameters
|
|
540
|
+
----------
|
|
541
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
542
|
+
The input samples. Internally, its dtype will be converted to
|
|
543
|
+
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
544
|
+
converted into a sparse ``csr_matrix``.
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
p : ndarray of shape (n_samples, n_classes), or a list of n_outputs
|
|
549
|
+
such arrays if n_outputs > 1.
|
|
550
|
+
The class probabilities of the input samples. The order of the
|
|
551
|
+
classes corresponds to that in the attribute :term:`classes_`.
|
|
552
|
+
"""
|
|
553
|
+
check_feature_names(self, X, reset=False)
|
|
554
|
+
if hasattr(self, "n_features_in_"):
|
|
555
|
+
try:
|
|
556
|
+
num_features = _daal_num_features(X)
|
|
557
|
+
except TypeError:
|
|
558
|
+
num_features = _num_samples(X)
|
|
559
|
+
if num_features != self.n_features_in_:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
(
|
|
562
|
+
f"X has {num_features} features, "
|
|
563
|
+
f"but RandomForestClassifier is expecting "
|
|
564
|
+
f"{self.n_features_in_} features as input"
|
|
565
|
+
)
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
_patching_status = PatchingConditionsChain(
|
|
569
|
+
"sklearn.ensemble.RandomForestClassifier.predict_proba"
|
|
570
|
+
)
|
|
571
|
+
_dal_ready = _patching_status.and_conditions(
|
|
572
|
+
[
|
|
573
|
+
(hasattr(self, "daal_model_"), "oneDAL model was not trained."),
|
|
574
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
575
|
+
(
|
|
576
|
+
daal_check_version((2021, "P", 400)),
|
|
577
|
+
"oneDAL version is lower than 2021.4.",
|
|
578
|
+
),
|
|
579
|
+
]
|
|
580
|
+
)
|
|
581
|
+
if hasattr(self, "n_outputs_"):
|
|
582
|
+
_dal_ready = _patching_status.and_conditions(
|
|
583
|
+
[
|
|
584
|
+
(
|
|
585
|
+
self.n_outputs_ == 1,
|
|
586
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
587
|
+
)
|
|
588
|
+
]
|
|
589
|
+
)
|
|
590
|
+
_patching_status.write_log()
|
|
591
|
+
|
|
592
|
+
if not _dal_ready:
|
|
593
|
+
return super().predict_proba(X)
|
|
594
|
+
X = check_array(X, dtype=[np.float64, np.float32])
|
|
595
|
+
check_is_fitted(self)
|
|
596
|
+
check_n_features(self, X, reset=False)
|
|
597
|
+
return self._daal_predict_proba(X)
|
|
598
|
+
|
|
599
|
+
if not sklearn_check_version("1.2"):
|
|
600
|
+
|
|
601
|
+
@deprecated(
|
|
602
|
+
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
603
|
+
"removed in 1.2. Use `n_features_in_` instead."
|
|
604
|
+
)
|
|
605
|
+
@property
|
|
606
|
+
def n_features_(self):
|
|
607
|
+
return self.n_features_in_
|
|
608
|
+
|
|
609
|
+
@property
|
|
610
|
+
def _estimators_(self):
|
|
611
|
+
if hasattr(self, "_cached_estimators_"):
|
|
612
|
+
if self._cached_estimators_:
|
|
613
|
+
return self._cached_estimators_
|
|
614
|
+
|
|
615
|
+
check_is_fitted(self)
|
|
616
|
+
classes_ = self.classes_[0]
|
|
617
|
+
n_classes_ = self.n_classes_[0]
|
|
618
|
+
# convert model to estimators
|
|
619
|
+
params = {
|
|
620
|
+
"criterion": self.criterion,
|
|
621
|
+
"max_depth": self.max_depth,
|
|
622
|
+
"min_samples_split": self.min_samples_split,
|
|
623
|
+
"min_samples_leaf": self.min_samples_leaf,
|
|
624
|
+
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
625
|
+
"max_features": self.max_features,
|
|
626
|
+
"max_leaf_nodes": self.max_leaf_nodes,
|
|
627
|
+
"min_impurity_decrease": self.min_impurity_decrease,
|
|
628
|
+
"random_state": None,
|
|
629
|
+
}
|
|
630
|
+
est = DecisionTreeClassifier(**params)
|
|
631
|
+
# we need to set est.tree_ field with Trees constructed from
|
|
632
|
+
# oneAPI Data Analytics Library solution
|
|
633
|
+
estimators_ = []
|
|
634
|
+
random_state_checked = check_random_state(self.random_state)
|
|
635
|
+
for i in range(self.n_estimators):
|
|
636
|
+
est_i = clone(est)
|
|
637
|
+
est_i.set_params(
|
|
638
|
+
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
639
|
+
)
|
|
640
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
641
|
+
est_i.n_outputs_ = self.n_outputs_
|
|
642
|
+
est_i.classes_ = classes_
|
|
643
|
+
est_i.n_classes_ = n_classes_
|
|
644
|
+
# treeState members: 'class_count', 'leaf_count', 'max_depth',
|
|
645
|
+
# 'node_ar', 'node_count', 'value_ar'
|
|
646
|
+
tree_i_state_class = daal4py.getTreeState(self.daal_model_, i, n_classes_)
|
|
647
|
+
|
|
648
|
+
# node_ndarray = tree_i_state_class.node_ar
|
|
649
|
+
# value_ndarray = tree_i_state_class.value_ar
|
|
650
|
+
# value_shape = (node_ndarray.shape[0], self.n_outputs_,
|
|
651
|
+
# n_classes_)
|
|
652
|
+
# assert np.allclose(
|
|
653
|
+
# value_ndarray, value_ndarray.astype(np.intc, casting='unsafe')
|
|
654
|
+
# ), "Value array is non-integer"
|
|
655
|
+
tree_i_state_dict = {
|
|
656
|
+
"max_depth": tree_i_state_class.max_depth,
|
|
657
|
+
"node_count": tree_i_state_class.node_count,
|
|
658
|
+
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
659
|
+
"values": tree_i_state_class.value_ar,
|
|
660
|
+
}
|
|
661
|
+
est_i.tree_ = Tree(
|
|
662
|
+
self.n_features_in_,
|
|
663
|
+
np.array([n_classes_], dtype=np.intp),
|
|
664
|
+
self.n_outputs_,
|
|
665
|
+
)
|
|
666
|
+
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
667
|
+
estimators_.append(est_i)
|
|
668
|
+
|
|
669
|
+
self._cached_estimators_ = estimators_
|
|
670
|
+
return estimators_
|
|
671
|
+
|
|
672
|
+
def _daal_predict_proba(self, X):
|
|
673
|
+
X_fptype = getFPType(X)
|
|
674
|
+
dfc_algorithm = daal4py.decision_forest_classification_prediction(
|
|
675
|
+
nClasses=int(self.n_classes_),
|
|
676
|
+
fptype=X_fptype,
|
|
677
|
+
resultsToEvaluate="computeClassProbabilities",
|
|
678
|
+
)
|
|
679
|
+
dfc_predictionResult = dfc_algorithm.compute(X, self.daal_model_)
|
|
680
|
+
|
|
681
|
+
pred = dfc_predictionResult.probabilities
|
|
682
|
+
# TODO: fix probabilities out of [0, 1] interval on oneDAL side
|
|
683
|
+
return pred.clip(0.0, 1.0)
|
|
684
|
+
|
|
685
|
+
def _daal_fit_classifier(self, X, y, sample_weight=None):
|
|
686
|
+
y = check_array(y, ensure_2d=False, dtype=None)
|
|
687
|
+
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
688
|
+
n_classes = self.n_classes_[0]
|
|
689
|
+
self.n_features_in_ = X.shape[1]
|
|
690
|
+
|
|
691
|
+
if expanded_class_weight is not None:
|
|
692
|
+
if sample_weight is not None:
|
|
693
|
+
sample_weight = sample_weight * expanded_class_weight
|
|
694
|
+
else:
|
|
695
|
+
sample_weight = expanded_class_weight
|
|
696
|
+
if sample_weight is not None:
|
|
697
|
+
sample_weight = [sample_weight]
|
|
698
|
+
|
|
699
|
+
rs_ = check_random_state(self.random_state)
|
|
700
|
+
seed_ = rs_.randint(0, np.iinfo("i").max)
|
|
701
|
+
|
|
702
|
+
if n_classes < 2:
|
|
703
|
+
raise ValueError("Training data only contain information about one class.")
|
|
704
|
+
|
|
705
|
+
daal_engine = daal4py.engines_mt19937(seed=seed_, fptype=getFPType(X))
|
|
706
|
+
|
|
707
|
+
features_per_node = _to_absolute_max_features(
|
|
708
|
+
self.max_features, X.shape[1], is_classification=True
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
n_samples_bootstrap = _get_n_samples_bootstrap(
|
|
712
|
+
n_samples=X.shape[0], max_samples=self.max_samples
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
716
|
+
raise ValueError(
|
|
717
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
718
|
+
"Either switch to `bootstrap=True` or set "
|
|
719
|
+
"`max_sample=None`."
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
if not self.bootstrap and self.oob_score:
|
|
723
|
+
raise ValueError("Out of bag estimation only available if bootstrap=True")
|
|
724
|
+
|
|
725
|
+
parameters = {
|
|
726
|
+
"bootstrap": bool(self.bootstrap),
|
|
727
|
+
"engine": daal_engine,
|
|
728
|
+
"featuresPerNode": features_per_node,
|
|
729
|
+
"fptype": getFPType(X),
|
|
730
|
+
"impurityThreshold": self.min_impurity_split or 0.0,
|
|
731
|
+
"maxBins": self.maxBins,
|
|
732
|
+
"maxLeafNodes": self.max_leaf_nodes or 0,
|
|
733
|
+
"maxTreeDepth": self.max_depth or 0,
|
|
734
|
+
"memorySavingMode": False,
|
|
735
|
+
"method": "hist",
|
|
736
|
+
"minBinSize": self.minBinSize,
|
|
737
|
+
"minImpurityDecreaseInSplitNode": self.min_impurity_decrease,
|
|
738
|
+
"minWeightFractionInLeafNode": self.min_weight_fraction_leaf,
|
|
739
|
+
"nClasses": int(n_classes),
|
|
740
|
+
"nTrees": self.n_estimators,
|
|
741
|
+
"observationsPerTreeFraction": 1.0,
|
|
742
|
+
"resultsToCompute": "",
|
|
743
|
+
"varImportance": "MDI",
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
747
|
+
parameters["minObservationsInSplitNode"] = self.min_samples_split
|
|
748
|
+
else:
|
|
749
|
+
parameters["minObservationsInSplitNode"] = ceil(
|
|
750
|
+
self.min_samples_split * X.shape[0]
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
754
|
+
parameters["minObservationsInLeafNode"] = self.min_samples_leaf
|
|
755
|
+
else:
|
|
756
|
+
parameters["minObservationsInLeafNode"] = ceil(
|
|
757
|
+
self.min_samples_leaf * X.shape[0]
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
if self.bootstrap:
|
|
761
|
+
parameters["observationsPerTreeFraction"] = n_samples_bootstrap
|
|
762
|
+
if self.oob_score:
|
|
763
|
+
parameters["resultsToCompute"] = (
|
|
764
|
+
"computeOutOfBagErrorAccuracy|computeOutOfBagErrorDecisionFunction"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if daal_check_version((2023, "P", 200)):
|
|
768
|
+
parameters["binningStrategy"] = self.binningStrategy
|
|
769
|
+
|
|
770
|
+
# create algorithm
|
|
771
|
+
dfc_algorithm = daal4py.decision_forest_classification_training(**parameters)
|
|
772
|
+
self._cached_estimators_ = None
|
|
773
|
+
# compute
|
|
774
|
+
dfc_trainingResult = dfc_algorithm.compute(X, y, sample_weight)
|
|
775
|
+
|
|
776
|
+
# get resulting model
|
|
777
|
+
model = dfc_trainingResult.model
|
|
778
|
+
self.daal_model_ = model
|
|
779
|
+
|
|
780
|
+
if self.oob_score:
|
|
781
|
+
self.oob_score_ = dfc_trainingResult.outOfBagErrorAccuracy[0][0]
|
|
782
|
+
self.oob_decision_function_ = dfc_trainingResult.outOfBagErrorDecisionFunction
|
|
783
|
+
if self.oob_decision_function_.shape[-1] == 1:
|
|
784
|
+
self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1)
|
|
785
|
+
|
|
786
|
+
return self
|
|
787
|
+
|
|
788
|
+
def _daal_predict_classifier(self, X):
|
|
789
|
+
X_fptype = getFPType(X)
|
|
790
|
+
dfc_algorithm = daal4py.decision_forest_classification_prediction(
|
|
791
|
+
nClasses=int(self.n_classes_),
|
|
792
|
+
fptype=X_fptype,
|
|
793
|
+
resultsToEvaluate="computeClassLabels",
|
|
794
|
+
)
|
|
795
|
+
if X.shape[1] != self.n_features_in_:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
(
|
|
798
|
+
f"X has {X.shape[1]} features, "
|
|
799
|
+
f"but RandomForestClassifier is expecting "
|
|
800
|
+
f"{self.n_features_in_} features as input"
|
|
801
|
+
)
|
|
802
|
+
)
|
|
803
|
+
dfc_predictionResult = dfc_algorithm.compute(X, self.daal_model_)
|
|
804
|
+
|
|
805
|
+
pred = dfc_predictionResult.prediction
|
|
806
|
+
|
|
807
|
+
return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
811
|
+
class RandomForestRegressor(RandomForestRegressor_original, RandomForestBase):
|
|
812
|
+
__doc__ = RandomForestRegressor_original.__doc__
|
|
813
|
+
|
|
814
|
+
if sklearn_check_version("1.2"):
|
|
815
|
+
_parameter_constraints: dict = {
|
|
816
|
+
**RandomForestRegressor_original._parameter_constraints,
|
|
817
|
+
"maxBins": [Interval(numbers.Integral, 0, None, closed="left")],
|
|
818
|
+
"minBinSize": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
819
|
+
"binningStrategy": [StrOptions({"quantiles", "averages"})],
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
if sklearn_check_version("1.4"):
|
|
823
|
+
|
|
824
|
+
def __init__(
|
|
825
|
+
self,
|
|
826
|
+
n_estimators=100,
|
|
827
|
+
*,
|
|
828
|
+
criterion="squared_error",
|
|
829
|
+
max_depth=None,
|
|
830
|
+
min_samples_split=2,
|
|
831
|
+
min_samples_leaf=1,
|
|
832
|
+
min_weight_fraction_leaf=0.0,
|
|
833
|
+
max_features=1.0,
|
|
834
|
+
max_leaf_nodes=None,
|
|
835
|
+
min_impurity_decrease=0.0,
|
|
836
|
+
bootstrap=True,
|
|
837
|
+
oob_score=False,
|
|
838
|
+
n_jobs=None,
|
|
839
|
+
random_state=None,
|
|
840
|
+
verbose=0,
|
|
841
|
+
warm_start=False,
|
|
842
|
+
ccp_alpha=0.0,
|
|
843
|
+
max_samples=None,
|
|
844
|
+
monotonic_cst=None,
|
|
845
|
+
maxBins=256,
|
|
846
|
+
minBinSize=1,
|
|
847
|
+
binningStrategy="quantiles",
|
|
848
|
+
):
|
|
849
|
+
super().__init__(
|
|
850
|
+
n_estimators=n_estimators,
|
|
851
|
+
criterion=criterion,
|
|
852
|
+
max_depth=max_depth,
|
|
853
|
+
min_samples_split=min_samples_split,
|
|
854
|
+
min_samples_leaf=min_samples_leaf,
|
|
855
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
856
|
+
max_features=max_features,
|
|
857
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
858
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
859
|
+
bootstrap=bootstrap,
|
|
860
|
+
oob_score=oob_score,
|
|
861
|
+
n_jobs=n_jobs,
|
|
862
|
+
random_state=random_state,
|
|
863
|
+
verbose=verbose,
|
|
864
|
+
warm_start=warm_start,
|
|
865
|
+
monotonic_cst=monotonic_cst,
|
|
866
|
+
)
|
|
867
|
+
self.ccp_alpha = ccp_alpha
|
|
868
|
+
self.max_samples = max_samples
|
|
869
|
+
self.monotonic_cst = monotonic_cst
|
|
870
|
+
self.maxBins = maxBins
|
|
871
|
+
self.minBinSize = minBinSize
|
|
872
|
+
self.min_impurity_split = None
|
|
873
|
+
self.binningStrategy = binningStrategy
|
|
874
|
+
|
|
875
|
+
else:
|
|
876
|
+
|
|
877
|
+
def __init__(
|
|
878
|
+
self,
|
|
879
|
+
n_estimators=100,
|
|
880
|
+
*,
|
|
881
|
+
criterion="squared_error",
|
|
882
|
+
max_depth=None,
|
|
883
|
+
min_samples_split=2,
|
|
884
|
+
min_samples_leaf=1,
|
|
885
|
+
min_weight_fraction_leaf=0.0,
|
|
886
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
887
|
+
max_leaf_nodes=None,
|
|
888
|
+
min_impurity_decrease=0.0,
|
|
889
|
+
bootstrap=True,
|
|
890
|
+
oob_score=False,
|
|
891
|
+
n_jobs=None,
|
|
892
|
+
random_state=None,
|
|
893
|
+
verbose=0,
|
|
894
|
+
warm_start=False,
|
|
895
|
+
ccp_alpha=0.0,
|
|
896
|
+
max_samples=None,
|
|
897
|
+
maxBins=256,
|
|
898
|
+
minBinSize=1,
|
|
899
|
+
binningStrategy="quantiles",
|
|
900
|
+
):
|
|
901
|
+
super().__init__(
|
|
902
|
+
n_estimators=n_estimators,
|
|
903
|
+
criterion=criterion,
|
|
904
|
+
max_depth=max_depth,
|
|
905
|
+
min_samples_split=min_samples_split,
|
|
906
|
+
min_samples_leaf=min_samples_leaf,
|
|
907
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
908
|
+
max_features=max_features,
|
|
909
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
910
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
911
|
+
bootstrap=bootstrap,
|
|
912
|
+
oob_score=oob_score,
|
|
913
|
+
n_jobs=n_jobs,
|
|
914
|
+
random_state=random_state,
|
|
915
|
+
verbose=verbose,
|
|
916
|
+
warm_start=warm_start,
|
|
917
|
+
)
|
|
918
|
+
self.ccp_alpha = ccp_alpha
|
|
919
|
+
self.max_samples = max_samples
|
|
920
|
+
self.maxBins = maxBins
|
|
921
|
+
self.minBinSize = minBinSize
|
|
922
|
+
self.min_impurity_split = None
|
|
923
|
+
self.binningStrategy = binningStrategy
|
|
924
|
+
|
|
925
|
+
def fit(self, X, y, sample_weight=None):
|
|
926
|
+
"""
|
|
927
|
+
Build a forest of trees from the training set (X, y).
|
|
928
|
+
|
|
929
|
+
Parameters
|
|
930
|
+
----------
|
|
931
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
932
|
+
The training input samples. Internally, its dtype will be converted
|
|
933
|
+
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
934
|
+
converted into a sparse ``csc_matrix``.
|
|
935
|
+
|
|
936
|
+
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
|
|
937
|
+
The target values (class labels in classification, real numbers in
|
|
938
|
+
regression).
|
|
939
|
+
|
|
940
|
+
sample_weight : array-like of shape (n_samples,), default=None
|
|
941
|
+
Sample weights. If None, then samples are equally weighted. Splits
|
|
942
|
+
that would create child nodes with net zero or negative weight are
|
|
943
|
+
ignored while searching for a split in each node. In the case of
|
|
944
|
+
classification, splits are also ignored if they would result in any
|
|
945
|
+
single class carrying a negative weight in either child node.
|
|
946
|
+
|
|
947
|
+
Returns
|
|
948
|
+
-------
|
|
949
|
+
self : object
|
|
950
|
+
"""
|
|
951
|
+
if sp.issparse(y):
|
|
952
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
953
|
+
if sklearn_check_version("1.2"):
|
|
954
|
+
self._validate_params()
|
|
955
|
+
else:
|
|
956
|
+
self._check_parameters()
|
|
957
|
+
if sample_weight is not None:
|
|
958
|
+
sample_weight = check_sample_weight(sample_weight, X)
|
|
959
|
+
|
|
960
|
+
if not sklearn_check_version("1.2") and self.criterion == "mse":
|
|
961
|
+
warnings.warn(
|
|
962
|
+
"Criterion 'mse' was deprecated in v1.0 and will be "
|
|
963
|
+
"removed in version 1.2. Use `criterion='squared_error'` "
|
|
964
|
+
"which is equivalent.",
|
|
965
|
+
FutureWarning,
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
_patching_status = PatchingConditionsChain(
|
|
969
|
+
"sklearn.ensemble.RandomForestRegressor.fit"
|
|
970
|
+
)
|
|
971
|
+
_dal_ready = _patching_status.and_conditions(
|
|
972
|
+
[
|
|
973
|
+
(
|
|
974
|
+
self.oob_score
|
|
975
|
+
and daal_check_version((2021, "P", 500))
|
|
976
|
+
or not self.oob_score,
|
|
977
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
978
|
+
),
|
|
979
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
980
|
+
(
|
|
981
|
+
self.criterion in ["mse", "squared_error"],
|
|
982
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
983
|
+
"Only 'mse' and 'squared_error' criteria are supported.",
|
|
984
|
+
),
|
|
985
|
+
(
|
|
986
|
+
self.ccp_alpha == 0.0,
|
|
987
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
988
|
+
),
|
|
989
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
990
|
+
]
|
|
991
|
+
)
|
|
992
|
+
if _dal_ready and sklearn_check_version("1.4"):
|
|
993
|
+
try:
|
|
994
|
+
_assert_all_finite(X)
|
|
995
|
+
input_is_finite = True
|
|
996
|
+
except ValueError:
|
|
997
|
+
input_is_finite = False
|
|
998
|
+
_patching_status.and_conditions(
|
|
999
|
+
[
|
|
1000
|
+
(
|
|
1001
|
+
input_is_finite,
|
|
1002
|
+
"Non-finite input is not supported.",
|
|
1003
|
+
),
|
|
1004
|
+
(
|
|
1005
|
+
self.monotonic_cst is None,
|
|
1006
|
+
"Monotonicity constraints are not supported.",
|
|
1007
|
+
),
|
|
1008
|
+
]
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if _dal_ready:
|
|
1012
|
+
check_feature_names(self, X, reset=True)
|
|
1013
|
+
if sklearn_check_version("1.6"):
|
|
1014
|
+
X = check_array(
|
|
1015
|
+
X,
|
|
1016
|
+
dtype=[np.float64, np.float32],
|
|
1017
|
+
ensure_all_finite=False,
|
|
1018
|
+
)
|
|
1019
|
+
else:
|
|
1020
|
+
X = check_array(
|
|
1021
|
+
X,
|
|
1022
|
+
dtype=[np.float64, np.float32],
|
|
1023
|
+
force_all_finite=not sklearn_check_version("1.4"),
|
|
1024
|
+
)
|
|
1025
|
+
y = np.asarray(y)
|
|
1026
|
+
y = np.atleast_1d(y)
|
|
1027
|
+
|
|
1028
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
1029
|
+
warnings.warn(
|
|
1030
|
+
"A column-vector y was passed when a 1d array was"
|
|
1031
|
+
" expected. Please change the shape of y to "
|
|
1032
|
+
"(n_samples,), for example using ravel().",
|
|
1033
|
+
DataConversionWarning,
|
|
1034
|
+
stacklevel=2,
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
y = check_array(y, ensure_2d=False, dtype=X.dtype)
|
|
1038
|
+
check_consistent_length(X, y)
|
|
1039
|
+
|
|
1040
|
+
if y.ndim == 1:
|
|
1041
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
1042
|
+
# [:, np.newaxis] that does not.
|
|
1043
|
+
y = np.reshape(y, (-1, 1))
|
|
1044
|
+
|
|
1045
|
+
self.n_outputs_ = y.shape[1]
|
|
1046
|
+
_dal_ready = _patching_status.and_conditions(
|
|
1047
|
+
[
|
|
1048
|
+
(
|
|
1049
|
+
self.n_outputs_ == 1,
|
|
1050
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1051
|
+
)
|
|
1052
|
+
]
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
_patching_status.write_log()
|
|
1056
|
+
if _dal_ready:
|
|
1057
|
+
self._daal_fit_regressor(X, y, sample_weight=sample_weight)
|
|
1058
|
+
|
|
1059
|
+
if sklearn_check_version("1.2"):
|
|
1060
|
+
self._estimator = DecisionTreeRegressor()
|
|
1061
|
+
self.estimators_ = self._estimators_
|
|
1062
|
+
return self
|
|
1063
|
+
return super().fit(X, y, sample_weight=sample_weight)
|
|
1064
|
+
|
|
1065
|
+
def predict(self, X):
|
|
1066
|
+
"""
|
|
1067
|
+
Predict class for X.
|
|
1068
|
+
|
|
1069
|
+
The predicted class of an input sample is a vote by the trees in
|
|
1070
|
+
the forest, weighted by their probability estimates. That is,
|
|
1071
|
+
the predicted class is the one with highest mean probability
|
|
1072
|
+
estimate across the trees.
|
|
1073
|
+
|
|
1074
|
+
Parameters
|
|
1075
|
+
----------
|
|
1076
|
+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
1077
|
+
The input samples. Internally, its dtype will be converted to
|
|
1078
|
+
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
1079
|
+
converted into a sparse ``csr_matrix``.
|
|
1080
|
+
|
|
1081
|
+
Returns
|
|
1082
|
+
-------
|
|
1083
|
+
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
|
|
1084
|
+
The predicted classes.
|
|
1085
|
+
"""
|
|
1086
|
+
_patching_status = PatchingConditionsChain(
|
|
1087
|
+
"sklearn.ensemble.RandomForestRegressor.predict"
|
|
1088
|
+
)
|
|
1089
|
+
_dal_ready = _patching_status.and_conditions(
|
|
1090
|
+
[
|
|
1091
|
+
(hasattr(self, "daal_model_"), "oneDAL model was not trained."),
|
|
1092
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1093
|
+
]
|
|
1094
|
+
)
|
|
1095
|
+
if hasattr(self, "n_outputs_"):
|
|
1096
|
+
_dal_ready = _patching_status.and_conditions(
|
|
1097
|
+
[
|
|
1098
|
+
(
|
|
1099
|
+
self.n_outputs_ == 1,
|
|
1100
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1101
|
+
)
|
|
1102
|
+
]
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
_patching_status.write_log()
|
|
1106
|
+
if not _dal_ready:
|
|
1107
|
+
return super().predict(X)
|
|
1108
|
+
|
|
1109
|
+
check_feature_names(self, X, reset=False)
|
|
1110
|
+
X = check_array(
|
|
1111
|
+
X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
|
|
1112
|
+
)
|
|
1113
|
+
return self._daal_predict_regressor(X)
|
|
1114
|
+
|
|
1115
|
+
if not sklearn_check_version("1.2"):
|
|
1116
|
+
|
|
1117
|
+
@deprecated(
|
|
1118
|
+
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
1119
|
+
"removed in 1.2. Use `n_features_in_` instead."
|
|
1120
|
+
)
|
|
1121
|
+
@property
|
|
1122
|
+
def n_features_(self):
|
|
1123
|
+
return self.n_features_in_
|
|
1124
|
+
|
|
1125
|
+
@property
|
|
1126
|
+
def _estimators_(self):
|
|
1127
|
+
if hasattr(self, "_cached_estimators_"):
|
|
1128
|
+
if self._cached_estimators_:
|
|
1129
|
+
return self._cached_estimators_
|
|
1130
|
+
check_is_fitted(self)
|
|
1131
|
+
# convert model to estimators
|
|
1132
|
+
params = {
|
|
1133
|
+
"criterion": self.criterion,
|
|
1134
|
+
"max_depth": self.max_depth,
|
|
1135
|
+
"min_samples_split": self.min_samples_split,
|
|
1136
|
+
"min_samples_leaf": self.min_samples_leaf,
|
|
1137
|
+
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
1138
|
+
"max_features": self.max_features,
|
|
1139
|
+
"max_leaf_nodes": self.max_leaf_nodes,
|
|
1140
|
+
"min_impurity_decrease": self.min_impurity_decrease,
|
|
1141
|
+
"random_state": None,
|
|
1142
|
+
}
|
|
1143
|
+
est = DecisionTreeRegressor(**params)
|
|
1144
|
+
|
|
1145
|
+
# we need to set est.tree_ field with Trees constructed from
|
|
1146
|
+
# oneAPI Data Analytics Library solution
|
|
1147
|
+
estimators_ = []
|
|
1148
|
+
random_state_checked = check_random_state(self.random_state)
|
|
1149
|
+
for i in range(self.n_estimators):
|
|
1150
|
+
est_i = clone(est)
|
|
1151
|
+
est_i.set_params(
|
|
1152
|
+
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
1153
|
+
)
|
|
1154
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
1155
|
+
est_i.n_outputs_ = self.n_outputs_
|
|
1156
|
+
|
|
1157
|
+
tree_i_state_class = daal4py.getTreeState(self.daal_model_, i)
|
|
1158
|
+
tree_i_state_dict = {
|
|
1159
|
+
"max_depth": tree_i_state_class.max_depth,
|
|
1160
|
+
"node_count": tree_i_state_class.node_count,
|
|
1161
|
+
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
1162
|
+
"values": tree_i_state_class.value_ar,
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
est_i.tree_ = Tree(
|
|
1166
|
+
self.n_features_in_, np.array([1], dtype=np.intp), self.n_outputs_
|
|
1167
|
+
)
|
|
1168
|
+
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
1169
|
+
estimators_.append(est_i)
|
|
1170
|
+
|
|
1171
|
+
return estimators_
|
|
1172
|
+
|
|
1173
|
+
def _daal_fit_regressor(self, X, y, sample_weight=None):
|
|
1174
|
+
self.n_features_in_ = X.shape[1]
|
|
1175
|
+
|
|
1176
|
+
rs_ = check_random_state(self.random_state)
|
|
1177
|
+
|
|
1178
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
1179
|
+
raise ValueError(
|
|
1180
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
1181
|
+
"Either switch to `bootstrap=True` or set "
|
|
1182
|
+
"`max_sample=None`."
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
if not self.bootstrap and self.oob_score:
|
|
1186
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
1187
|
+
|
|
1188
|
+
seed_ = rs_.randint(0, np.iinfo("i").max)
|
|
1189
|
+
|
|
1190
|
+
daal_engine = daal4py.engines_mt19937(seed=seed_, fptype=getFPType(X))
|
|
1191
|
+
|
|
1192
|
+
features_per_node = _to_absolute_max_features(
|
|
1193
|
+
self.max_features, X.shape[1], is_classification=False
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
n_samples_bootstrap = _get_n_samples_bootstrap(
|
|
1197
|
+
n_samples=X.shape[0], max_samples=self.max_samples
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
if sample_weight is not None:
|
|
1201
|
+
if hasattr(sample_weight, "__array__"):
|
|
1202
|
+
sample_weight[sample_weight == 0.0] = 1.0
|
|
1203
|
+
sample_weight = [sample_weight]
|
|
1204
|
+
|
|
1205
|
+
parameters = {
|
|
1206
|
+
"bootstrap": bool(self.bootstrap),
|
|
1207
|
+
"engine": daal_engine,
|
|
1208
|
+
"featuresPerNode": features_per_node,
|
|
1209
|
+
"fptype": getFPType(X),
|
|
1210
|
+
"impurityThreshold": float(self.min_impurity_split or 0.0),
|
|
1211
|
+
"maxBins": self.maxBins,
|
|
1212
|
+
"maxLeafNodes": self.max_leaf_nodes or 0,
|
|
1213
|
+
"maxTreeDepth": self.max_depth or 0,
|
|
1214
|
+
"memorySavingMode": False,
|
|
1215
|
+
"method": "hist",
|
|
1216
|
+
"minBinSize": self.minBinSize,
|
|
1217
|
+
"minImpurityDecreaseInSplitNode": self.min_impurity_decrease,
|
|
1218
|
+
"minWeightFractionInLeafNode": self.min_weight_fraction_leaf,
|
|
1219
|
+
"nTrees": int(self.n_estimators),
|
|
1220
|
+
"observationsPerTreeFraction": 1.0,
|
|
1221
|
+
"resultsToCompute": "",
|
|
1222
|
+
"varImportance": "MDI",
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
1226
|
+
parameters["minObservationsInSplitNode"] = self.min_samples_split
|
|
1227
|
+
else:
|
|
1228
|
+
parameters["minObservationsInSplitNode"] = ceil(
|
|
1229
|
+
self.min_samples_split * X.shape[0]
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
1233
|
+
parameters["minObservationsInLeafNode"] = self.min_samples_leaf
|
|
1234
|
+
else:
|
|
1235
|
+
parameters["minObservationsInLeafNode"] = ceil(
|
|
1236
|
+
self.min_samples_leaf * X.shape[0]
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
if self.bootstrap:
|
|
1240
|
+
parameters["observationsPerTreeFraction"] = n_samples_bootstrap
|
|
1241
|
+
if self.oob_score:
|
|
1242
|
+
parameters["resultsToCompute"] = (
|
|
1243
|
+
"computeOutOfBagErrorR2|computeOutOfBagErrorPrediction"
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
if daal_check_version((2023, "P", 200)):
|
|
1247
|
+
parameters["binningStrategy"] = self.binningStrategy
|
|
1248
|
+
|
|
1249
|
+
# create algorithm
|
|
1250
|
+
dfr_algorithm = daal4py.decision_forest_regression_training(**parameters)
|
|
1251
|
+
|
|
1252
|
+
self._cached_estimators_ = None
|
|
1253
|
+
|
|
1254
|
+
dfr_trainingResult = dfr_algorithm.compute(X, y, sample_weight)
|
|
1255
|
+
|
|
1256
|
+
# get resulting model
|
|
1257
|
+
model = dfr_trainingResult.model
|
|
1258
|
+
self.daal_model_ = model
|
|
1259
|
+
|
|
1260
|
+
if self.oob_score:
|
|
1261
|
+
self.oob_score_ = dfr_trainingResult.outOfBagErrorR2[0][0]
|
|
1262
|
+
self.oob_prediction_ = dfr_trainingResult.outOfBagErrorPrediction.squeeze(
|
|
1263
|
+
axis=1
|
|
1264
|
+
)
|
|
1265
|
+
if self.oob_prediction_.shape[-1] == 1:
|
|
1266
|
+
self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1)
|
|
1267
|
+
|
|
1268
|
+
return self
|
|
1269
|
+
|
|
1270
|
+
def _daal_predict_regressor(self, X):
|
|
1271
|
+
if X.shape[1] != self.n_features_in_:
|
|
1272
|
+
raise ValueError(
|
|
1273
|
+
(
|
|
1274
|
+
f"X has {X.shape[1]} features, "
|
|
1275
|
+
f"but RandomForestRegressor is expecting "
|
|
1276
|
+
f"{self.n_features_in_} features as input"
|
|
1277
|
+
)
|
|
1278
|
+
)
|
|
1279
|
+
X_fptype = getFPType(X)
|
|
1280
|
+
dfr_alg = daal4py.decision_forest_regression_prediction(fptype=X_fptype)
|
|
1281
|
+
dfr_predictionResult = dfr_alg.compute(X, self.daal_model_)
|
|
1282
|
+
|
|
1283
|
+
pred = dfr_predictionResult.prediction
|
|
1284
|
+
|
|
1285
|
+
return pred.ravel()
|