scikit-learn-intelex 2024.1.0__py310-none-win_amd64.whl → 2025.1.0__py310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/_daal4py.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +248 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition}/__init__.py +3 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -29
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +4 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +236 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/_models_info.py +13 -22
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/test_patching.py +10 -42
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/utils/_launch_algorithms.py +4 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_config.py +54 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_device_offload.py +222 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp310-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +564 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +76 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +154 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +126 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +414 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +727 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +250 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +767 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +176 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/test_common.py +57 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +102 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +81 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_dpep_helpers.py +56 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/validation.py +440 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__init__.py +10 -7
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_config.py +22 -16
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +126 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_utils.py +27 -4
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +230 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +19 -10
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +395 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +159 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +398 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +425 -0
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +25 -9
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +241 -60
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +250 -188
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +39 -21
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +482 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +425 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +341 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +194 -133
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +134 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +4 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +236 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +53 -6
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +51 -155
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +46 -149
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +55 -100
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +16 -18
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +1 -3
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +138 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +233 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model}/__init__.py +19 -19
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +424 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +3 -2
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +11 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -1
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +14 -18
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +339 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +172 -78
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +74 -70
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +170 -77
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +66 -66
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +390 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +123 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +379 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +276 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +108 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +385 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +321 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +44 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +371 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +82 -0
- scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/METADATA +231 -230
- scikit_learn_intelex-2025.1.0.dist-info/RECORD +257 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/WHEEL +1 -1
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -223
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -388
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -82
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -376
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -98
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -376
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -188
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -225
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -227
- scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
- scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -20,13 +20,16 @@ from abc import ABC
|
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
from scipy import sparse as sp
|
|
23
|
-
from sklearn.base import clone
|
|
24
|
-
from sklearn.ensemble import ExtraTreesClassifier as
|
|
25
|
-
from sklearn.ensemble import ExtraTreesRegressor as
|
|
26
|
-
from sklearn.ensemble import RandomForestClassifier as
|
|
27
|
-
from sklearn.ensemble import RandomForestRegressor as
|
|
23
|
+
from sklearn.base import BaseEstimator, clone
|
|
24
|
+
from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
|
|
25
|
+
from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
|
|
26
|
+
from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
|
|
27
|
+
from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
|
|
28
|
+
from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
|
|
29
|
+
from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
|
|
28
30
|
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
29
31
|
from sklearn.exceptions import DataConversionWarning
|
|
32
|
+
from sklearn.metrics import accuracy_score, r2_score
|
|
30
33
|
from sklearn.tree import (
|
|
31
34
|
DecisionTreeClassifier,
|
|
32
35
|
DecisionTreeRegressor,
|
|
@@ -36,71 +39,59 @@ from sklearn.tree import (
|
|
|
36
39
|
from sklearn.tree._tree import Tree
|
|
37
40
|
from sklearn.utils import check_random_state, deprecated
|
|
38
41
|
from sklearn.utils.validation import (
|
|
42
|
+
_check_sample_weight,
|
|
39
43
|
check_array,
|
|
40
|
-
check_consistent_length,
|
|
41
44
|
check_is_fitted,
|
|
42
45
|
check_X_y,
|
|
43
46
|
)
|
|
44
47
|
|
|
48
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
45
49
|
from daal4py.sklearn._utils import (
|
|
46
50
|
check_tree_nodes,
|
|
47
|
-
control_n_jobs,
|
|
48
51
|
daal_check_version,
|
|
49
|
-
run_with_n_jobs,
|
|
50
52
|
sklearn_check_version,
|
|
51
53
|
)
|
|
52
54
|
from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
53
55
|
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
54
56
|
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
55
57
|
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
56
|
-
|
|
57
|
-
# try catch needed for changes in structures observed in Scikit-learn around v0.22
|
|
58
|
-
try:
|
|
59
|
-
from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
|
|
60
|
-
from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
|
|
61
|
-
except ModuleNotFoundError:
|
|
62
|
-
from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
|
|
63
|
-
from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
|
|
64
|
-
|
|
65
58
|
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
66
59
|
from onedal.utils import _num_features, _num_samples
|
|
60
|
+
from sklearnex import get_hyperparameters
|
|
61
|
+
from sklearnex._utils import register_hyperparameters
|
|
67
62
|
|
|
68
|
-
from .._config import get_config
|
|
69
63
|
from .._device_offload import dispatch, wrap_output_data
|
|
70
64
|
from .._utils import PatchingConditionsChain
|
|
65
|
+
from ..utils._array_api import get_namespace
|
|
71
66
|
|
|
72
67
|
if sklearn_check_version("1.2"):
|
|
73
68
|
from sklearn.utils._param_validation import Interval
|
|
74
69
|
if sklearn_check_version("1.4"):
|
|
75
70
|
from daal4py.sklearn.utils import _assert_all_finite
|
|
76
71
|
|
|
72
|
+
if sklearn_check_version("1.6"):
|
|
73
|
+
from sklearn.utils.validation import validate_data
|
|
74
|
+
else:
|
|
75
|
+
validate_data = BaseEstimator._validate_data
|
|
76
|
+
|
|
77
77
|
|
|
78
78
|
class BaseForest(ABC):
|
|
79
79
|
_onedal_factory = None
|
|
80
80
|
|
|
81
|
-
@run_with_n_jobs
|
|
82
81
|
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
X, y = check_X_y(
|
|
94
|
-
X,
|
|
95
|
-
y,
|
|
96
|
-
accept_sparse=False,
|
|
97
|
-
dtype=[np.float64, np.float32],
|
|
98
|
-
multi_output=False,
|
|
99
|
-
force_all_finite=False,
|
|
100
|
-
)
|
|
82
|
+
X, y = validate_data(
|
|
83
|
+
self,
|
|
84
|
+
X,
|
|
85
|
+
y,
|
|
86
|
+
multi_output=True,
|
|
87
|
+
accept_sparse=False,
|
|
88
|
+
dtype=[np.float64, np.float32],
|
|
89
|
+
force_all_finite=False,
|
|
90
|
+
ensure_2d=True,
|
|
91
|
+
)
|
|
101
92
|
|
|
102
93
|
if sample_weight is not None:
|
|
103
|
-
sample_weight =
|
|
94
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
104
95
|
|
|
105
96
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
106
97
|
warnings.warn(
|
|
@@ -120,8 +111,6 @@ class BaseForest(ABC):
|
|
|
120
111
|
|
|
121
112
|
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
122
113
|
|
|
123
|
-
self.n_features_in_ = X.shape[1]
|
|
124
|
-
|
|
125
114
|
if expanded_class_weight is not None:
|
|
126
115
|
if sample_weight is not None:
|
|
127
116
|
sample_weight = sample_weight * expanded_class_weight
|
|
@@ -137,7 +126,9 @@ class BaseForest(ABC):
|
|
|
137
126
|
"min_samples_split": self.min_samples_split,
|
|
138
127
|
"min_samples_leaf": self.min_samples_leaf,
|
|
139
128
|
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
140
|
-
"max_features": self.
|
|
129
|
+
"max_features": self._to_absolute_max_features(
|
|
130
|
+
self.max_features, self.n_features_in_
|
|
131
|
+
),
|
|
141
132
|
"max_leaf_nodes": self.max_leaf_nodes,
|
|
142
133
|
"min_impurity_decrease": self.min_impurity_decrease,
|
|
143
134
|
"bootstrap": self.bootstrap,
|
|
@@ -175,15 +166,6 @@ class BaseForest(ABC):
|
|
|
175
166
|
|
|
176
167
|
return self
|
|
177
168
|
|
|
178
|
-
def _fit_proba(self, X, y, sample_weight=None, queue=None):
|
|
179
|
-
params = self.get_params()
|
|
180
|
-
self.__class__(**params)
|
|
181
|
-
|
|
182
|
-
# We use stock metaestimators below, so the only way
|
|
183
|
-
# to pass a queue is using config_context.
|
|
184
|
-
cfg = get_config()
|
|
185
|
-
cfg["target_offload"] = queue
|
|
186
|
-
|
|
187
169
|
def _save_attributes(self):
|
|
188
170
|
if self.oob_score:
|
|
189
171
|
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
@@ -206,8 +188,45 @@ class BaseForest(ABC):
|
|
|
206
188
|
self._validate_estimator()
|
|
207
189
|
return self
|
|
208
190
|
|
|
209
|
-
|
|
210
|
-
|
|
191
|
+
def _to_absolute_max_features(self, max_features, n_features):
|
|
192
|
+
if max_features is None:
|
|
193
|
+
return n_features
|
|
194
|
+
if isinstance(max_features, str):
|
|
195
|
+
if max_features == "auto":
|
|
196
|
+
if not sklearn_check_version("1.3"):
|
|
197
|
+
if sklearn_check_version("1.1"):
|
|
198
|
+
warnings.warn(
|
|
199
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
200
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
201
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
202
|
+
"parameter as it is also the default value for "
|
|
203
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
204
|
+
FutureWarning,
|
|
205
|
+
)
|
|
206
|
+
return (
|
|
207
|
+
max(1, int(np.sqrt(n_features)))
|
|
208
|
+
if isinstance(self, ForestClassifier)
|
|
209
|
+
else n_features
|
|
210
|
+
)
|
|
211
|
+
if max_features == "sqrt":
|
|
212
|
+
return max(1, int(np.sqrt(n_features)))
|
|
213
|
+
if max_features == "log2":
|
|
214
|
+
return max(1, int(np.log2(n_features)))
|
|
215
|
+
allowed_string_values = (
|
|
216
|
+
'"sqrt" or "log2"'
|
|
217
|
+
if sklearn_check_version("1.3")
|
|
218
|
+
else '"auto", "sqrt" or "log2"'
|
|
219
|
+
)
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Invalid value for max_features. Allowed string "
|
|
222
|
+
f"values are {allowed_string_values}."
|
|
223
|
+
)
|
|
224
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
225
|
+
return max_features
|
|
226
|
+
if max_features > 0.0:
|
|
227
|
+
return max(1, int(max_features * n_features))
|
|
228
|
+
return 0
|
|
229
|
+
|
|
211
230
|
def _check_parameters(self):
|
|
212
231
|
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
213
232
|
if not 1 <= self.min_samples_leaf:
|
|
@@ -283,38 +302,6 @@ class BaseForest(ABC):
|
|
|
283
302
|
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
284
303
|
)
|
|
285
304
|
|
|
286
|
-
def check_sample_weight(self, sample_weight, X, dtype=None):
|
|
287
|
-
n_samples = _num_samples(X)
|
|
288
|
-
|
|
289
|
-
if dtype is not None and dtype not in [np.float32, np.float64]:
|
|
290
|
-
dtype = np.float64
|
|
291
|
-
|
|
292
|
-
if sample_weight is None:
|
|
293
|
-
sample_weight = np.ones(n_samples, dtype=dtype)
|
|
294
|
-
elif isinstance(sample_weight, numbers.Number):
|
|
295
|
-
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
|
|
296
|
-
else:
|
|
297
|
-
if dtype is None:
|
|
298
|
-
dtype = [np.float64, np.float32]
|
|
299
|
-
sample_weight = check_array(
|
|
300
|
-
sample_weight,
|
|
301
|
-
accept_sparse=False,
|
|
302
|
-
ensure_2d=False,
|
|
303
|
-
dtype=dtype,
|
|
304
|
-
order="C",
|
|
305
|
-
force_all_finite=False,
|
|
306
|
-
)
|
|
307
|
-
if sample_weight.ndim != 1:
|
|
308
|
-
raise ValueError("Sample weights must be 1D array or scalar")
|
|
309
|
-
|
|
310
|
-
if sample_weight.shape != (n_samples,):
|
|
311
|
-
raise ValueError(
|
|
312
|
-
"sample_weight.shape == {}, expected {}!".format(
|
|
313
|
-
sample_weight.shape, (n_samples,)
|
|
314
|
-
)
|
|
315
|
-
)
|
|
316
|
-
return sample_weight
|
|
317
|
-
|
|
318
305
|
@property
|
|
319
306
|
def estimators_(self):
|
|
320
307
|
if hasattr(self, "_cached_estimators_"):
|
|
@@ -415,7 +402,7 @@ class BaseForest(ABC):
|
|
|
415
402
|
self.estimator = estimator
|
|
416
403
|
|
|
417
404
|
|
|
418
|
-
class ForestClassifier(
|
|
405
|
+
class ForestClassifier(_sklearn_ForestClassifier, BaseForest):
|
|
419
406
|
# Surprisingly, even though scikit-learn warns against using
|
|
420
407
|
# their ForestClassifier directly, it actually has a more stable
|
|
421
408
|
# API than the user-facing objects (over time). If they change it
|
|
@@ -455,14 +442,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
455
442
|
|
|
456
443
|
# The estimator is checked against the class attribute for conformance.
|
|
457
444
|
# This should only trigger if the user uses this class directly.
|
|
458
|
-
if (
|
|
459
|
-
self.
|
|
460
|
-
and self._onedal_factory != onedal_RandomForestClassifier
|
|
445
|
+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
|
|
446
|
+
self._onedal_factory, onedal_RandomForestClassifier
|
|
461
447
|
):
|
|
462
448
|
self._onedal_factory = onedal_RandomForestClassifier
|
|
463
|
-
elif (
|
|
464
|
-
self.
|
|
465
|
-
and self._onedal_factory != onedal_ExtraTreesClassifier
|
|
449
|
+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
|
|
450
|
+
self._onedal_factory, onedal_ExtraTreesClassifier
|
|
466
451
|
):
|
|
467
452
|
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
468
453
|
|
|
@@ -481,7 +466,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
481
466
|
"fit",
|
|
482
467
|
{
|
|
483
468
|
"onedal": self.__class__._onedal_fit,
|
|
484
|
-
"sklearn":
|
|
469
|
+
"sklearn": _sklearn_ForestClassifier.fit,
|
|
485
470
|
},
|
|
486
471
|
X,
|
|
487
472
|
y,
|
|
@@ -554,18 +539,14 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
554
539
|
)
|
|
555
540
|
|
|
556
541
|
if patching_status.get_status():
|
|
557
|
-
|
|
558
|
-
X,
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
)
|
|
566
|
-
else:
|
|
567
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
568
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
542
|
+
X, y = check_X_y(
|
|
543
|
+
X,
|
|
544
|
+
y,
|
|
545
|
+
multi_output=True,
|
|
546
|
+
accept_sparse=True,
|
|
547
|
+
dtype=[np.float64, np.float32],
|
|
548
|
+
force_all_finite=False,
|
|
549
|
+
)
|
|
569
550
|
|
|
570
551
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
571
552
|
warnings.warn(
|
|
@@ -619,12 +600,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
619
600
|
|
|
620
601
|
@wrap_output_data
|
|
621
602
|
def predict(self, X):
|
|
603
|
+
check_is_fitted(self)
|
|
622
604
|
return dispatch(
|
|
623
605
|
self,
|
|
624
606
|
"predict",
|
|
625
607
|
{
|
|
626
608
|
"onedal": self.__class__._onedal_predict,
|
|
627
|
-
"sklearn":
|
|
609
|
+
"sklearn": _sklearn_ForestClassifier.predict,
|
|
628
610
|
},
|
|
629
611
|
X,
|
|
630
612
|
)
|
|
@@ -634,34 +616,50 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
634
616
|
# TODO:
|
|
635
617
|
# _check_proba()
|
|
636
618
|
# self._check_proba()
|
|
637
|
-
|
|
638
|
-
self._check_feature_names(X, reset=False)
|
|
639
|
-
if hasattr(self, "n_features_in_"):
|
|
640
|
-
try:
|
|
641
|
-
num_features = _num_features(X)
|
|
642
|
-
except TypeError:
|
|
643
|
-
num_features = _num_samples(X)
|
|
644
|
-
if num_features != self.n_features_in_:
|
|
645
|
-
raise ValueError(
|
|
646
|
-
(
|
|
647
|
-
f"X has {num_features} features, "
|
|
648
|
-
f"but {self.__class__.__name__} is expecting "
|
|
649
|
-
f"{self.n_features_in_} features as input"
|
|
650
|
-
)
|
|
651
|
-
)
|
|
619
|
+
check_is_fitted(self)
|
|
652
620
|
return dispatch(
|
|
653
621
|
self,
|
|
654
622
|
"predict_proba",
|
|
655
623
|
{
|
|
656
624
|
"onedal": self.__class__._onedal_predict_proba,
|
|
657
|
-
"sklearn":
|
|
625
|
+
"sklearn": _sklearn_ForestClassifier.predict_proba,
|
|
658
626
|
},
|
|
659
627
|
X,
|
|
660
628
|
)
|
|
661
629
|
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
630
|
+
def predict_log_proba(self, X):
|
|
631
|
+
xp, _ = get_namespace(X)
|
|
632
|
+
proba = self.predict_proba(X)
|
|
633
|
+
|
|
634
|
+
if self.n_outputs_ == 1:
|
|
635
|
+
return xp.log(proba)
|
|
636
|
+
|
|
637
|
+
else:
|
|
638
|
+
for k in range(self.n_outputs_):
|
|
639
|
+
proba[k] = xp.log(proba[k])
|
|
640
|
+
|
|
641
|
+
return proba
|
|
642
|
+
|
|
643
|
+
@wrap_output_data
|
|
644
|
+
def score(self, X, y, sample_weight=None):
|
|
645
|
+
check_is_fitted(self)
|
|
646
|
+
return dispatch(
|
|
647
|
+
self,
|
|
648
|
+
"score",
|
|
649
|
+
{
|
|
650
|
+
"onedal": self.__class__._onedal_score,
|
|
651
|
+
"sklearn": _sklearn_ForestClassifier.score,
|
|
652
|
+
},
|
|
653
|
+
X,
|
|
654
|
+
y,
|
|
655
|
+
sample_weight=sample_weight,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
|
|
659
|
+
predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
|
|
660
|
+
predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
|
|
661
|
+
predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
662
|
+
score.__doc__ = _sklearn_ForestClassifier.score.__doc__
|
|
665
663
|
|
|
666
664
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
667
665
|
class_name = self.__class__.__name__
|
|
@@ -688,7 +686,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
688
686
|
]
|
|
689
687
|
)
|
|
690
688
|
|
|
691
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
689
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
692
690
|
X = data[0]
|
|
693
691
|
|
|
694
692
|
patching_status.and_conditions(
|
|
@@ -749,11 +747,15 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
749
747
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
750
748
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
751
749
|
),
|
|
752
|
-
(
|
|
750
|
+
(
|
|
751
|
+
not self.oob_score,
|
|
752
|
+
"oob_scores using r2 or accuracy not implemented.",
|
|
753
|
+
),
|
|
754
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
753
755
|
]
|
|
754
756
|
)
|
|
755
757
|
|
|
756
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
758
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
757
759
|
X = data[0]
|
|
758
760
|
|
|
759
761
|
patching_status.and_conditions(
|
|
@@ -787,34 +789,69 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
|
787
789
|
|
|
788
790
|
return patching_status
|
|
789
791
|
|
|
790
|
-
@run_with_n_jobs
|
|
791
792
|
def _onedal_predict(self, X, queue=None):
|
|
792
|
-
X = check_array(
|
|
793
|
-
X,
|
|
794
|
-
dtype=[np.float64, np.float32],
|
|
795
|
-
force_all_finite=False,
|
|
796
|
-
) # Warning, order of dtype matters
|
|
797
|
-
check_is_fitted(self, "_onedal_estimator")
|
|
798
793
|
|
|
799
794
|
if sklearn_check_version("1.0"):
|
|
800
|
-
|
|
795
|
+
X = validate_data(
|
|
796
|
+
self,
|
|
797
|
+
X,
|
|
798
|
+
dtype=[np.float64, np.float32],
|
|
799
|
+
force_all_finite=False,
|
|
800
|
+
reset=False,
|
|
801
|
+
ensure_2d=True,
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
X = check_array(
|
|
805
|
+
X,
|
|
806
|
+
dtype=[np.float64, np.float32],
|
|
807
|
+
force_all_finite=False,
|
|
808
|
+
) # Warning, order of dtype matters
|
|
809
|
+
if hasattr(self, "n_features_in_"):
|
|
810
|
+
try:
|
|
811
|
+
num_features = _num_features(X)
|
|
812
|
+
except TypeError:
|
|
813
|
+
num_features = _num_samples(X)
|
|
814
|
+
if num_features != self.n_features_in_:
|
|
815
|
+
raise ValueError(
|
|
816
|
+
(
|
|
817
|
+
f"X has {num_features} features, "
|
|
818
|
+
f"but {self.__class__.__name__} is expecting "
|
|
819
|
+
f"{self.n_features_in_} features as input"
|
|
820
|
+
)
|
|
821
|
+
)
|
|
822
|
+
self._check_n_features(X, reset=False)
|
|
801
823
|
|
|
802
824
|
res = self._onedal_estimator.predict(X, queue=queue)
|
|
803
825
|
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
804
826
|
|
|
805
|
-
@run_with_n_jobs
|
|
806
827
|
def _onedal_predict_proba(self, X, queue=None):
|
|
807
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
808
|
-
check_is_fitted(self, "_onedal_estimator")
|
|
809
828
|
|
|
810
|
-
if sklearn_check_version("0.23"):
|
|
811
|
-
self._check_n_features(X, reset=False)
|
|
812
829
|
if sklearn_check_version("1.0"):
|
|
813
|
-
|
|
830
|
+
X = validate_data(
|
|
831
|
+
self,
|
|
832
|
+
X,
|
|
833
|
+
dtype=[np.float64, np.float32],
|
|
834
|
+
force_all_finite=False,
|
|
835
|
+
reset=False,
|
|
836
|
+
ensure_2d=True,
|
|
837
|
+
)
|
|
838
|
+
else:
|
|
839
|
+
X = check_array(
|
|
840
|
+
X,
|
|
841
|
+
dtype=[np.float64, np.float32],
|
|
842
|
+
force_all_finite=False,
|
|
843
|
+
) # Warning, order of dtype matters
|
|
844
|
+
self._check_n_features(X, reset=False)
|
|
845
|
+
|
|
814
846
|
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
815
847
|
|
|
848
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
849
|
+
return accuracy_score(
|
|
850
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
851
|
+
)
|
|
852
|
+
|
|
816
853
|
|
|
817
|
-
class ForestRegressor(
|
|
854
|
+
class ForestRegressor(_sklearn_ForestRegressor, BaseForest):
|
|
818
855
|
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
819
856
|
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
820
857
|
|
|
@@ -847,14 +884,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
847
884
|
|
|
848
885
|
# The splitter is checked against the class attribute for conformance
|
|
849
886
|
# This should only trigger if the user uses this class directly.
|
|
850
|
-
if (
|
|
851
|
-
self.
|
|
852
|
-
and self._onedal_factory != onedal_RandomForestRegressor
|
|
887
|
+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
|
|
888
|
+
self._onedal_factory, onedal_RandomForestRegressor
|
|
853
889
|
):
|
|
854
890
|
self._onedal_factory = onedal_RandomForestRegressor
|
|
855
|
-
elif (
|
|
856
|
-
self.
|
|
857
|
-
and self._onedal_factory != onedal_ExtraTreesRegressor
|
|
891
|
+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
|
|
892
|
+
self._onedal_factory, onedal_ExtraTreesRegressor
|
|
858
893
|
):
|
|
859
894
|
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
860
895
|
|
|
@@ -924,18 +959,14 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
924
959
|
)
|
|
925
960
|
|
|
926
961
|
if patching_status.get_status():
|
|
927
|
-
|
|
928
|
-
X,
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
)
|
|
936
|
-
else:
|
|
937
|
-
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
938
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
962
|
+
X, y = check_X_y(
|
|
963
|
+
X,
|
|
964
|
+
y,
|
|
965
|
+
multi_output=True,
|
|
966
|
+
accept_sparse=True,
|
|
967
|
+
dtype=[np.float64, np.float32],
|
|
968
|
+
force_all_finite=False,
|
|
969
|
+
)
|
|
939
970
|
|
|
940
971
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
941
972
|
warnings.warn(
|
|
@@ -1010,7 +1041,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1010
1041
|
]
|
|
1011
1042
|
)
|
|
1012
1043
|
|
|
1013
|
-
elif method_name
|
|
1044
|
+
elif method_name in ["predict", "score"]:
|
|
1014
1045
|
X = data[0]
|
|
1015
1046
|
|
|
1016
1047
|
patching_status.and_conditions(
|
|
@@ -1060,11 +1091,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1060
1091
|
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1061
1092
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1062
1093
|
),
|
|
1063
|
-
(
|
|
1094
|
+
(not self.oob_score, "oob_score value is not sklearn conformant."),
|
|
1095
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
1064
1096
|
]
|
|
1065
1097
|
)
|
|
1066
1098
|
|
|
1067
|
-
elif method_name
|
|
1099
|
+
elif method_name in ["predict", "score"]:
|
|
1068
1100
|
X = data[0]
|
|
1069
1101
|
|
|
1070
1102
|
patching_status.and_conditions(
|
|
@@ -1096,25 +1128,37 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1096
1128
|
|
|
1097
1129
|
return patching_status
|
|
1098
1130
|
|
|
1099
|
-
@run_with_n_jobs
|
|
1100
1131
|
def _onedal_predict(self, X, queue=None):
|
|
1101
|
-
X = check_array(
|
|
1102
|
-
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1103
|
-
) # Warning, order of dtype matters
|
|
1104
1132
|
check_is_fitted(self, "_onedal_estimator")
|
|
1105
1133
|
|
|
1106
1134
|
if sklearn_check_version("1.0"):
|
|
1107
|
-
|
|
1135
|
+
X = validate_data(
|
|
1136
|
+
self,
|
|
1137
|
+
X,
|
|
1138
|
+
dtype=[np.float64, np.float32],
|
|
1139
|
+
force_all_finite=False,
|
|
1140
|
+
reset=False,
|
|
1141
|
+
ensure_2d=True,
|
|
1142
|
+
) # Warning, order of dtype matters
|
|
1143
|
+
else:
|
|
1144
|
+
X = check_array(
|
|
1145
|
+
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1146
|
+
) # Warning, order of dtype matters
|
|
1108
1147
|
|
|
1109
1148
|
return self._onedal_estimator.predict(X, queue=queue)
|
|
1110
1149
|
|
|
1150
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
1151
|
+
return r2_score(
|
|
1152
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1111
1155
|
def fit(self, X, y, sample_weight=None):
|
|
1112
1156
|
dispatch(
|
|
1113
1157
|
self,
|
|
1114
1158
|
"fit",
|
|
1115
1159
|
{
|
|
1116
1160
|
"onedal": self.__class__._onedal_fit,
|
|
1117
|
-
"sklearn":
|
|
1161
|
+
"sklearn": _sklearn_ForestRegressor.fit,
|
|
1118
1162
|
},
|
|
1119
1163
|
X,
|
|
1120
1164
|
y,
|
|
@@ -1124,28 +1168,46 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
|
1124
1168
|
|
|
1125
1169
|
@wrap_output_data
|
|
1126
1170
|
def predict(self, X):
|
|
1171
|
+
check_is_fitted(self)
|
|
1127
1172
|
return dispatch(
|
|
1128
1173
|
self,
|
|
1129
1174
|
"predict",
|
|
1130
1175
|
{
|
|
1131
1176
|
"onedal": self.__class__._onedal_predict,
|
|
1132
|
-
"sklearn":
|
|
1177
|
+
"sklearn": _sklearn_ForestRegressor.predict,
|
|
1133
1178
|
},
|
|
1134
1179
|
X,
|
|
1135
1180
|
)
|
|
1136
1181
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1182
|
+
@wrap_output_data
|
|
1183
|
+
def score(self, X, y, sample_weight=None):
|
|
1184
|
+
check_is_fitted(self)
|
|
1185
|
+
return dispatch(
|
|
1186
|
+
self,
|
|
1187
|
+
"score",
|
|
1188
|
+
{
|
|
1189
|
+
"onedal": self.__class__._onedal_score,
|
|
1190
|
+
"sklearn": _sklearn_ForestRegressor.score,
|
|
1191
|
+
},
|
|
1192
|
+
X,
|
|
1193
|
+
y,
|
|
1194
|
+
sample_weight=sample_weight,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
|
|
1198
|
+
predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
|
|
1199
|
+
score.__doc__ = _sklearn_ForestRegressor.score.__doc__
|
|
1139
1200
|
|
|
1140
1201
|
|
|
1141
|
-
@
|
|
1202
|
+
@register_hyperparameters({"infer": get_hyperparameters("decision_forest", "infer")})
|
|
1203
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1142
1204
|
class RandomForestClassifier(ForestClassifier):
|
|
1143
|
-
__doc__ =
|
|
1205
|
+
__doc__ = _sklearn_RandomForestClassifier.__doc__
|
|
1144
1206
|
_onedal_factory = onedal_RandomForestClassifier
|
|
1145
1207
|
|
|
1146
1208
|
if sklearn_check_version("1.2"):
|
|
1147
1209
|
_parameter_constraints: dict = {
|
|
1148
|
-
**
|
|
1210
|
+
**_sklearn_RandomForestClassifier._parameter_constraints,
|
|
1149
1211
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1150
1212
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1151
1213
|
}
|
|
@@ -1348,14 +1410,14 @@ class RandomForestClassifier(ForestClassifier):
|
|
|
1348
1410
|
self.min_bin_size = min_bin_size
|
|
1349
1411
|
|
|
1350
1412
|
|
|
1351
|
-
@control_n_jobs
|
|
1413
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1352
1414
|
class RandomForestRegressor(ForestRegressor):
|
|
1353
|
-
__doc__ =
|
|
1415
|
+
__doc__ = _sklearn_RandomForestRegressor.__doc__
|
|
1354
1416
|
_onedal_factory = onedal_RandomForestRegressor
|
|
1355
1417
|
|
|
1356
1418
|
if sklearn_check_version("1.2"):
|
|
1357
1419
|
_parameter_constraints: dict = {
|
|
1358
|
-
**
|
|
1420
|
+
**_sklearn_RandomForestRegressor._parameter_constraints,
|
|
1359
1421
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1360
1422
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1361
1423
|
}
|
|
@@ -1549,14 +1611,14 @@ class RandomForestRegressor(ForestRegressor):
|
|
|
1549
1611
|
self.min_bin_size = min_bin_size
|
|
1550
1612
|
|
|
1551
1613
|
|
|
1552
|
-
@control_n_jobs
|
|
1614
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1553
1615
|
class ExtraTreesClassifier(ForestClassifier):
|
|
1554
|
-
__doc__ =
|
|
1616
|
+
__doc__ = _sklearn_ExtraTreesClassifier.__doc__
|
|
1555
1617
|
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1556
1618
|
|
|
1557
1619
|
if sklearn_check_version("1.2"):
|
|
1558
1620
|
_parameter_constraints: dict = {
|
|
1559
|
-
**
|
|
1621
|
+
**_sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1560
1622
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1561
1623
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1562
1624
|
}
|
|
@@ -1759,14 +1821,14 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1759
1821
|
self.min_bin_size = min_bin_size
|
|
1760
1822
|
|
|
1761
1823
|
|
|
1762
|
-
@control_n_jobs
|
|
1824
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "score"])
|
|
1763
1825
|
class ExtraTreesRegressor(ForestRegressor):
|
|
1764
|
-
__doc__ =
|
|
1826
|
+
__doc__ = _sklearn_ExtraTreesRegressor.__doc__
|
|
1765
1827
|
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1766
1828
|
|
|
1767
1829
|
if sklearn_check_version("1.2"):
|
|
1768
1830
|
_parameter_constraints: dict = {
|
|
1769
|
-
**
|
|
1831
|
+
**_sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1770
1832
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1771
1833
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1772
1834
|
}
|
|
@@ -1961,7 +2023,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1961
2023
|
|
|
1962
2024
|
|
|
1963
2025
|
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
2026
|
+
_sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
2027
|
+
_sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
2028
|
+
_sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
2029
|
+
_sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|