scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
- scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
- scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
- scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
- scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1199 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import warnings
|
|
19
|
+
from collections import deque
|
|
20
|
+
from copy import deepcopy
|
|
21
|
+
from tempfile import NamedTemporaryFile
|
|
22
|
+
from typing import Any, Deque, Dict, List, Optional, Tuple, Union
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
from .. import gbt_clf_model_builder, gbt_reg_model_builder
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CatBoostNode:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
split: Optional[Dict] = None,
|
|
33
|
+
value: Optional[List[float]] = None,
|
|
34
|
+
right: Optional[int] = None,
|
|
35
|
+
left: Optional[float] = None,
|
|
36
|
+
cover: Optional[float] = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
self.split = split
|
|
39
|
+
self.value = value
|
|
40
|
+
self.right = right
|
|
41
|
+
self.left = left
|
|
42
|
+
self.cover = cover
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CatBoostModelData:
|
|
46
|
+
"""Wrapper around the CatBoost model dump for easier access to properties"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, data):
|
|
49
|
+
self.__data = data
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def n_features(self):
|
|
53
|
+
return len(self.__data["features_info"]["float_features"])
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def grow_policy(self):
|
|
57
|
+
return self.__data["model_info"]["params"]["tree_learner_options"]["grow_policy"]
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def oblivious_trees(self):
|
|
61
|
+
return self.__data["oblivious_trees"]
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def trees(self):
|
|
65
|
+
return self.__data["trees"]
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def n_classes(self):
|
|
69
|
+
"""Number of classes, returns -1 if it's not a classification model"""
|
|
70
|
+
if "class_params" in self.__data["model_info"]:
|
|
71
|
+
return len(self.__data["model_info"]["class_params"]["class_to_label"])
|
|
72
|
+
return -1
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def is_classification(self):
|
|
76
|
+
return "class_params" in self.__data["model_info"]
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def has_categorical_features(self):
|
|
80
|
+
return "categorical_features" in self.__data["features_info"]
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def is_symmetric_tree(self):
|
|
84
|
+
return self.grow_policy == "SymmetricTree"
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def float_features(self):
|
|
88
|
+
return self.__data["features_info"]["float_features"]
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def n_iterations(self):
|
|
92
|
+
if self.is_symmetric_tree:
|
|
93
|
+
return len(self.oblivious_trees)
|
|
94
|
+
else:
|
|
95
|
+
return len(self.trees)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def scale(self):
|
|
99
|
+
return self.__data["scale_and_bias"][0]
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def default_left(self):
|
|
103
|
+
dpo = self.__data["model_info"]["params"]["data_processing_options"]
|
|
104
|
+
nan_mode = dpo["float_features_binarization"]["nan_mode"]
|
|
105
|
+
return int(nan_mode.lower() == "min")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class Node:
|
|
109
|
+
"""Helper class holding Tree Node information"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
cover: float,
|
|
114
|
+
is_leaf: bool,
|
|
115
|
+
default_left: bool,
|
|
116
|
+
feature: int,
|
|
117
|
+
value: float,
|
|
118
|
+
n_children: int = 0,
|
|
119
|
+
left_child: "Optional[Node]" = None,
|
|
120
|
+
right_child: "Optional[Node]" = None,
|
|
121
|
+
parent_id: Optional[int] = -1,
|
|
122
|
+
position: Optional[int] = -1,
|
|
123
|
+
) -> None:
|
|
124
|
+
self.cover = cover
|
|
125
|
+
self.is_leaf = is_leaf
|
|
126
|
+
self.default_left = default_left
|
|
127
|
+
self.__feature = feature
|
|
128
|
+
self.value = value
|
|
129
|
+
self.n_children = n_children
|
|
130
|
+
self.left_child = left_child
|
|
131
|
+
self.right_child = right_child
|
|
132
|
+
self.parent_id = parent_id
|
|
133
|
+
self.position = position
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def from_xgb_dict(
|
|
137
|
+
input_dict: Dict[str, Any], feature_names_to_indices: dict[str, int]
|
|
138
|
+
) -> "Node":
|
|
139
|
+
if "children" in input_dict:
|
|
140
|
+
left_child = Node.from_xgb_dict(
|
|
141
|
+
input_dict["children"][0], feature_names_to_indices
|
|
142
|
+
)
|
|
143
|
+
right_child = Node.from_xgb_dict(
|
|
144
|
+
input_dict["children"][1], feature_names_to_indices
|
|
145
|
+
)
|
|
146
|
+
n_children = 2 + left_child.n_children + right_child.n_children
|
|
147
|
+
else:
|
|
148
|
+
left_child = None
|
|
149
|
+
right_child = None
|
|
150
|
+
n_children = 0
|
|
151
|
+
is_leaf = "leaf" in input_dict
|
|
152
|
+
default_left = "yes" in input_dict and input_dict["yes"] == input_dict["missing"]
|
|
153
|
+
feature = input_dict.get("split")
|
|
154
|
+
if feature:
|
|
155
|
+
feature = feature_names_to_indices[feature]
|
|
156
|
+
return Node(
|
|
157
|
+
cover=input_dict["cover"],
|
|
158
|
+
is_leaf=is_leaf,
|
|
159
|
+
default_left=default_left,
|
|
160
|
+
feature=feature,
|
|
161
|
+
value=input_dict["leaf"] if is_leaf else input_dict["split_condition"],
|
|
162
|
+
n_children=n_children,
|
|
163
|
+
left_child=left_child,
|
|
164
|
+
right_child=right_child,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def from_lightgbm_dict(input_dict: Dict[str, Any]) -> "Node":
|
|
169
|
+
if "tree_structure" in input_dict:
|
|
170
|
+
tree = input_dict["tree_structure"]
|
|
171
|
+
else:
|
|
172
|
+
tree = input_dict
|
|
173
|
+
|
|
174
|
+
n_children = 0
|
|
175
|
+
if "left_child" in tree:
|
|
176
|
+
left_child = Node.from_lightgbm_dict(tree["left_child"])
|
|
177
|
+
n_children += 1 + left_child.n_children
|
|
178
|
+
else:
|
|
179
|
+
left_child = None
|
|
180
|
+
if "right_child" in tree:
|
|
181
|
+
right_child = Node.from_lightgbm_dict(tree["right_child"])
|
|
182
|
+
n_children += 1 + right_child.n_children
|
|
183
|
+
else:
|
|
184
|
+
right_child = None
|
|
185
|
+
|
|
186
|
+
is_leaf = "leaf_value" in tree
|
|
187
|
+
# get cover and value for leaf nodes or internal nodes
|
|
188
|
+
cover = tree.get("leaf_count", 0) or tree.get("internal_count", 0)
|
|
189
|
+
value = tree.get("leaf_value", 0) or tree.get("threshold", 0)
|
|
190
|
+
return Node(
|
|
191
|
+
cover=cover,
|
|
192
|
+
is_leaf=is_leaf,
|
|
193
|
+
default_left=tree.get("default_left", 0),
|
|
194
|
+
feature=tree.get("split_feature"),
|
|
195
|
+
value=value,
|
|
196
|
+
n_children=n_children,
|
|
197
|
+
left_child=left_child,
|
|
198
|
+
right_child=right_child,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def from_treelite_dict(dict_all_nodes: list[dict[str, Any]], node_id: int) -> "Node":
|
|
203
|
+
this_node = dict_all_nodes[node_id]
|
|
204
|
+
is_leaf = "leaf_value" in this_node
|
|
205
|
+
default_left = this_node.get("default_left", False)
|
|
206
|
+
|
|
207
|
+
n_children = 0
|
|
208
|
+
if "left_child" in this_node:
|
|
209
|
+
left_child = Node.from_treelite_dict(dict_all_nodes, this_node["left_child"])
|
|
210
|
+
n_children += 1 + left_child.n_children
|
|
211
|
+
else:
|
|
212
|
+
left_child = None
|
|
213
|
+
if "right_child" in this_node:
|
|
214
|
+
right_child = Node.from_treelite_dict(
|
|
215
|
+
dict_all_nodes, this_node["right_child"]
|
|
216
|
+
)
|
|
217
|
+
n_children += 1 + right_child.n_children
|
|
218
|
+
else:
|
|
219
|
+
right_child = None
|
|
220
|
+
|
|
221
|
+
value = this_node["leaf_value"] if is_leaf else this_node["threshold"]
|
|
222
|
+
if not is_leaf:
|
|
223
|
+
comp = this_node["comparison_op"]
|
|
224
|
+
if comp == "<=":
|
|
225
|
+
value = float(np.nextafter(value, np.inf))
|
|
226
|
+
elif comp in [">", ">="]:
|
|
227
|
+
left_child, right_child = right_child, left_child
|
|
228
|
+
default_left = not default_left
|
|
229
|
+
if comp == ">":
|
|
230
|
+
value = float(np.nextafter(value, -np.inf))
|
|
231
|
+
elif comp != "<":
|
|
232
|
+
raise TypeError(
|
|
233
|
+
f"Model to convert contains unsupported split type: {comp}."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return Node(
|
|
237
|
+
cover=this_node.get("sum_hess", 0.0),
|
|
238
|
+
is_leaf=is_leaf,
|
|
239
|
+
default_left=default_left,
|
|
240
|
+
feature=this_node.get("split_feature_id"),
|
|
241
|
+
value=value,
|
|
242
|
+
n_children=n_children,
|
|
243
|
+
left_child=left_child,
|
|
244
|
+
right_child=right_child,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def get_value_closest_float_downward(self) -> np.float64:
|
|
248
|
+
"""Get the closest exact fp value smaller than self.value"""
|
|
249
|
+
return np.nextafter(np.single(self.value), np.single(-np.inf))
|
|
250
|
+
|
|
251
|
+
def get_children(self) -> "Optional[Tuple[Node, Node]]":
|
|
252
|
+
if not self.left_child or not self.right_child:
|
|
253
|
+
assert self.is_leaf
|
|
254
|
+
else:
|
|
255
|
+
return (self.left_child, self.right_child)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def feature(self) -> int:
|
|
259
|
+
if isinstance(self.__feature, int):
|
|
260
|
+
return self.__feature
|
|
261
|
+
if isinstance(self.__feature, str) and self.__feature.isnumeric():
|
|
262
|
+
return int(self.__feature)
|
|
263
|
+
raise AttributeError(
|
|
264
|
+
f"Feature names must be integers (got ({type(self.__feature)}){self.__feature})"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class TreeView:
|
|
269
|
+
"""Helper class, treating a list of nodes as one tree"""
|
|
270
|
+
|
|
271
|
+
def __init__(self, tree_id: int, root_node: Node) -> None:
|
|
272
|
+
self.tree_id = tree_id
|
|
273
|
+
self.root_node = root_node
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def is_leaf(self) -> bool:
|
|
277
|
+
return self.root_node.is_leaf
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def value(self) -> float:
|
|
281
|
+
if not self.is_leaf:
|
|
282
|
+
raise AttributeError("Tree is not a leaf-only tree")
|
|
283
|
+
if self.root_node.value is None:
|
|
284
|
+
raise AttributeError("Tree is leaf-only but leaf node has no value")
|
|
285
|
+
return self.root_node.value
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def cover(self) -> float:
|
|
289
|
+
if not self.is_leaf:
|
|
290
|
+
raise AttributeError("Tree is not a leaf-only tree")
|
|
291
|
+
return self.root_node.cover
|
|
292
|
+
|
|
293
|
+
@property
|
|
294
|
+
def n_nodes(self) -> int:
|
|
295
|
+
return self.root_node.n_children + 1
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class TreeList(list):
|
|
299
|
+
"""Helper class that is able to extract all information required by the
|
|
300
|
+
model builders from various objects"""
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def from_xgb_booster(
|
|
304
|
+
booster, max_trees: int, feature_names_to_indices: dict[str, int]
|
|
305
|
+
) -> "TreeList":
|
|
306
|
+
"""
|
|
307
|
+
Load a TreeList from an xgb.Booster object
|
|
308
|
+
Note: We cannot type-hint the xgb.Booster without loading xgb as dependency in pyx code,
|
|
309
|
+
therefore not type hint is added.
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
# Note: in XGBoost, it's possible to use 'int' type for features that contain
|
|
313
|
+
# non-integer floating points. In such case, the training procedure and JSON
|
|
314
|
+
# export from XGBoost will not treat them any differently from 'q'-type
|
|
315
|
+
# (numeric) features, but the per-tree JSON text dumps used here will output
|
|
316
|
+
# a split threshold rounded to the nearest integer for those 'int' features,
|
|
317
|
+
# even if the booster internally has thresholds with decimal points and outputs
|
|
318
|
+
# them as such in the full-model JSON dumps. Hence the need for this override
|
|
319
|
+
# mechanism. If this behavior changes in XGBoost, then this conversion and
|
|
320
|
+
# override can be removed.
|
|
321
|
+
orig_feature_types = None
|
|
322
|
+
try:
|
|
323
|
+
if hasattr(booster, "feature_types"):
|
|
324
|
+
feature_types = booster.feature_types
|
|
325
|
+
orig_feature_types = deepcopy(feature_types)
|
|
326
|
+
if feature_types:
|
|
327
|
+
for i in range(len(feature_types)):
|
|
328
|
+
if feature_types[i] == "int":
|
|
329
|
+
feature_types[i] = "float"
|
|
330
|
+
booster.feature_types = feature_types
|
|
331
|
+
|
|
332
|
+
tl = TreeList()
|
|
333
|
+
dump = booster.get_dump(dump_format="json", with_stats=True)
|
|
334
|
+
finally:
|
|
335
|
+
if orig_feature_types:
|
|
336
|
+
booster.feature_types = orig_feature_types
|
|
337
|
+
for tree_id, raw_tree in enumerate(dump):
|
|
338
|
+
if max_trees > 0 and tree_id == max_trees:
|
|
339
|
+
break
|
|
340
|
+
raw_tree_parsed = json.loads(raw_tree)
|
|
341
|
+
root_node = Node.from_xgb_dict(raw_tree_parsed, feature_names_to_indices)
|
|
342
|
+
tl.append(TreeView(tree_id=tree_id, root_node=root_node))
|
|
343
|
+
|
|
344
|
+
return tl
|
|
345
|
+
|
|
346
|
+
@staticmethod
|
|
347
|
+
def from_lightgbm_booster_dump(dump: Dict[str, Any]) -> "TreeList":
|
|
348
|
+
"""
|
|
349
|
+
Load a TreeList from a lgbm Booster dump
|
|
350
|
+
Note: We cannot type-hint the the Model without loading lightgbm as dependency in pyx code,
|
|
351
|
+
therefore not type hint is added.
|
|
352
|
+
"""
|
|
353
|
+
tl = TreeList()
|
|
354
|
+
for tree_id, tree_dict in enumerate(dump["tree_info"]):
|
|
355
|
+
root_node = Node.from_lightgbm_dict(tree_dict)
|
|
356
|
+
tl.append(TreeView(tree_id=tree_id, root_node=root_node))
|
|
357
|
+
|
|
358
|
+
return tl
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def from_treelite_dict(tl_json: Dict[str, Any]) -> "TreeList":
|
|
362
|
+
tl = TreeList()
|
|
363
|
+
for tree_id, tree_dict in enumerate(tl_json["trees"]):
|
|
364
|
+
root_node = Node.from_treelite_dict(tree_dict["nodes"], 0)
|
|
365
|
+
tl.append(TreeView(tree_id=tree_id, root_node=root_node))
|
|
366
|
+
return tl
|
|
367
|
+
|
|
368
|
+
def __setitem__(self):
|
|
369
|
+
raise NotImplementedError(
|
|
370
|
+
"Use TreeList.from_*() methods to initialize a TreeList"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def get_lightgbm_params(booster):
|
|
375
|
+
return booster.dump_model()
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def get_xgboost_params(booster):
|
|
379
|
+
return json.loads(booster.save_config())
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def get_catboost_params(booster):
|
|
383
|
+
with NamedTemporaryFile() as fp:
|
|
384
|
+
booster.save_model(fp.name, "json")
|
|
385
|
+
fp.seek(0)
|
|
386
|
+
model_data = json.load(fp)
|
|
387
|
+
return model_data
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def get_gbt_model_from_tree_list(
|
|
391
|
+
tree_list: TreeList,
|
|
392
|
+
n_iterations: int,
|
|
393
|
+
is_regression: bool,
|
|
394
|
+
n_features: int,
|
|
395
|
+
n_classes: int,
|
|
396
|
+
base_score: Optional[Union[float, List[float]]] = None,
|
|
397
|
+
):
|
|
398
|
+
"""Return a GBT Model from TreeList"""
|
|
399
|
+
|
|
400
|
+
if is_regression:
|
|
401
|
+
mb = gbt_reg_model_builder(n_features=n_features, n_iterations=n_iterations)
|
|
402
|
+
else:
|
|
403
|
+
mb = gbt_clf_model_builder(
|
|
404
|
+
n_features=n_features, n_iterations=n_iterations, n_classes=n_classes
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
class_label = 0
|
|
408
|
+
for counter, tree in enumerate(tree_list, start=1):
|
|
409
|
+
# find out the number of nodes in the tree
|
|
410
|
+
if is_regression:
|
|
411
|
+
tree_id = mb.create_tree(tree.n_nodes)
|
|
412
|
+
else:
|
|
413
|
+
tree_id = mb.create_tree(n_nodes=tree.n_nodes, class_label=class_label)
|
|
414
|
+
|
|
415
|
+
# Note: starting from xgboost>=3.1.0, multi-class classification models have
|
|
416
|
+
# vector-valued intercepts. Since oneDAL doesn't support these, it instead
|
|
417
|
+
# adds the scores to all of the terminal leafs in the first tree.
|
|
418
|
+
if isinstance(base_score, list) and counter <= n_classes:
|
|
419
|
+
intercept_add = base_score[counter - 1]
|
|
420
|
+
else:
|
|
421
|
+
intercept_add = 0.0
|
|
422
|
+
|
|
423
|
+
if counter % n_iterations == 0:
|
|
424
|
+
class_label += 1
|
|
425
|
+
|
|
426
|
+
if tree.is_leaf:
|
|
427
|
+
mb.add_leaf(
|
|
428
|
+
tree_id=tree_id, response=tree.value + intercept_add, cover=tree.cover
|
|
429
|
+
)
|
|
430
|
+
continue
|
|
431
|
+
|
|
432
|
+
root_node = tree.root_node
|
|
433
|
+
parent_id = mb.add_split(
|
|
434
|
+
tree_id=tree_id,
|
|
435
|
+
feature_index=root_node.feature,
|
|
436
|
+
feature_value=root_node.get_value_closest_float_downward(),
|
|
437
|
+
cover=root_node.cover,
|
|
438
|
+
default_left=root_node.default_left,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# create queue
|
|
442
|
+
node_queue: Deque[Node] = deque()
|
|
443
|
+
children = root_node.get_children()
|
|
444
|
+
assert children is not None
|
|
445
|
+
for position, child in enumerate(children):
|
|
446
|
+
child.parent_id = parent_id
|
|
447
|
+
child.position = position
|
|
448
|
+
node_queue.append(child)
|
|
449
|
+
|
|
450
|
+
while node_queue:
|
|
451
|
+
node = node_queue.popleft()
|
|
452
|
+
assert node.parent_id != -1, "node.parent_id must not be -1"
|
|
453
|
+
assert node.position != -1, "node.position must not be -1"
|
|
454
|
+
|
|
455
|
+
if node.is_leaf:
|
|
456
|
+
mb.add_leaf(
|
|
457
|
+
tree_id=tree_id,
|
|
458
|
+
response=node.value + intercept_add,
|
|
459
|
+
cover=node.cover,
|
|
460
|
+
parent_id=node.parent_id,
|
|
461
|
+
position=node.position,
|
|
462
|
+
)
|
|
463
|
+
else:
|
|
464
|
+
parent_id = mb.add_split(
|
|
465
|
+
tree_id=tree_id,
|
|
466
|
+
feature_index=node.feature,
|
|
467
|
+
feature_value=node.get_value_closest_float_downward(),
|
|
468
|
+
cover=node.cover,
|
|
469
|
+
default_left=node.default_left,
|
|
470
|
+
parent_id=node.parent_id,
|
|
471
|
+
position=node.position,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
children = node.get_children()
|
|
475
|
+
assert children is not None
|
|
476
|
+
for position, child in enumerate(children):
|
|
477
|
+
child.parent_id = parent_id
|
|
478
|
+
child.position = position
|
|
479
|
+
node_queue.append(child)
|
|
480
|
+
|
|
481
|
+
return mb.model(base_score=base_score if isinstance(base_score, float) else None)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def get_gbt_model_from_lightgbm(model: Any, booster=None) -> Any:
|
|
485
|
+
model_str = model.model_to_string()
|
|
486
|
+
if "is_linear=1" in model_str:
|
|
487
|
+
raise TypeError("Linear trees are not supported.")
|
|
488
|
+
if "[boosting: dart]" in model_str:
|
|
489
|
+
raise TypeError(
|
|
490
|
+
"'Dart' booster is not supported. Try converting to 'treelite' first."
|
|
491
|
+
)
|
|
492
|
+
if "[boosting: rf]" in model_str:
|
|
493
|
+
raise TypeError("Random forest boosters are not supported.")
|
|
494
|
+
if ("[objective: lambdarank]" in model_str) or (
|
|
495
|
+
"[objective: rank_xendcg]" in model_str
|
|
496
|
+
):
|
|
497
|
+
raise TypeError("Ranking objectives are not supported.")
|
|
498
|
+
|
|
499
|
+
if booster is None:
|
|
500
|
+
booster = model.dump_model()
|
|
501
|
+
|
|
502
|
+
n_features = booster["max_feature_idx"] + 1
|
|
503
|
+
n_iterations = len(booster["tree_info"]) / booster["num_tree_per_iteration"]
|
|
504
|
+
n_classes = booster["num_tree_per_iteration"]
|
|
505
|
+
|
|
506
|
+
is_regression = False
|
|
507
|
+
objective_fun = booster["objective"]
|
|
508
|
+
if n_classes > 2:
|
|
509
|
+
if ("ova" in objective_fun) or ("ovr" in objective_fun):
|
|
510
|
+
raise TypeError(
|
|
511
|
+
"Only multiclass (softmax) objective is supported for multiclass classification"
|
|
512
|
+
)
|
|
513
|
+
elif "binary" in objective_fun: # nClasses == 1
|
|
514
|
+
n_classes = 2
|
|
515
|
+
else:
|
|
516
|
+
is_regression = True
|
|
517
|
+
|
|
518
|
+
tree_list = TreeList.from_lightgbm_booster_dump(booster)
|
|
519
|
+
|
|
520
|
+
return get_gbt_model_from_tree_list(
|
|
521
|
+
tree_list,
|
|
522
|
+
n_iterations=n_iterations,
|
|
523
|
+
is_regression=is_regression,
|
|
524
|
+
n_features=n_features,
|
|
525
|
+
n_classes=n_classes,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
|
|
530
|
+
# Note: in the absence of any feature names, XGBoost will generate
|
|
531
|
+
# tree json dumps where features are named 'f0..N'. While the JSONs
|
|
532
|
+
# of the whole model will have feature indices, the per-tree JSONs
|
|
533
|
+
# used here always use string names instead, hence the need for this.
|
|
534
|
+
feature_names = booster.feature_names
|
|
535
|
+
if feature_names:
|
|
536
|
+
feature_names_to_indices = {fname: ind for ind, fname in enumerate(feature_names)}
|
|
537
|
+
else:
|
|
538
|
+
feature_names_to_indices = {
|
|
539
|
+
f"f{ind}": ind for ind in range(booster.num_features())
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
if xgb_config is None:
|
|
543
|
+
xgb_config = get_xgboost_params(booster)
|
|
544
|
+
|
|
545
|
+
if xgb_config["learner"]["learner_train_param"]["booster"] != "gbtree":
|
|
546
|
+
raise TypeError(
|
|
547
|
+
"Only 'gbtree' booster type is supported. For DART, try converting to 'treelite' first."
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
n_targets = xgb_config["learner"]["learner_model_param"].get("num_target")
|
|
551
|
+
if n_targets is not None and int(n_targets) > 1:
|
|
552
|
+
raise TypeError("Multi-target boosters are not supported.")
|
|
553
|
+
|
|
554
|
+
n_features = int(xgb_config["learner"]["learner_model_param"]["num_feature"])
|
|
555
|
+
n_classes = int(xgb_config["learner"]["learner_model_param"]["num_class"])
|
|
556
|
+
# Note: base scores in XGBoost might be vector-valued starting from version 3.1.0.
|
|
557
|
+
# When this is the case, the 'base_score' attribute will be a JSON list, otherwise
|
|
558
|
+
# it will be a scalar. Note that in either case, it will be in the response scale.
|
|
559
|
+
base_score_str: str = xgb_config["learner"]["learner_model_param"]["base_score"]
|
|
560
|
+
if base_score_str.startswith("["):
|
|
561
|
+
base_score = json.loads(base_score_str)
|
|
562
|
+
if len(base_score) == 1:
|
|
563
|
+
base_score = base_score[0]
|
|
564
|
+
elif len(base_score) == 0:
|
|
565
|
+
base_score = 0.5
|
|
566
|
+
else:
|
|
567
|
+
base_score = float(base_score_str)
|
|
568
|
+
|
|
569
|
+
is_regression = False
|
|
570
|
+
objective_fun = xgb_config["learner"]["learner_train_param"]["objective"]
|
|
571
|
+
|
|
572
|
+
# Note: the base score from XGBoost is in the response scale, but the predictions
|
|
573
|
+
# are calculated in the link scale, so when there is a non-identity link function,
|
|
574
|
+
# it needs to be converted to the link scale.
|
|
575
|
+
if objective_fun in ["count:poisson", "reg:gamma", "reg:tweedie", "survival:aft"]:
|
|
576
|
+
base_score = float(np.log(base_score))
|
|
577
|
+
elif objective_fun == "reg:logistic":
|
|
578
|
+
base_score = float(np.log(base_score / (1 - base_score)))
|
|
579
|
+
elif objective_fun.startswith("rank"):
|
|
580
|
+
raise TypeError("Ranking objectives are not supported.")
|
|
581
|
+
|
|
582
|
+
if n_classes > 2:
|
|
583
|
+
if objective_fun not in ["multi:softprob", "multi:softmax"]:
|
|
584
|
+
raise TypeError(
|
|
585
|
+
"multi:softprob and multi:softmax are only supported for multiclass classification"
|
|
586
|
+
)
|
|
587
|
+
elif objective_fun.startswith("binary:"):
|
|
588
|
+
if objective_fun not in ["binary:logistic", "binary:logitraw"]:
|
|
589
|
+
raise TypeError(
|
|
590
|
+
"only binary:logistic and binary:logitraw are supported for binary classification"
|
|
591
|
+
)
|
|
592
|
+
n_classes = 2
|
|
593
|
+
if objective_fun == "binary:logitraw":
|
|
594
|
+
# daal4py always applies a sigmoid for pred_proba, wheres XGBoost
|
|
595
|
+
# returns raw predictions with logitraw
|
|
596
|
+
base_score = float(1 / (1 + np.exp(-base_score)))
|
|
597
|
+
else:
|
|
598
|
+
is_regression = True
|
|
599
|
+
|
|
600
|
+
# max_trees=0 if best_iteration does not exist
|
|
601
|
+
max_trees = getattr(booster, "best_iteration", -1) + 1
|
|
602
|
+
if n_classes > 2:
|
|
603
|
+
max_trees *= n_classes
|
|
604
|
+
tree_list = TreeList.from_xgb_booster(booster, max_trees, feature_names_to_indices)
|
|
605
|
+
|
|
606
|
+
if hasattr(booster, "best_iteration"):
|
|
607
|
+
n_iterations = booster.best_iteration + 1
|
|
608
|
+
else:
|
|
609
|
+
n_iterations = len(tree_list) // (n_classes if n_classes > 2 else 1)
|
|
610
|
+
|
|
611
|
+
return get_gbt_model_from_tree_list(
|
|
612
|
+
tree_list,
|
|
613
|
+
n_iterations=n_iterations,
|
|
614
|
+
is_regression=is_regression,
|
|
615
|
+
n_features=n_features,
|
|
616
|
+
n_classes=n_classes,
|
|
617
|
+
base_score=base_score,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def __get_value_as_list(node):
|
|
622
|
+
"""Make sure the values are a list"""
|
|
623
|
+
values = node["value"]
|
|
624
|
+
if isinstance(values, (list, tuple)):
|
|
625
|
+
return values
|
|
626
|
+
else:
|
|
627
|
+
return [values]
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def __calc_node_weights_from_leaf_weights(weights):
|
|
631
|
+
def sum_pairs(values):
|
|
632
|
+
assert len(values) % 2 == 0, "Length of values must be even"
|
|
633
|
+
return [values[i] + values[i + 1] for i in range(0, len(values), 2)]
|
|
634
|
+
|
|
635
|
+
level_weights = sum_pairs(weights)
|
|
636
|
+
result = [level_weights]
|
|
637
|
+
while len(level_weights) > 1:
|
|
638
|
+
level_weights = sum_pairs(level_weights)
|
|
639
|
+
result.append(level_weights)
|
|
640
|
+
return result[::-1]
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def get_gbt_model_from_catboost(booster: Any) -> Any:
|
|
644
|
+
if not booster.is_fitted():
|
|
645
|
+
raise RuntimeError("Model should be fitted before exporting to daal4py.")
|
|
646
|
+
|
|
647
|
+
model = CatBoostModelData(get_catboost_params(booster))
|
|
648
|
+
|
|
649
|
+
if model.has_categorical_features:
|
|
650
|
+
raise NotImplementedError(
|
|
651
|
+
"Categorical features are not supported in daal4py Gradient Boosting Trees"
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
objective = booster.get_params().get("objective", "")
|
|
655
|
+
if (
|
|
656
|
+
"Rank" in objective
|
|
657
|
+
or "Query" in objective
|
|
658
|
+
or "Pair" in objective
|
|
659
|
+
or objective in ["LambdaMart", "StochasticFilter", "GroupQuantile"]
|
|
660
|
+
):
|
|
661
|
+
raise TypeError("Ranking objectives are not supported.")
|
|
662
|
+
if "Multi" in objective and objective != "MultiClass":
|
|
663
|
+
if model.is_classification:
|
|
664
|
+
raise TypeError(
|
|
665
|
+
"Only 'MultiClass' loss is supported for multi-class classification."
|
|
666
|
+
)
|
|
667
|
+
else:
|
|
668
|
+
raise TypeError("Multi-output models are not supported.")
|
|
669
|
+
|
|
670
|
+
if model.is_classification:
|
|
671
|
+
mb = gbt_clf_model_builder(
|
|
672
|
+
n_features=model.n_features,
|
|
673
|
+
n_iterations=model.n_iterations,
|
|
674
|
+
n_classes=model.n_classes,
|
|
675
|
+
)
|
|
676
|
+
else:
|
|
677
|
+
mb = gbt_reg_model_builder(
|
|
678
|
+
n_features=model.n_features, n_iterations=model.n_iterations
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Create splits array (all splits are placed sequentially)
|
|
682
|
+
splits = []
|
|
683
|
+
for feature in model.float_features:
|
|
684
|
+
if feature["borders"]:
|
|
685
|
+
for feature_border in feature["borders"]:
|
|
686
|
+
splits.append(
|
|
687
|
+
{"feature_index": feature["feature_index"], "value": feature_border}
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Note: catboost models might have a 'bias' (intercept) which gets added
|
|
691
|
+
# to all predictions. In the case of single-output models, this is a scalar,
|
|
692
|
+
# but in the case of multi-output models such as multinomial logistic, it
|
|
693
|
+
# is a vector. Since daal4py doesn't support vector-valued intercepts, this
|
|
694
|
+
# adds the intercept to every terminal node instead, by dividing it equally
|
|
695
|
+
# among all trees. Usually, catboost would anyway set them to zero, but it
|
|
696
|
+
# still allows setting custom intercepts.
|
|
697
|
+
cb_bias = booster.get_scale_and_bias()[1]
|
|
698
|
+
add_intercept_to_each_node = isinstance(cb_bias, list)
|
|
699
|
+
if add_intercept_to_each_node:
|
|
700
|
+
cb_bias = np.array(cb_bias) / model.n_iterations
|
|
701
|
+
if not model.is_classification:
|
|
702
|
+
raise TypeError("Multi-output regression models are not supported.")
|
|
703
|
+
|
|
704
|
+
def add_vector_bias(values: list[float]) -> list[float]:
|
|
705
|
+
return list(np.array(values) + cb_bias)
|
|
706
|
+
|
|
707
|
+
trees_explicit = []
|
|
708
|
+
tree_symmetric = []
|
|
709
|
+
|
|
710
|
+
all_trees_are_empty = True
|
|
711
|
+
|
|
712
|
+
if model.is_symmetric_tree:
|
|
713
|
+
for tree in model.oblivious_trees:
|
|
714
|
+
tree_splits = tree.get("splits", [])
|
|
715
|
+
cur_tree_depth = len(tree_splits) if tree_splits is not None else 0
|
|
716
|
+
tree_symmetric.append((tree, cur_tree_depth))
|
|
717
|
+
else:
|
|
718
|
+
for tree in model.trees:
|
|
719
|
+
n_nodes = 1
|
|
720
|
+
|
|
721
|
+
# Check if node is a leaf (in case of stump)
|
|
722
|
+
if "split" in tree:
|
|
723
|
+
# Get number of trees and splits info via BFS
|
|
724
|
+
# Create queue
|
|
725
|
+
nodes_queue = []
|
|
726
|
+
root_node = CatBoostNode(split=splits[tree["split"]["split_index"]])
|
|
727
|
+
nodes_queue.append((tree, root_node))
|
|
728
|
+
while nodes_queue:
|
|
729
|
+
cur_node_data, cur_node = nodes_queue.pop(0)
|
|
730
|
+
if "value" in cur_node_data:
|
|
731
|
+
cur_node.value = __get_value_as_list(cur_node_data)
|
|
732
|
+
else:
|
|
733
|
+
cur_node.split = splits[cur_node_data["split"]["split_index"]]
|
|
734
|
+
left_node = CatBoostNode()
|
|
735
|
+
right_node = CatBoostNode()
|
|
736
|
+
cur_node.left = left_node
|
|
737
|
+
cur_node.right = right_node
|
|
738
|
+
nodes_queue.append((cur_node_data["left"], left_node))
|
|
739
|
+
nodes_queue.append((cur_node_data["right"], right_node))
|
|
740
|
+
n_nodes += 2
|
|
741
|
+
all_trees_are_empty = False
|
|
742
|
+
else:
|
|
743
|
+
root_node = CatBoostNode()
|
|
744
|
+
if model.is_classification and model.n_classes > 2:
|
|
745
|
+
root_node.value = [value * model.scale for value in tree["value"]]
|
|
746
|
+
if add_intercept_to_each_node:
|
|
747
|
+
root_node.value = add_vector_bias(root_node.value)
|
|
748
|
+
else:
|
|
749
|
+
root_node.value = [tree["value"] * model.scale]
|
|
750
|
+
trees_explicit.append((root_node, n_nodes))
|
|
751
|
+
|
|
752
|
+
tree_id = []
|
|
753
|
+
class_label = 0
|
|
754
|
+
count = 0
|
|
755
|
+
|
|
756
|
+
# Only 1 tree for each iteration in case of regression or binary classification
|
|
757
|
+
if not model.is_classification or model.n_classes == 2:
|
|
758
|
+
n_tree_each_iter = 1
|
|
759
|
+
else:
|
|
760
|
+
n_tree_each_iter = model.n_classes
|
|
761
|
+
|
|
762
|
+
shap_ready = False
|
|
763
|
+
|
|
764
|
+
# Create id for trees (for the right order in model builder)
|
|
765
|
+
for i in range(model.n_iterations):
|
|
766
|
+
for _ in range(n_tree_each_iter):
|
|
767
|
+
if model.is_symmetric_tree:
|
|
768
|
+
if not len(tree_symmetric):
|
|
769
|
+
n_nodes = 1
|
|
770
|
+
else:
|
|
771
|
+
n_nodes = 2 ** (tree_symmetric[i][1] + 1) - 1
|
|
772
|
+
else:
|
|
773
|
+
if not len(trees_explicit):
|
|
774
|
+
n_nodes = 1
|
|
775
|
+
else:
|
|
776
|
+
n_nodes = trees_explicit[i][1]
|
|
777
|
+
|
|
778
|
+
if model.is_classification and model.n_classes > 2:
|
|
779
|
+
tree_id.append(mb.create_tree(n_nodes, class_label))
|
|
780
|
+
count += 1
|
|
781
|
+
if count == model.n_iterations:
|
|
782
|
+
class_label += 1
|
|
783
|
+
count = 0
|
|
784
|
+
|
|
785
|
+
elif model.is_classification:
|
|
786
|
+
tree_id.append(mb.create_tree(n_nodes, 0))
|
|
787
|
+
else:
|
|
788
|
+
tree_id.append(mb.create_tree(n_nodes))
|
|
789
|
+
|
|
790
|
+
if model.is_symmetric_tree:
|
|
791
|
+
shap_ready = True # this code branch provides all info for SHAP values
|
|
792
|
+
for class_label in range(n_tree_each_iter):
|
|
793
|
+
for i in range(model.n_iterations):
|
|
794
|
+
cur_tree_info = tree_symmetric[i][0]
|
|
795
|
+
cur_tree_id = tree_id[i * n_tree_each_iter + class_label]
|
|
796
|
+
cur_tree_leaf_val = cur_tree_info["leaf_values"]
|
|
797
|
+
cur_tree_leaf_weights = cur_tree_info["leaf_weights"]
|
|
798
|
+
cur_tree_depth = tree_symmetric[i][1]
|
|
799
|
+
if cur_tree_depth == 0:
|
|
800
|
+
mb.add_leaf(
|
|
801
|
+
tree_id=cur_tree_id,
|
|
802
|
+
response=cur_tree_leaf_val[class_label] * model.scale
|
|
803
|
+
+ (cb_bias[class_label] if add_intercept_to_each_node else 0),
|
|
804
|
+
cover=cur_tree_leaf_weights[0],
|
|
805
|
+
)
|
|
806
|
+
else:
|
|
807
|
+
# One split used for the whole level
|
|
808
|
+
cur_level_split = splits[
|
|
809
|
+
cur_tree_info["splits"][cur_tree_depth - 1]["split_index"]
|
|
810
|
+
]
|
|
811
|
+
cur_tree_weights_per_level = __calc_node_weights_from_leaf_weights(
|
|
812
|
+
cur_tree_leaf_weights
|
|
813
|
+
)
|
|
814
|
+
root_weight = cur_tree_weights_per_level[0][0]
|
|
815
|
+
|
|
816
|
+
root_id = mb.add_split(
|
|
817
|
+
tree_id=cur_tree_id,
|
|
818
|
+
feature_index=cur_level_split["feature_index"],
|
|
819
|
+
feature_value=cur_level_split["value"],
|
|
820
|
+
default_left=model.default_left,
|
|
821
|
+
cover=root_weight,
|
|
822
|
+
)
|
|
823
|
+
prev_level_nodes = [root_id]
|
|
824
|
+
|
|
825
|
+
# Iterate over levels, splits in json are reversed (root split is the last)
|
|
826
|
+
for cur_level in range(cur_tree_depth - 2, -1, -1):
|
|
827
|
+
cur_level_nodes = []
|
|
828
|
+
next_level_weights = cur_tree_weights_per_level[cur_level + 1]
|
|
829
|
+
cur_level_node_index = 0
|
|
830
|
+
for cur_parent in prev_level_nodes:
|
|
831
|
+
cur_level_split = splits[
|
|
832
|
+
cur_tree_info["splits"][cur_level]["split_index"]
|
|
833
|
+
]
|
|
834
|
+
cover_nodes = next_level_weights[cur_level_node_index]
|
|
835
|
+
if cover_nodes == 0:
|
|
836
|
+
shap_ready = False
|
|
837
|
+
cur_left_node = mb.add_split(
|
|
838
|
+
tree_id=cur_tree_id,
|
|
839
|
+
parent_id=cur_parent,
|
|
840
|
+
position=0,
|
|
841
|
+
feature_index=cur_level_split["feature_index"],
|
|
842
|
+
feature_value=cur_level_split["value"],
|
|
843
|
+
default_left=model.default_left,
|
|
844
|
+
cover=cover_nodes,
|
|
845
|
+
)
|
|
846
|
+
# cur_level_node_index += 1
|
|
847
|
+
cur_right_node = mb.add_split(
|
|
848
|
+
tree_id=cur_tree_id,
|
|
849
|
+
parent_id=cur_parent,
|
|
850
|
+
position=1,
|
|
851
|
+
feature_index=cur_level_split["feature_index"],
|
|
852
|
+
feature_value=cur_level_split["value"],
|
|
853
|
+
default_left=model.default_left,
|
|
854
|
+
cover=cover_nodes,
|
|
855
|
+
)
|
|
856
|
+
# cur_level_node_index += 1
|
|
857
|
+
cur_level_nodes.append(cur_left_node)
|
|
858
|
+
cur_level_nodes.append(cur_right_node)
|
|
859
|
+
prev_level_nodes = cur_level_nodes
|
|
860
|
+
|
|
861
|
+
# Different storing format for leaves
|
|
862
|
+
if not model.is_classification or model.n_classes == 2:
|
|
863
|
+
for last_level_node_num in range(len(prev_level_nodes)):
|
|
864
|
+
mb.add_leaf(
|
|
865
|
+
tree_id=cur_tree_id,
|
|
866
|
+
response=cur_tree_leaf_val[2 * last_level_node_num]
|
|
867
|
+
* model.scale,
|
|
868
|
+
parent_id=prev_level_nodes[last_level_node_num],
|
|
869
|
+
position=0,
|
|
870
|
+
cover=cur_tree_leaf_weights[2 * last_level_node_num],
|
|
871
|
+
)
|
|
872
|
+
mb.add_leaf(
|
|
873
|
+
tree_id=cur_tree_id,
|
|
874
|
+
response=cur_tree_leaf_val[2 * last_level_node_num + 1]
|
|
875
|
+
* model.scale,
|
|
876
|
+
parent_id=prev_level_nodes[last_level_node_num],
|
|
877
|
+
position=1,
|
|
878
|
+
cover=cur_tree_leaf_weights[2 * last_level_node_num + 1],
|
|
879
|
+
)
|
|
880
|
+
else:
|
|
881
|
+
shap_ready = False
|
|
882
|
+
for last_level_node_num in range(len(prev_level_nodes)):
|
|
883
|
+
left_index = (
|
|
884
|
+
2 * last_level_node_num * n_tree_each_iter + class_label
|
|
885
|
+
)
|
|
886
|
+
right_index = (
|
|
887
|
+
2 * last_level_node_num + 1
|
|
888
|
+
) * n_tree_each_iter + class_label
|
|
889
|
+
mb.add_leaf(
|
|
890
|
+
tree_id=cur_tree_id,
|
|
891
|
+
response=cur_tree_leaf_val[left_index] * model.scale
|
|
892
|
+
+ (
|
|
893
|
+
cb_bias[class_label]
|
|
894
|
+
if add_intercept_to_each_node
|
|
895
|
+
else 0
|
|
896
|
+
),
|
|
897
|
+
parent_id=prev_level_nodes[last_level_node_num],
|
|
898
|
+
position=0,
|
|
899
|
+
cover=0.0,
|
|
900
|
+
)
|
|
901
|
+
mb.add_leaf(
|
|
902
|
+
tree_id=cur_tree_id,
|
|
903
|
+
response=cur_tree_leaf_val[right_index] * model.scale
|
|
904
|
+
+ (
|
|
905
|
+
cb_bias[class_label]
|
|
906
|
+
if add_intercept_to_each_node
|
|
907
|
+
else 0
|
|
908
|
+
),
|
|
909
|
+
parent_id=prev_level_nodes[last_level_node_num],
|
|
910
|
+
position=1,
|
|
911
|
+
cover=0.0,
|
|
912
|
+
)
|
|
913
|
+
else:
|
|
914
|
+
shap_ready = False
|
|
915
|
+
scale = booster.get_scale_and_bias()[0]
|
|
916
|
+
for class_label in range(n_tree_each_iter):
|
|
917
|
+
for i in range(model.n_iterations):
|
|
918
|
+
root_node = trees_explicit[i][0]
|
|
919
|
+
|
|
920
|
+
cur_tree_id = tree_id[i * n_tree_each_iter + class_label]
|
|
921
|
+
# Traverse tree via BFS and build tree with modelbuilder
|
|
922
|
+
if root_node.value is None:
|
|
923
|
+
root_id = mb.add_split(
|
|
924
|
+
tree_id=cur_tree_id,
|
|
925
|
+
feature_index=root_node.split["feature_index"],
|
|
926
|
+
feature_value=root_node.split["value"],
|
|
927
|
+
default_left=model.default_left,
|
|
928
|
+
cover=0.0,
|
|
929
|
+
)
|
|
930
|
+
nodes_queue = [(root_node, root_id)]
|
|
931
|
+
while nodes_queue:
|
|
932
|
+
cur_node, cur_node_id = nodes_queue.pop(0)
|
|
933
|
+
left_node = cur_node.left
|
|
934
|
+
# Check if node is a leaf
|
|
935
|
+
if left_node.value is None:
|
|
936
|
+
left_node_id = mb.add_split(
|
|
937
|
+
tree_id=cur_tree_id,
|
|
938
|
+
parent_id=cur_node_id,
|
|
939
|
+
position=0,
|
|
940
|
+
feature_index=left_node.split["feature_index"],
|
|
941
|
+
feature_value=left_node.split["value"],
|
|
942
|
+
default_left=model.default_left,
|
|
943
|
+
cover=0.0,
|
|
944
|
+
)
|
|
945
|
+
nodes_queue.append((left_node, left_node_id))
|
|
946
|
+
else:
|
|
947
|
+
mb.add_leaf(
|
|
948
|
+
tree_id=cur_tree_id,
|
|
949
|
+
response=scale * left_node.value[class_label]
|
|
950
|
+
+ (
|
|
951
|
+
cb_bias[class_label]
|
|
952
|
+
if add_intercept_to_each_node
|
|
953
|
+
else 0
|
|
954
|
+
),
|
|
955
|
+
parent_id=cur_node_id,
|
|
956
|
+
position=0,
|
|
957
|
+
cover=0.0,
|
|
958
|
+
)
|
|
959
|
+
right_node = cur_node.right
|
|
960
|
+
# Check if node is a leaf
|
|
961
|
+
if right_node.value is None:
|
|
962
|
+
right_node_id = mb.add_split(
|
|
963
|
+
tree_id=cur_tree_id,
|
|
964
|
+
parent_id=cur_node_id,
|
|
965
|
+
position=1,
|
|
966
|
+
feature_index=right_node.split["feature_index"],
|
|
967
|
+
feature_value=right_node.split["value"],
|
|
968
|
+
default_left=model.default_left,
|
|
969
|
+
cover=0.0,
|
|
970
|
+
)
|
|
971
|
+
nodes_queue.append((right_node, right_node_id))
|
|
972
|
+
else:
|
|
973
|
+
mb.add_leaf(
|
|
974
|
+
tree_id=cur_tree_id,
|
|
975
|
+
response=scale * cur_node.right.value[class_label]
|
|
976
|
+
+ (
|
|
977
|
+
cb_bias[class_label]
|
|
978
|
+
if add_intercept_to_each_node
|
|
979
|
+
else 0
|
|
980
|
+
),
|
|
981
|
+
parent_id=cur_node_id,
|
|
982
|
+
position=1,
|
|
983
|
+
cover=0.0,
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
else:
|
|
987
|
+
# Tree has only one node
|
|
988
|
+
# Note: the root node already has scale and bias added to it,
|
|
989
|
+
# so no need to add them again here like it is done for the leafs.
|
|
990
|
+
mb.add_leaf(
|
|
991
|
+
tree_id=cur_tree_id,
|
|
992
|
+
response=root_node.value[class_label],
|
|
993
|
+
cover=0.0,
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
if all_trees_are_empty and not model.is_symmetric_tree:
|
|
997
|
+
shap_ready = True
|
|
998
|
+
|
|
999
|
+
intercept = 0.0
|
|
1000
|
+
if not add_intercept_to_each_node:
|
|
1001
|
+
intercept = booster.get_scale_and_bias()[1]
|
|
1002
|
+
return mb.model(base_score=intercept), shap_ready
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
def get_gbt_model_from_treelite(
|
|
1006
|
+
tl_model: "treelite.model.Model",
|
|
1007
|
+
) -> tuple[Any, int, int, bool]:
|
|
1008
|
+
model_json = json.loads(tl_model.dump_as_json())
|
|
1009
|
+
task_type = model_json["task_type"]
|
|
1010
|
+
if task_type not in ["kBinaryClf", "kRegressor", "kMultiClf", "kIsolationForest"]:
|
|
1011
|
+
raise TypeError(f"Model to convert is of unsupported type: {task_type}")
|
|
1012
|
+
if model_json["num_target"] > 1:
|
|
1013
|
+
raise TypeError("Multi-target models are not supported.")
|
|
1014
|
+
if model_json["postprocessor"] == "multiclass_ova":
|
|
1015
|
+
raise TypeError(
|
|
1016
|
+
"Multi-class classification models that use One-Vs-All are not supported."
|
|
1017
|
+
)
|
|
1018
|
+
for tree in model_json["trees"]:
|
|
1019
|
+
if tree["has_categorical_split"]:
|
|
1020
|
+
raise TypeError("Models with categorical features are not supported.")
|
|
1021
|
+
num_trees = tl_model.num_tree
|
|
1022
|
+
if not num_trees:
|
|
1023
|
+
raise TypeError("Model to convert contains no trees.")
|
|
1024
|
+
|
|
1025
|
+
# Note: the daal4py module always adds up the scores, but some models
|
|
1026
|
+
# might average them instead. In such case, this turns the trees into
|
|
1027
|
+
# additive ones by dividing the predictions by the number of nodes beforehand.
|
|
1028
|
+
if model_json["average_tree_output"]:
|
|
1029
|
+
divide_treelite_leaf_values_by_const(model_json, num_trees)
|
|
1030
|
+
|
|
1031
|
+
base_score = model_json["base_scores"]
|
|
1032
|
+
num_class = model_json["num_class"][0]
|
|
1033
|
+
num_feature = model_json["num_feature"]
|
|
1034
|
+
|
|
1035
|
+
if task_type == "kBinaryClf":
|
|
1036
|
+
num_class = 2
|
|
1037
|
+
if base_score:
|
|
1038
|
+
base_score = list(1 / (1 + np.exp(-np.array(base_score))))
|
|
1039
|
+
|
|
1040
|
+
if num_class > 2:
|
|
1041
|
+
shap_ready = False
|
|
1042
|
+
else:
|
|
1043
|
+
shap_ready = True
|
|
1044
|
+
for tree in model_json["trees"]:
|
|
1045
|
+
if not tree["nodes"][0].get("sum_hess", False):
|
|
1046
|
+
shap_ready = False
|
|
1047
|
+
break
|
|
1048
|
+
|
|
1049
|
+
# In the case of random forests for classification, it might work
|
|
1050
|
+
# by averaging predictions without any link function, whereas
|
|
1051
|
+
# daal4py assumes a logit link. In such case, it's not possible to
|
|
1052
|
+
# convert them to daal4py's logic, but the model can still be used
|
|
1053
|
+
# as a regressor that always outputs something between 0 and 1.
|
|
1054
|
+
is_regression = "Clf" not in task_type
|
|
1055
|
+
if not is_regression and model_json["postprocessor"] == "identity_multiclass":
|
|
1056
|
+
is_regression = True
|
|
1057
|
+
warnings.warn(
|
|
1058
|
+
"Attempting to convert classification model which is not"
|
|
1059
|
+
" based on gradient boosting. Will output a regression"
|
|
1060
|
+
" model instead."
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
looks_like_random_forest = (
|
|
1064
|
+
model_json["postprocessor"] == "identity_multiclass"
|
|
1065
|
+
and len(model_json["base_scores"]) > 1
|
|
1066
|
+
and task_type == "kMultiClf"
|
|
1067
|
+
)
|
|
1068
|
+
if looks_like_random_forest:
|
|
1069
|
+
if num_class > 2 or len(base_score) > 2:
|
|
1070
|
+
raise TypeError("Multi-class random forests are not supported.")
|
|
1071
|
+
if len(model_json["num_class"]) > 1:
|
|
1072
|
+
raise TypeError("Multi-output random forests are not supported.")
|
|
1073
|
+
if len(base_score) == 2 and base_score[0]:
|
|
1074
|
+
raise TypeError("Random forests with base scores are not supported.")
|
|
1075
|
+
|
|
1076
|
+
# In the case of binary random forests, it will always have leaf values
|
|
1077
|
+
# for 2 classes, which is redundant as they sum to 1. daal4py requires
|
|
1078
|
+
# only values for the positive class, so they need to be converted.
|
|
1079
|
+
if looks_like_random_forest:
|
|
1080
|
+
leave_only_last_treelite_leaf_value(model_json)
|
|
1081
|
+
base_score = base_score[-1]
|
|
1082
|
+
|
|
1083
|
+
# In the case of multi-class classification models, if converted
|
|
1084
|
+
# from xgboost, the order of the trees will be the same - i.e.
|
|
1085
|
+
# sequences of one tree of each class, followed by another such
|
|
1086
|
+
# sequence. But treelite could in theory also support building
|
|
1087
|
+
# models where the trees are in a different order, in which case
|
|
1088
|
+
# they will need to be reordered to match xgboost, since that's
|
|
1089
|
+
# how daal4py handles them. And if there is an uneven number of
|
|
1090
|
+
# trees per class, then will need to make up extra trees with
|
|
1091
|
+
# zeros to accommodate it.
|
|
1092
|
+
if task_type == "kMultiClf" and not looks_like_random_forest:
|
|
1093
|
+
num_trees = len(model_json["trees"])
|
|
1094
|
+
if (num_trees % num_class) != 0:
|
|
1095
|
+
shap_ready = False
|
|
1096
|
+
class_ids, num_trees_per_class = np.unique(
|
|
1097
|
+
model_json["class_id"], return_counts=True
|
|
1098
|
+
)
|
|
1099
|
+
max_tree_per_class = num_trees_per_class.max()
|
|
1100
|
+
num_tree_add_per_class = max_tree_per_class - num_trees_per_class
|
|
1101
|
+
for class_ind in range(num_class):
|
|
1102
|
+
for tree in range(num_tree_add_per_class[class_ind]):
|
|
1103
|
+
add_empty_tree_to_treelite_json(model_json, class_ind)
|
|
1104
|
+
|
|
1105
|
+
tree_class_orders = model_json["class_id"]
|
|
1106
|
+
sequential_ids = np.arange(num_class)
|
|
1107
|
+
num_trees = len(model_json["trees"])
|
|
1108
|
+
assert (num_trees % num_class) == 0
|
|
1109
|
+
if not np.array_equal(
|
|
1110
|
+
tree_class_orders, np.tile(sequential_ids, int(num_trees / num_class))
|
|
1111
|
+
):
|
|
1112
|
+
argsorted_class_indices = np.argsort(tree_class_orders)
|
|
1113
|
+
per_class_indices = np.split(argsorted_class_indices, num_class)
|
|
1114
|
+
correct_order = np.vstack(per_class_indices).reshape(-1, order="F")
|
|
1115
|
+
model_json["trees"] = [model_json["trees"][ix] for ix in correct_order]
|
|
1116
|
+
model_json["class_id"] = [model_json["class_id"][ix] for ix in correct_order]
|
|
1117
|
+
|
|
1118
|
+
# In the case of multi-class classification with base scores,
|
|
1119
|
+
# since daal4py only supports scalar intercepts, this follows the
|
|
1120
|
+
# same strategy as in catboost of dividing the intercepts equally
|
|
1121
|
+
# among the number of trees
|
|
1122
|
+
if task_type == "kMultiClf" and not looks_like_random_forest:
|
|
1123
|
+
add_intercept_to_treelite_leafs(model_json, base_score)
|
|
1124
|
+
base_score = None
|
|
1125
|
+
|
|
1126
|
+
if isinstance(base_score, list):
|
|
1127
|
+
if len(base_score) == 1:
|
|
1128
|
+
base_score = base_score[0]
|
|
1129
|
+
else:
|
|
1130
|
+
raise TypeError("Model to convert is malformed.")
|
|
1131
|
+
|
|
1132
|
+
tree_list = TreeList.from_treelite_dict(model_json)
|
|
1133
|
+
return (
|
|
1134
|
+
get_gbt_model_from_tree_list(
|
|
1135
|
+
tree_list,
|
|
1136
|
+
n_iterations=num_trees
|
|
1137
|
+
/ (
|
|
1138
|
+
num_class
|
|
1139
|
+
if task_type == "kMultiClf" and not looks_like_random_forest
|
|
1140
|
+
else 1
|
|
1141
|
+
),
|
|
1142
|
+
is_regression=is_regression,
|
|
1143
|
+
n_features=num_feature,
|
|
1144
|
+
n_classes=num_class,
|
|
1145
|
+
base_score=base_score,
|
|
1146
|
+
),
|
|
1147
|
+
num_class,
|
|
1148
|
+
num_feature,
|
|
1149
|
+
shap_ready,
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
def divide_treelite_leaf_values_by_const(
|
|
1154
|
+
tl_json: dict[str, Any], divisor: "int | float"
|
|
1155
|
+
) -> None:
|
|
1156
|
+
for tree in tl_json["trees"]:
|
|
1157
|
+
for node in tree["nodes"]:
|
|
1158
|
+
if "leaf_value" in node:
|
|
1159
|
+
if isinstance(node["leaf_value"], (list, tuple)):
|
|
1160
|
+
node["leaf_value"] = list(np.array(node["leaf_value"]) / divisor)
|
|
1161
|
+
else:
|
|
1162
|
+
node["leaf_value"] /= divisor
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
def leave_only_last_treelite_leaf_value(tl_json: dict[str, Any]) -> None:
|
|
1166
|
+
for tree in tl_json["trees"]:
|
|
1167
|
+
for node in tree["nodes"]:
|
|
1168
|
+
if "leaf_value" in node:
|
|
1169
|
+
assert len(node["leaf_value"]) == 2
|
|
1170
|
+
node["leaf_value"] = node["leaf_value"][-1]
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def add_intercept_to_treelite_leafs(
|
|
1174
|
+
tl_json: dict[str, Any], base_score: list[float]
|
|
1175
|
+
) -> None:
|
|
1176
|
+
num_trees_per_class = len(tl_json["trees"]) / tl_json["num_class"][0]
|
|
1177
|
+
for tree_index, tree in enumerate(tl_json["trees"]):
|
|
1178
|
+
leaf_add = base_score[tl_json["class_id"][tree_index]] / num_trees_per_class
|
|
1179
|
+
for node in tree["nodes"]:
|
|
1180
|
+
if "leaf_value" in node:
|
|
1181
|
+
node["leaf_value"] += leaf_add
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def add_empty_tree_to_treelite_json(tl_json: dict[str, Any], class_add: int) -> None:
|
|
1185
|
+
tl_json["class_id"].append(class_add)
|
|
1186
|
+
tl_json["trees"].append(
|
|
1187
|
+
{
|
|
1188
|
+
"num_nodes": 1,
|
|
1189
|
+
"has_categorical_split": False,
|
|
1190
|
+
"nodes": [
|
|
1191
|
+
{
|
|
1192
|
+
"node_id": 0,
|
|
1193
|
+
"leaf_value": 0.0,
|
|
1194
|
+
"data_count": 0,
|
|
1195
|
+
"sum_hess": 0.0,
|
|
1196
|
+
},
|
|
1197
|
+
],
|
|
1198
|
+
}
|
|
1199
|
+
)
|