scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (267) hide show
  1. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
  2. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
  3. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
  4. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
  5. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
  6. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
  7. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
  8. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
  9. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
  10. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
  11. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
  12. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
  13. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
  14. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
  15. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  16. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
  17. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
  18. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  19. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
  20. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
  21. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
  22. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  23. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
  24. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
  25. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
  26. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
  27. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  28. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
  29. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  30. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
  31. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
  32. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
  33. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
  34. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  35. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
  36. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
  37. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
  38. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
  39. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
  40. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
  41. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
  42. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  43. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
  44. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  45. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
  46. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  47. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  48. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  49. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
  50. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
  51. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
  52. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
  53. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  54. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  55. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
  56. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
  57. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
  58. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
  59. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
  60. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
  61. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
  62. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
  63. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
  64. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
  65. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
  66. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
  67. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
  68. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
  69. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
  70. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  71. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
  72. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
  73. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
  74. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
  75. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
  76. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
  77. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
  78. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
  79. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
  80. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
  81. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
  82. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
  83. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
  84. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
  85. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
  86. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
  87. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
  88. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
  89. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
  90. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
  91. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
  92. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
  93. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
  94. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
  95. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
  96. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
  97. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
  98. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
  99. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
  100. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
  101. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
  102. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
  103. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
  104. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
  105. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
  106. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
  107. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
  108. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  109. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  110. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
  111. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
  112. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
  113. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
  114. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
  115. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
  116. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
  117. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
  118. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
  119. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
  120. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
  121. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
  122. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
  123. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
  124. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
  125. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
  126. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
  127. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
  128. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
  129. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
  130. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
  131. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
  132. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
  133. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
  134. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
  135. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
  136. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
  137. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
  138. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
  139. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
  140. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
  141. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  142. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
  143. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
  144. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  145. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  146. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
  147. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
  148. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
  149. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
  150. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
  151. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
  152. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  153. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
  154. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
  155. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
  156. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
  157. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
  158. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
  159. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
  160. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
  161. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
  162. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
  163. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
  164. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
  165. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
  166. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
  167. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
  168. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
  169. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
  170. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
  171. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
  172. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
  173. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
  174. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
  175. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  176. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  177. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
  178. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
  179. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
  180. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
  181. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
  182. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
  183. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
  184. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
  185. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
  186. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
  187. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
  188. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
  189. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
  190. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
  191. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
  192. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
  193. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
  194. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
  195. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
  196. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
  197. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
  198. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
  199. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
  200. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
  201. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  202. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
  203. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
  204. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
  205. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
  206. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
  207. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  208. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
  209. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
  210. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  211. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
  212. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  213. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
  214. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
  215. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
  216. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
  217. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
  218. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
  219. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
  220. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
  221. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
  222. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
  223. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
  224. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
  225. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  226. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
  227. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
  228. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
  229. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
  230. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
  231. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  232. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
  233. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
  234. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
  235. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
  236. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
  237. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
  238. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
  239. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
  240. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
  241. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
  242. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
  243. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
  244. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
  245. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
  246. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
  247. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
  248. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
  249. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
  250. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
  251. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
  252. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
  253. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
  254. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
  255. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
  256. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  257. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
  258. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
  259. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
  260. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
  261. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
  262. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
  263. scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
  264. scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
  265. scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
  266. scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
  267. scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1199 @@
1
+ # ===============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import json
18
+ import warnings
19
+ from collections import deque
20
+ from copy import deepcopy
21
+ from tempfile import NamedTemporaryFile
22
+ from typing import Any, Deque, Dict, List, Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+
26
+ from .. import gbt_clf_model_builder, gbt_reg_model_builder
27
+
28
+
29
+ class CatBoostNode:
30
+ def __init__(
31
+ self,
32
+ split: Optional[Dict] = None,
33
+ value: Optional[List[float]] = None,
34
+ right: Optional[int] = None,
35
+ left: Optional[float] = None,
36
+ cover: Optional[float] = None,
37
+ ) -> None:
38
+ self.split = split
39
+ self.value = value
40
+ self.right = right
41
+ self.left = left
42
+ self.cover = cover
43
+
44
+
45
+ class CatBoostModelData:
46
+ """Wrapper around the CatBoost model dump for easier access to properties"""
47
+
48
+ def __init__(self, data):
49
+ self.__data = data
50
+
51
+ @property
52
+ def n_features(self):
53
+ return len(self.__data["features_info"]["float_features"])
54
+
55
+ @property
56
+ def grow_policy(self):
57
+ return self.__data["model_info"]["params"]["tree_learner_options"]["grow_policy"]
58
+
59
+ @property
60
+ def oblivious_trees(self):
61
+ return self.__data["oblivious_trees"]
62
+
63
+ @property
64
+ def trees(self):
65
+ return self.__data["trees"]
66
+
67
+ @property
68
+ def n_classes(self):
69
+ """Number of classes, returns -1 if it's not a classification model"""
70
+ if "class_params" in self.__data["model_info"]:
71
+ return len(self.__data["model_info"]["class_params"]["class_to_label"])
72
+ return -1
73
+
74
+ @property
75
+ def is_classification(self):
76
+ return "class_params" in self.__data["model_info"]
77
+
78
+ @property
79
+ def has_categorical_features(self):
80
+ return "categorical_features" in self.__data["features_info"]
81
+
82
+ @property
83
+ def is_symmetric_tree(self):
84
+ return self.grow_policy == "SymmetricTree"
85
+
86
+ @property
87
+ def float_features(self):
88
+ return self.__data["features_info"]["float_features"]
89
+
90
+ @property
91
+ def n_iterations(self):
92
+ if self.is_symmetric_tree:
93
+ return len(self.oblivious_trees)
94
+ else:
95
+ return len(self.trees)
96
+
97
+ @property
98
+ def scale(self):
99
+ return self.__data["scale_and_bias"][0]
100
+
101
+ @property
102
+ def default_left(self):
103
+ dpo = self.__data["model_info"]["params"]["data_processing_options"]
104
+ nan_mode = dpo["float_features_binarization"]["nan_mode"]
105
+ return int(nan_mode.lower() == "min")
106
+
107
+
108
+ class Node:
109
+ """Helper class holding Tree Node information"""
110
+
111
+ def __init__(
112
+ self,
113
+ cover: float,
114
+ is_leaf: bool,
115
+ default_left: bool,
116
+ feature: int,
117
+ value: float,
118
+ n_children: int = 0,
119
+ left_child: "Optional[Node]" = None,
120
+ right_child: "Optional[Node]" = None,
121
+ parent_id: Optional[int] = -1,
122
+ position: Optional[int] = -1,
123
+ ) -> None:
124
+ self.cover = cover
125
+ self.is_leaf = is_leaf
126
+ self.default_left = default_left
127
+ self.__feature = feature
128
+ self.value = value
129
+ self.n_children = n_children
130
+ self.left_child = left_child
131
+ self.right_child = right_child
132
+ self.parent_id = parent_id
133
+ self.position = position
134
+
135
+ @staticmethod
136
+ def from_xgb_dict(
137
+ input_dict: Dict[str, Any], feature_names_to_indices: dict[str, int]
138
+ ) -> "Node":
139
+ if "children" in input_dict:
140
+ left_child = Node.from_xgb_dict(
141
+ input_dict["children"][0], feature_names_to_indices
142
+ )
143
+ right_child = Node.from_xgb_dict(
144
+ input_dict["children"][1], feature_names_to_indices
145
+ )
146
+ n_children = 2 + left_child.n_children + right_child.n_children
147
+ else:
148
+ left_child = None
149
+ right_child = None
150
+ n_children = 0
151
+ is_leaf = "leaf" in input_dict
152
+ default_left = "yes" in input_dict and input_dict["yes"] == input_dict["missing"]
153
+ feature = input_dict.get("split")
154
+ if feature:
155
+ feature = feature_names_to_indices[feature]
156
+ return Node(
157
+ cover=input_dict["cover"],
158
+ is_leaf=is_leaf,
159
+ default_left=default_left,
160
+ feature=feature,
161
+ value=input_dict["leaf"] if is_leaf else input_dict["split_condition"],
162
+ n_children=n_children,
163
+ left_child=left_child,
164
+ right_child=right_child,
165
+ )
166
+
167
+ @staticmethod
168
+ def from_lightgbm_dict(input_dict: Dict[str, Any]) -> "Node":
169
+ if "tree_structure" in input_dict:
170
+ tree = input_dict["tree_structure"]
171
+ else:
172
+ tree = input_dict
173
+
174
+ n_children = 0
175
+ if "left_child" in tree:
176
+ left_child = Node.from_lightgbm_dict(tree["left_child"])
177
+ n_children += 1 + left_child.n_children
178
+ else:
179
+ left_child = None
180
+ if "right_child" in tree:
181
+ right_child = Node.from_lightgbm_dict(tree["right_child"])
182
+ n_children += 1 + right_child.n_children
183
+ else:
184
+ right_child = None
185
+
186
+ is_leaf = "leaf_value" in tree
187
+ # get cover and value for leaf nodes or internal nodes
188
+ cover = tree.get("leaf_count", 0) or tree.get("internal_count", 0)
189
+ value = tree.get("leaf_value", 0) or tree.get("threshold", 0)
190
+ return Node(
191
+ cover=cover,
192
+ is_leaf=is_leaf,
193
+ default_left=tree.get("default_left", 0),
194
+ feature=tree.get("split_feature"),
195
+ value=value,
196
+ n_children=n_children,
197
+ left_child=left_child,
198
+ right_child=right_child,
199
+ )
200
+
201
+ @staticmethod
202
+ def from_treelite_dict(dict_all_nodes: list[dict[str, Any]], node_id: int) -> "Node":
203
+ this_node = dict_all_nodes[node_id]
204
+ is_leaf = "leaf_value" in this_node
205
+ default_left = this_node.get("default_left", False)
206
+
207
+ n_children = 0
208
+ if "left_child" in this_node:
209
+ left_child = Node.from_treelite_dict(dict_all_nodes, this_node["left_child"])
210
+ n_children += 1 + left_child.n_children
211
+ else:
212
+ left_child = None
213
+ if "right_child" in this_node:
214
+ right_child = Node.from_treelite_dict(
215
+ dict_all_nodes, this_node["right_child"]
216
+ )
217
+ n_children += 1 + right_child.n_children
218
+ else:
219
+ right_child = None
220
+
221
+ value = this_node["leaf_value"] if is_leaf else this_node["threshold"]
222
+ if not is_leaf:
223
+ comp = this_node["comparison_op"]
224
+ if comp == "<=":
225
+ value = float(np.nextafter(value, np.inf))
226
+ elif comp in [">", ">="]:
227
+ left_child, right_child = right_child, left_child
228
+ default_left = not default_left
229
+ if comp == ">":
230
+ value = float(np.nextafter(value, -np.inf))
231
+ elif comp != "<":
232
+ raise TypeError(
233
+ f"Model to convert contains unsupported split type: {comp}."
234
+ )
235
+
236
+ return Node(
237
+ cover=this_node.get("sum_hess", 0.0),
238
+ is_leaf=is_leaf,
239
+ default_left=default_left,
240
+ feature=this_node.get("split_feature_id"),
241
+ value=value,
242
+ n_children=n_children,
243
+ left_child=left_child,
244
+ right_child=right_child,
245
+ )
246
+
247
+ def get_value_closest_float_downward(self) -> np.float64:
248
+ """Get the closest exact fp value smaller than self.value"""
249
+ return np.nextafter(np.single(self.value), np.single(-np.inf))
250
+
251
+ def get_children(self) -> "Optional[Tuple[Node, Node]]":
252
+ if not self.left_child or not self.right_child:
253
+ assert self.is_leaf
254
+ else:
255
+ return (self.left_child, self.right_child)
256
+
257
+ @property
258
+ def feature(self) -> int:
259
+ if isinstance(self.__feature, int):
260
+ return self.__feature
261
+ if isinstance(self.__feature, str) and self.__feature.isnumeric():
262
+ return int(self.__feature)
263
+ raise AttributeError(
264
+ f"Feature names must be integers (got ({type(self.__feature)}){self.__feature})"
265
+ )
266
+
267
+
268
+ class TreeView:
269
+ """Helper class, treating a list of nodes as one tree"""
270
+
271
+ def __init__(self, tree_id: int, root_node: Node) -> None:
272
+ self.tree_id = tree_id
273
+ self.root_node = root_node
274
+
275
+ @property
276
+ def is_leaf(self) -> bool:
277
+ return self.root_node.is_leaf
278
+
279
+ @property
280
+ def value(self) -> float:
281
+ if not self.is_leaf:
282
+ raise AttributeError("Tree is not a leaf-only tree")
283
+ if self.root_node.value is None:
284
+ raise AttributeError("Tree is leaf-only but leaf node has no value")
285
+ return self.root_node.value
286
+
287
+ @property
288
+ def cover(self) -> float:
289
+ if not self.is_leaf:
290
+ raise AttributeError("Tree is not a leaf-only tree")
291
+ return self.root_node.cover
292
+
293
+ @property
294
+ def n_nodes(self) -> int:
295
+ return self.root_node.n_children + 1
296
+
297
+
298
+ class TreeList(list):
299
+ """Helper class that is able to extract all information required by the
300
+ model builders from various objects"""
301
+
302
+ @staticmethod
303
+ def from_xgb_booster(
304
+ booster, max_trees: int, feature_names_to_indices: dict[str, int]
305
+ ) -> "TreeList":
306
+ """
307
+ Load a TreeList from an xgb.Booster object
308
+ Note: We cannot type-hint the xgb.Booster without loading xgb as dependency in pyx code,
309
+ therefore not type hint is added.
310
+ """
311
+
312
+ # Note: in XGBoost, it's possible to use 'int' type for features that contain
313
+ # non-integer floating points. In such case, the training procedure and JSON
314
+ # export from XGBoost will not treat them any differently from 'q'-type
315
+ # (numeric) features, but the per-tree JSON text dumps used here will output
316
+ # a split threshold rounded to the nearest integer for those 'int' features,
317
+ # even if the booster internally has thresholds with decimal points and outputs
318
+ # them as such in the full-model JSON dumps. Hence the need for this override
319
+ # mechanism. If this behavior changes in XGBoost, then this conversion and
320
+ # override can be removed.
321
+ orig_feature_types = None
322
+ try:
323
+ if hasattr(booster, "feature_types"):
324
+ feature_types = booster.feature_types
325
+ orig_feature_types = deepcopy(feature_types)
326
+ if feature_types:
327
+ for i in range(len(feature_types)):
328
+ if feature_types[i] == "int":
329
+ feature_types[i] = "float"
330
+ booster.feature_types = feature_types
331
+
332
+ tl = TreeList()
333
+ dump = booster.get_dump(dump_format="json", with_stats=True)
334
+ finally:
335
+ if orig_feature_types:
336
+ booster.feature_types = orig_feature_types
337
+ for tree_id, raw_tree in enumerate(dump):
338
+ if max_trees > 0 and tree_id == max_trees:
339
+ break
340
+ raw_tree_parsed = json.loads(raw_tree)
341
+ root_node = Node.from_xgb_dict(raw_tree_parsed, feature_names_to_indices)
342
+ tl.append(TreeView(tree_id=tree_id, root_node=root_node))
343
+
344
+ return tl
345
+
346
+ @staticmethod
347
+ def from_lightgbm_booster_dump(dump: Dict[str, Any]) -> "TreeList":
348
+ """
349
+ Load a TreeList from a lgbm Booster dump
350
+ Note: We cannot type-hint the the Model without loading lightgbm as dependency in pyx code,
351
+ therefore not type hint is added.
352
+ """
353
+ tl = TreeList()
354
+ for tree_id, tree_dict in enumerate(dump["tree_info"]):
355
+ root_node = Node.from_lightgbm_dict(tree_dict)
356
+ tl.append(TreeView(tree_id=tree_id, root_node=root_node))
357
+
358
+ return tl
359
+
360
+ @staticmethod
361
+ def from_treelite_dict(tl_json: Dict[str, Any]) -> "TreeList":
362
+ tl = TreeList()
363
+ for tree_id, tree_dict in enumerate(tl_json["trees"]):
364
+ root_node = Node.from_treelite_dict(tree_dict["nodes"], 0)
365
+ tl.append(TreeView(tree_id=tree_id, root_node=root_node))
366
+ return tl
367
+
368
+ def __setitem__(self):
369
+ raise NotImplementedError(
370
+ "Use TreeList.from_*() methods to initialize a TreeList"
371
+ )
372
+
373
+
374
+ def get_lightgbm_params(booster):
375
+ return booster.dump_model()
376
+
377
+
378
+ def get_xgboost_params(booster):
379
+ return json.loads(booster.save_config())
380
+
381
+
382
+ def get_catboost_params(booster):
383
+ with NamedTemporaryFile() as fp:
384
+ booster.save_model(fp.name, "json")
385
+ fp.seek(0)
386
+ model_data = json.load(fp)
387
+ return model_data
388
+
389
+
390
+ def get_gbt_model_from_tree_list(
391
+ tree_list: TreeList,
392
+ n_iterations: int,
393
+ is_regression: bool,
394
+ n_features: int,
395
+ n_classes: int,
396
+ base_score: Optional[Union[float, List[float]]] = None,
397
+ ):
398
+ """Return a GBT Model from TreeList"""
399
+
400
+ if is_regression:
401
+ mb = gbt_reg_model_builder(n_features=n_features, n_iterations=n_iterations)
402
+ else:
403
+ mb = gbt_clf_model_builder(
404
+ n_features=n_features, n_iterations=n_iterations, n_classes=n_classes
405
+ )
406
+
407
+ class_label = 0
408
+ for counter, tree in enumerate(tree_list, start=1):
409
+ # find out the number of nodes in the tree
410
+ if is_regression:
411
+ tree_id = mb.create_tree(tree.n_nodes)
412
+ else:
413
+ tree_id = mb.create_tree(n_nodes=tree.n_nodes, class_label=class_label)
414
+
415
+ # Note: starting from xgboost>=3.1.0, multi-class classification models have
416
+ # vector-valued intercepts. Since oneDAL doesn't support these, it instead
417
+ # adds the scores to all of the terminal leafs in the first tree.
418
+ if isinstance(base_score, list) and counter <= n_classes:
419
+ intercept_add = base_score[counter - 1]
420
+ else:
421
+ intercept_add = 0.0
422
+
423
+ if counter % n_iterations == 0:
424
+ class_label += 1
425
+
426
+ if tree.is_leaf:
427
+ mb.add_leaf(
428
+ tree_id=tree_id, response=tree.value + intercept_add, cover=tree.cover
429
+ )
430
+ continue
431
+
432
+ root_node = tree.root_node
433
+ parent_id = mb.add_split(
434
+ tree_id=tree_id,
435
+ feature_index=root_node.feature,
436
+ feature_value=root_node.get_value_closest_float_downward(),
437
+ cover=root_node.cover,
438
+ default_left=root_node.default_left,
439
+ )
440
+
441
+ # create queue
442
+ node_queue: Deque[Node] = deque()
443
+ children = root_node.get_children()
444
+ assert children is not None
445
+ for position, child in enumerate(children):
446
+ child.parent_id = parent_id
447
+ child.position = position
448
+ node_queue.append(child)
449
+
450
+ while node_queue:
451
+ node = node_queue.popleft()
452
+ assert node.parent_id != -1, "node.parent_id must not be -1"
453
+ assert node.position != -1, "node.position must not be -1"
454
+
455
+ if node.is_leaf:
456
+ mb.add_leaf(
457
+ tree_id=tree_id,
458
+ response=node.value + intercept_add,
459
+ cover=node.cover,
460
+ parent_id=node.parent_id,
461
+ position=node.position,
462
+ )
463
+ else:
464
+ parent_id = mb.add_split(
465
+ tree_id=tree_id,
466
+ feature_index=node.feature,
467
+ feature_value=node.get_value_closest_float_downward(),
468
+ cover=node.cover,
469
+ default_left=node.default_left,
470
+ parent_id=node.parent_id,
471
+ position=node.position,
472
+ )
473
+
474
+ children = node.get_children()
475
+ assert children is not None
476
+ for position, child in enumerate(children):
477
+ child.parent_id = parent_id
478
+ child.position = position
479
+ node_queue.append(child)
480
+
481
+ return mb.model(base_score=base_score if isinstance(base_score, float) else None)
482
+
483
+
484
+ def get_gbt_model_from_lightgbm(model: Any, booster=None) -> Any:
485
+ model_str = model.model_to_string()
486
+ if "is_linear=1" in model_str:
487
+ raise TypeError("Linear trees are not supported.")
488
+ if "[boosting: dart]" in model_str:
489
+ raise TypeError(
490
+ "'Dart' booster is not supported. Try converting to 'treelite' first."
491
+ )
492
+ if "[boosting: rf]" in model_str:
493
+ raise TypeError("Random forest boosters are not supported.")
494
+ if ("[objective: lambdarank]" in model_str) or (
495
+ "[objective: rank_xendcg]" in model_str
496
+ ):
497
+ raise TypeError("Ranking objectives are not supported.")
498
+
499
+ if booster is None:
500
+ booster = model.dump_model()
501
+
502
+ n_features = booster["max_feature_idx"] + 1
503
+ n_iterations = len(booster["tree_info"]) / booster["num_tree_per_iteration"]
504
+ n_classes = booster["num_tree_per_iteration"]
505
+
506
+ is_regression = False
507
+ objective_fun = booster["objective"]
508
+ if n_classes > 2:
509
+ if ("ova" in objective_fun) or ("ovr" in objective_fun):
510
+ raise TypeError(
511
+ "Only multiclass (softmax) objective is supported for multiclass classification"
512
+ )
513
+ elif "binary" in objective_fun: # nClasses == 1
514
+ n_classes = 2
515
+ else:
516
+ is_regression = True
517
+
518
+ tree_list = TreeList.from_lightgbm_booster_dump(booster)
519
+
520
+ return get_gbt_model_from_tree_list(
521
+ tree_list,
522
+ n_iterations=n_iterations,
523
+ is_regression=is_regression,
524
+ n_features=n_features,
525
+ n_classes=n_classes,
526
+ )
527
+
528
+
529
+ def get_gbt_model_from_xgboost(booster: Any, xgb_config=None) -> Any:
530
+ # Note: in the absence of any feature names, XGBoost will generate
531
+ # tree json dumps where features are named 'f0..N'. While the JSONs
532
+ # of the whole model will have feature indices, the per-tree JSONs
533
+ # used here always use string names instead, hence the need for this.
534
+ feature_names = booster.feature_names
535
+ if feature_names:
536
+ feature_names_to_indices = {fname: ind for ind, fname in enumerate(feature_names)}
537
+ else:
538
+ feature_names_to_indices = {
539
+ f"f{ind}": ind for ind in range(booster.num_features())
540
+ }
541
+
542
+ if xgb_config is None:
543
+ xgb_config = get_xgboost_params(booster)
544
+
545
+ if xgb_config["learner"]["learner_train_param"]["booster"] != "gbtree":
546
+ raise TypeError(
547
+ "Only 'gbtree' booster type is supported. For DART, try converting to 'treelite' first."
548
+ )
549
+
550
+ n_targets = xgb_config["learner"]["learner_model_param"].get("num_target")
551
+ if n_targets is not None and int(n_targets) > 1:
552
+ raise TypeError("Multi-target boosters are not supported.")
553
+
554
+ n_features = int(xgb_config["learner"]["learner_model_param"]["num_feature"])
555
+ n_classes = int(xgb_config["learner"]["learner_model_param"]["num_class"])
556
+ # Note: base scores in XGBoost might be vector-valued starting from version 3.1.0.
557
+ # When this is the case, the 'base_score' attribute will be a JSON list, otherwise
558
+ # it will be a scalar. Note that in either case, it will be in the response scale.
559
+ base_score_str: str = xgb_config["learner"]["learner_model_param"]["base_score"]
560
+ if base_score_str.startswith("["):
561
+ base_score = json.loads(base_score_str)
562
+ if len(base_score) == 1:
563
+ base_score = base_score[0]
564
+ elif len(base_score) == 0:
565
+ base_score = 0.5
566
+ else:
567
+ base_score = float(base_score_str)
568
+
569
+ is_regression = False
570
+ objective_fun = xgb_config["learner"]["learner_train_param"]["objective"]
571
+
572
+ # Note: the base score from XGBoost is in the response scale, but the predictions
573
+ # are calculated in the link scale, so when there is a non-identity link function,
574
+ # it needs to be converted to the link scale.
575
+ if objective_fun in ["count:poisson", "reg:gamma", "reg:tweedie", "survival:aft"]:
576
+ base_score = float(np.log(base_score))
577
+ elif objective_fun == "reg:logistic":
578
+ base_score = float(np.log(base_score / (1 - base_score)))
579
+ elif objective_fun.startswith("rank"):
580
+ raise TypeError("Ranking objectives are not supported.")
581
+
582
+ if n_classes > 2:
583
+ if objective_fun not in ["multi:softprob", "multi:softmax"]:
584
+ raise TypeError(
585
+ "multi:softprob and multi:softmax are only supported for multiclass classification"
586
+ )
587
+ elif objective_fun.startswith("binary:"):
588
+ if objective_fun not in ["binary:logistic", "binary:logitraw"]:
589
+ raise TypeError(
590
+ "only binary:logistic and binary:logitraw are supported for binary classification"
591
+ )
592
+ n_classes = 2
593
+ if objective_fun == "binary:logitraw":
594
+ # daal4py always applies a sigmoid for pred_proba, wheres XGBoost
595
+ # returns raw predictions with logitraw
596
+ base_score = float(1 / (1 + np.exp(-base_score)))
597
+ else:
598
+ is_regression = True
599
+
600
+ # max_trees=0 if best_iteration does not exist
601
+ max_trees = getattr(booster, "best_iteration", -1) + 1
602
+ if n_classes > 2:
603
+ max_trees *= n_classes
604
+ tree_list = TreeList.from_xgb_booster(booster, max_trees, feature_names_to_indices)
605
+
606
+ if hasattr(booster, "best_iteration"):
607
+ n_iterations = booster.best_iteration + 1
608
+ else:
609
+ n_iterations = len(tree_list) // (n_classes if n_classes > 2 else 1)
610
+
611
+ return get_gbt_model_from_tree_list(
612
+ tree_list,
613
+ n_iterations=n_iterations,
614
+ is_regression=is_regression,
615
+ n_features=n_features,
616
+ n_classes=n_classes,
617
+ base_score=base_score,
618
+ )
619
+
620
+
621
+ def __get_value_as_list(node):
622
+ """Make sure the values are a list"""
623
+ values = node["value"]
624
+ if isinstance(values, (list, tuple)):
625
+ return values
626
+ else:
627
+ return [values]
628
+
629
+
630
+ def __calc_node_weights_from_leaf_weights(weights):
631
+ def sum_pairs(values):
632
+ assert len(values) % 2 == 0, "Length of values must be even"
633
+ return [values[i] + values[i + 1] for i in range(0, len(values), 2)]
634
+
635
+ level_weights = sum_pairs(weights)
636
+ result = [level_weights]
637
+ while len(level_weights) > 1:
638
+ level_weights = sum_pairs(level_weights)
639
+ result.append(level_weights)
640
+ return result[::-1]
641
+
642
+
643
+ def get_gbt_model_from_catboost(booster: Any) -> Any:
644
+ if not booster.is_fitted():
645
+ raise RuntimeError("Model should be fitted before exporting to daal4py.")
646
+
647
+ model = CatBoostModelData(get_catboost_params(booster))
648
+
649
+ if model.has_categorical_features:
650
+ raise NotImplementedError(
651
+ "Categorical features are not supported in daal4py Gradient Boosting Trees"
652
+ )
653
+
654
+ objective = booster.get_params().get("objective", "")
655
+ if (
656
+ "Rank" in objective
657
+ or "Query" in objective
658
+ or "Pair" in objective
659
+ or objective in ["LambdaMart", "StochasticFilter", "GroupQuantile"]
660
+ ):
661
+ raise TypeError("Ranking objectives are not supported.")
662
+ if "Multi" in objective and objective != "MultiClass":
663
+ if model.is_classification:
664
+ raise TypeError(
665
+ "Only 'MultiClass' loss is supported for multi-class classification."
666
+ )
667
+ else:
668
+ raise TypeError("Multi-output models are not supported.")
669
+
670
+ if model.is_classification:
671
+ mb = gbt_clf_model_builder(
672
+ n_features=model.n_features,
673
+ n_iterations=model.n_iterations,
674
+ n_classes=model.n_classes,
675
+ )
676
+ else:
677
+ mb = gbt_reg_model_builder(
678
+ n_features=model.n_features, n_iterations=model.n_iterations
679
+ )
680
+
681
+ # Create splits array (all splits are placed sequentially)
682
+ splits = []
683
+ for feature in model.float_features:
684
+ if feature["borders"]:
685
+ for feature_border in feature["borders"]:
686
+ splits.append(
687
+ {"feature_index": feature["feature_index"], "value": feature_border}
688
+ )
689
+
690
+ # Note: catboost models might have a 'bias' (intercept) which gets added
691
+ # to all predictions. In the case of single-output models, this is a scalar,
692
+ # but in the case of multi-output models such as multinomial logistic, it
693
+ # is a vector. Since daal4py doesn't support vector-valued intercepts, this
694
+ # adds the intercept to every terminal node instead, by dividing it equally
695
+ # among all trees. Usually, catboost would anyway set them to zero, but it
696
+ # still allows setting custom intercepts.
697
+ cb_bias = booster.get_scale_and_bias()[1]
698
+ add_intercept_to_each_node = isinstance(cb_bias, list)
699
+ if add_intercept_to_each_node:
700
+ cb_bias = np.array(cb_bias) / model.n_iterations
701
+ if not model.is_classification:
702
+ raise TypeError("Multi-output regression models are not supported.")
703
+
704
+ def add_vector_bias(values: list[float]) -> list[float]:
705
+ return list(np.array(values) + cb_bias)
706
+
707
+ trees_explicit = []
708
+ tree_symmetric = []
709
+
710
+ all_trees_are_empty = True
711
+
712
+ if model.is_symmetric_tree:
713
+ for tree in model.oblivious_trees:
714
+ tree_splits = tree.get("splits", [])
715
+ cur_tree_depth = len(tree_splits) if tree_splits is not None else 0
716
+ tree_symmetric.append((tree, cur_tree_depth))
717
+ else:
718
+ for tree in model.trees:
719
+ n_nodes = 1
720
+
721
+ # Check if node is a leaf (in case of stump)
722
+ if "split" in tree:
723
+ # Get number of trees and splits info via BFS
724
+ # Create queue
725
+ nodes_queue = []
726
+ root_node = CatBoostNode(split=splits[tree["split"]["split_index"]])
727
+ nodes_queue.append((tree, root_node))
728
+ while nodes_queue:
729
+ cur_node_data, cur_node = nodes_queue.pop(0)
730
+ if "value" in cur_node_data:
731
+ cur_node.value = __get_value_as_list(cur_node_data)
732
+ else:
733
+ cur_node.split = splits[cur_node_data["split"]["split_index"]]
734
+ left_node = CatBoostNode()
735
+ right_node = CatBoostNode()
736
+ cur_node.left = left_node
737
+ cur_node.right = right_node
738
+ nodes_queue.append((cur_node_data["left"], left_node))
739
+ nodes_queue.append((cur_node_data["right"], right_node))
740
+ n_nodes += 2
741
+ all_trees_are_empty = False
742
+ else:
743
+ root_node = CatBoostNode()
744
+ if model.is_classification and model.n_classes > 2:
745
+ root_node.value = [value * model.scale for value in tree["value"]]
746
+ if add_intercept_to_each_node:
747
+ root_node.value = add_vector_bias(root_node.value)
748
+ else:
749
+ root_node.value = [tree["value"] * model.scale]
750
+ trees_explicit.append((root_node, n_nodes))
751
+
752
+ tree_id = []
753
+ class_label = 0
754
+ count = 0
755
+
756
+ # Only 1 tree for each iteration in case of regression or binary classification
757
+ if not model.is_classification or model.n_classes == 2:
758
+ n_tree_each_iter = 1
759
+ else:
760
+ n_tree_each_iter = model.n_classes
761
+
762
+ shap_ready = False
763
+
764
+ # Create id for trees (for the right order in model builder)
765
+ for i in range(model.n_iterations):
766
+ for _ in range(n_tree_each_iter):
767
+ if model.is_symmetric_tree:
768
+ if not len(tree_symmetric):
769
+ n_nodes = 1
770
+ else:
771
+ n_nodes = 2 ** (tree_symmetric[i][1] + 1) - 1
772
+ else:
773
+ if not len(trees_explicit):
774
+ n_nodes = 1
775
+ else:
776
+ n_nodes = trees_explicit[i][1]
777
+
778
+ if model.is_classification and model.n_classes > 2:
779
+ tree_id.append(mb.create_tree(n_nodes, class_label))
780
+ count += 1
781
+ if count == model.n_iterations:
782
+ class_label += 1
783
+ count = 0
784
+
785
+ elif model.is_classification:
786
+ tree_id.append(mb.create_tree(n_nodes, 0))
787
+ else:
788
+ tree_id.append(mb.create_tree(n_nodes))
789
+
790
+ if model.is_symmetric_tree:
791
+ shap_ready = True # this code branch provides all info for SHAP values
792
+ for class_label in range(n_tree_each_iter):
793
+ for i in range(model.n_iterations):
794
+ cur_tree_info = tree_symmetric[i][0]
795
+ cur_tree_id = tree_id[i * n_tree_each_iter + class_label]
796
+ cur_tree_leaf_val = cur_tree_info["leaf_values"]
797
+ cur_tree_leaf_weights = cur_tree_info["leaf_weights"]
798
+ cur_tree_depth = tree_symmetric[i][1]
799
+ if cur_tree_depth == 0:
800
+ mb.add_leaf(
801
+ tree_id=cur_tree_id,
802
+ response=cur_tree_leaf_val[class_label] * model.scale
803
+ + (cb_bias[class_label] if add_intercept_to_each_node else 0),
804
+ cover=cur_tree_leaf_weights[0],
805
+ )
806
+ else:
807
+ # One split used for the whole level
808
+ cur_level_split = splits[
809
+ cur_tree_info["splits"][cur_tree_depth - 1]["split_index"]
810
+ ]
811
+ cur_tree_weights_per_level = __calc_node_weights_from_leaf_weights(
812
+ cur_tree_leaf_weights
813
+ )
814
+ root_weight = cur_tree_weights_per_level[0][0]
815
+
816
+ root_id = mb.add_split(
817
+ tree_id=cur_tree_id,
818
+ feature_index=cur_level_split["feature_index"],
819
+ feature_value=cur_level_split["value"],
820
+ default_left=model.default_left,
821
+ cover=root_weight,
822
+ )
823
+ prev_level_nodes = [root_id]
824
+
825
+ # Iterate over levels, splits in json are reversed (root split is the last)
826
+ for cur_level in range(cur_tree_depth - 2, -1, -1):
827
+ cur_level_nodes = []
828
+ next_level_weights = cur_tree_weights_per_level[cur_level + 1]
829
+ cur_level_node_index = 0
830
+ for cur_parent in prev_level_nodes:
831
+ cur_level_split = splits[
832
+ cur_tree_info["splits"][cur_level]["split_index"]
833
+ ]
834
+ cover_nodes = next_level_weights[cur_level_node_index]
835
+ if cover_nodes == 0:
836
+ shap_ready = False
837
+ cur_left_node = mb.add_split(
838
+ tree_id=cur_tree_id,
839
+ parent_id=cur_parent,
840
+ position=0,
841
+ feature_index=cur_level_split["feature_index"],
842
+ feature_value=cur_level_split["value"],
843
+ default_left=model.default_left,
844
+ cover=cover_nodes,
845
+ )
846
+ # cur_level_node_index += 1
847
+ cur_right_node = mb.add_split(
848
+ tree_id=cur_tree_id,
849
+ parent_id=cur_parent,
850
+ position=1,
851
+ feature_index=cur_level_split["feature_index"],
852
+ feature_value=cur_level_split["value"],
853
+ default_left=model.default_left,
854
+ cover=cover_nodes,
855
+ )
856
+ # cur_level_node_index += 1
857
+ cur_level_nodes.append(cur_left_node)
858
+ cur_level_nodes.append(cur_right_node)
859
+ prev_level_nodes = cur_level_nodes
860
+
861
+ # Different storing format for leaves
862
+ if not model.is_classification or model.n_classes == 2:
863
+ for last_level_node_num in range(len(prev_level_nodes)):
864
+ mb.add_leaf(
865
+ tree_id=cur_tree_id,
866
+ response=cur_tree_leaf_val[2 * last_level_node_num]
867
+ * model.scale,
868
+ parent_id=prev_level_nodes[last_level_node_num],
869
+ position=0,
870
+ cover=cur_tree_leaf_weights[2 * last_level_node_num],
871
+ )
872
+ mb.add_leaf(
873
+ tree_id=cur_tree_id,
874
+ response=cur_tree_leaf_val[2 * last_level_node_num + 1]
875
+ * model.scale,
876
+ parent_id=prev_level_nodes[last_level_node_num],
877
+ position=1,
878
+ cover=cur_tree_leaf_weights[2 * last_level_node_num + 1],
879
+ )
880
+ else:
881
+ shap_ready = False
882
+ for last_level_node_num in range(len(prev_level_nodes)):
883
+ left_index = (
884
+ 2 * last_level_node_num * n_tree_each_iter + class_label
885
+ )
886
+ right_index = (
887
+ 2 * last_level_node_num + 1
888
+ ) * n_tree_each_iter + class_label
889
+ mb.add_leaf(
890
+ tree_id=cur_tree_id,
891
+ response=cur_tree_leaf_val[left_index] * model.scale
892
+ + (
893
+ cb_bias[class_label]
894
+ if add_intercept_to_each_node
895
+ else 0
896
+ ),
897
+ parent_id=prev_level_nodes[last_level_node_num],
898
+ position=0,
899
+ cover=0.0,
900
+ )
901
+ mb.add_leaf(
902
+ tree_id=cur_tree_id,
903
+ response=cur_tree_leaf_val[right_index] * model.scale
904
+ + (
905
+ cb_bias[class_label]
906
+ if add_intercept_to_each_node
907
+ else 0
908
+ ),
909
+ parent_id=prev_level_nodes[last_level_node_num],
910
+ position=1,
911
+ cover=0.0,
912
+ )
913
+ else:
914
+ shap_ready = False
915
+ scale = booster.get_scale_and_bias()[0]
916
+ for class_label in range(n_tree_each_iter):
917
+ for i in range(model.n_iterations):
918
+ root_node = trees_explicit[i][0]
919
+
920
+ cur_tree_id = tree_id[i * n_tree_each_iter + class_label]
921
+ # Traverse tree via BFS and build tree with modelbuilder
922
+ if root_node.value is None:
923
+ root_id = mb.add_split(
924
+ tree_id=cur_tree_id,
925
+ feature_index=root_node.split["feature_index"],
926
+ feature_value=root_node.split["value"],
927
+ default_left=model.default_left,
928
+ cover=0.0,
929
+ )
930
+ nodes_queue = [(root_node, root_id)]
931
+ while nodes_queue:
932
+ cur_node, cur_node_id = nodes_queue.pop(0)
933
+ left_node = cur_node.left
934
+ # Check if node is a leaf
935
+ if left_node.value is None:
936
+ left_node_id = mb.add_split(
937
+ tree_id=cur_tree_id,
938
+ parent_id=cur_node_id,
939
+ position=0,
940
+ feature_index=left_node.split["feature_index"],
941
+ feature_value=left_node.split["value"],
942
+ default_left=model.default_left,
943
+ cover=0.0,
944
+ )
945
+ nodes_queue.append((left_node, left_node_id))
946
+ else:
947
+ mb.add_leaf(
948
+ tree_id=cur_tree_id,
949
+ response=scale * left_node.value[class_label]
950
+ + (
951
+ cb_bias[class_label]
952
+ if add_intercept_to_each_node
953
+ else 0
954
+ ),
955
+ parent_id=cur_node_id,
956
+ position=0,
957
+ cover=0.0,
958
+ )
959
+ right_node = cur_node.right
960
+ # Check if node is a leaf
961
+ if right_node.value is None:
962
+ right_node_id = mb.add_split(
963
+ tree_id=cur_tree_id,
964
+ parent_id=cur_node_id,
965
+ position=1,
966
+ feature_index=right_node.split["feature_index"],
967
+ feature_value=right_node.split["value"],
968
+ default_left=model.default_left,
969
+ cover=0.0,
970
+ )
971
+ nodes_queue.append((right_node, right_node_id))
972
+ else:
973
+ mb.add_leaf(
974
+ tree_id=cur_tree_id,
975
+ response=scale * cur_node.right.value[class_label]
976
+ + (
977
+ cb_bias[class_label]
978
+ if add_intercept_to_each_node
979
+ else 0
980
+ ),
981
+ parent_id=cur_node_id,
982
+ position=1,
983
+ cover=0.0,
984
+ )
985
+
986
+ else:
987
+ # Tree has only one node
988
+ # Note: the root node already has scale and bias added to it,
989
+ # so no need to add them again here like it is done for the leafs.
990
+ mb.add_leaf(
991
+ tree_id=cur_tree_id,
992
+ response=root_node.value[class_label],
993
+ cover=0.0,
994
+ )
995
+
996
+ if all_trees_are_empty and not model.is_symmetric_tree:
997
+ shap_ready = True
998
+
999
+ intercept = 0.0
1000
+ if not add_intercept_to_each_node:
1001
+ intercept = booster.get_scale_and_bias()[1]
1002
+ return mb.model(base_score=intercept), shap_ready
1003
+
1004
+
1005
+ def get_gbt_model_from_treelite(
1006
+ tl_model: "treelite.model.Model",
1007
+ ) -> tuple[Any, int, int, bool]:
1008
+ model_json = json.loads(tl_model.dump_as_json())
1009
+ task_type = model_json["task_type"]
1010
+ if task_type not in ["kBinaryClf", "kRegressor", "kMultiClf", "kIsolationForest"]:
1011
+ raise TypeError(f"Model to convert is of unsupported type: {task_type}")
1012
+ if model_json["num_target"] > 1:
1013
+ raise TypeError("Multi-target models are not supported.")
1014
+ if model_json["postprocessor"] == "multiclass_ova":
1015
+ raise TypeError(
1016
+ "Multi-class classification models that use One-Vs-All are not supported."
1017
+ )
1018
+ for tree in model_json["trees"]:
1019
+ if tree["has_categorical_split"]:
1020
+ raise TypeError("Models with categorical features are not supported.")
1021
+ num_trees = tl_model.num_tree
1022
+ if not num_trees:
1023
+ raise TypeError("Model to convert contains no trees.")
1024
+
1025
+ # Note: the daal4py module always adds up the scores, but some models
1026
+ # might average them instead. In such case, this turns the trees into
1027
+ # additive ones by dividing the predictions by the number of nodes beforehand.
1028
+ if model_json["average_tree_output"]:
1029
+ divide_treelite_leaf_values_by_const(model_json, num_trees)
1030
+
1031
+ base_score = model_json["base_scores"]
1032
+ num_class = model_json["num_class"][0]
1033
+ num_feature = model_json["num_feature"]
1034
+
1035
+ if task_type == "kBinaryClf":
1036
+ num_class = 2
1037
+ if base_score:
1038
+ base_score = list(1 / (1 + np.exp(-np.array(base_score))))
1039
+
1040
+ if num_class > 2:
1041
+ shap_ready = False
1042
+ else:
1043
+ shap_ready = True
1044
+ for tree in model_json["trees"]:
1045
+ if not tree["nodes"][0].get("sum_hess", False):
1046
+ shap_ready = False
1047
+ break
1048
+
1049
+ # In the case of random forests for classification, it might work
1050
+ # by averaging predictions without any link function, whereas
1051
+ # daal4py assumes a logit link. In such case, it's not possible to
1052
+ # convert them to daal4py's logic, but the model can still be used
1053
+ # as a regressor that always outputs something between 0 and 1.
1054
+ is_regression = "Clf" not in task_type
1055
+ if not is_regression and model_json["postprocessor"] == "identity_multiclass":
1056
+ is_regression = True
1057
+ warnings.warn(
1058
+ "Attempting to convert classification model which is not"
1059
+ " based on gradient boosting. Will output a regression"
1060
+ " model instead."
1061
+ )
1062
+
1063
+ looks_like_random_forest = (
1064
+ model_json["postprocessor"] == "identity_multiclass"
1065
+ and len(model_json["base_scores"]) > 1
1066
+ and task_type == "kMultiClf"
1067
+ )
1068
+ if looks_like_random_forest:
1069
+ if num_class > 2 or len(base_score) > 2:
1070
+ raise TypeError("Multi-class random forests are not supported.")
1071
+ if len(model_json["num_class"]) > 1:
1072
+ raise TypeError("Multi-output random forests are not supported.")
1073
+ if len(base_score) == 2 and base_score[0]:
1074
+ raise TypeError("Random forests with base scores are not supported.")
1075
+
1076
+ # In the case of binary random forests, it will always have leaf values
1077
+ # for 2 classes, which is redundant as they sum to 1. daal4py requires
1078
+ # only values for the positive class, so they need to be converted.
1079
+ if looks_like_random_forest:
1080
+ leave_only_last_treelite_leaf_value(model_json)
1081
+ base_score = base_score[-1]
1082
+
1083
+ # In the case of multi-class classification models, if converted
1084
+ # from xgboost, the order of the trees will be the same - i.e.
1085
+ # sequences of one tree of each class, followed by another such
1086
+ # sequence. But treelite could in theory also support building
1087
+ # models where the trees are in a different order, in which case
1088
+ # they will need to be reordered to match xgboost, since that's
1089
+ # how daal4py handles them. And if there is an uneven number of
1090
+ # trees per class, then will need to make up extra trees with
1091
+ # zeros to accommodate it.
1092
+ if task_type == "kMultiClf" and not looks_like_random_forest:
1093
+ num_trees = len(model_json["trees"])
1094
+ if (num_trees % num_class) != 0:
1095
+ shap_ready = False
1096
+ class_ids, num_trees_per_class = np.unique(
1097
+ model_json["class_id"], return_counts=True
1098
+ )
1099
+ max_tree_per_class = num_trees_per_class.max()
1100
+ num_tree_add_per_class = max_tree_per_class - num_trees_per_class
1101
+ for class_ind in range(num_class):
1102
+ for tree in range(num_tree_add_per_class[class_ind]):
1103
+ add_empty_tree_to_treelite_json(model_json, class_ind)
1104
+
1105
+ tree_class_orders = model_json["class_id"]
1106
+ sequential_ids = np.arange(num_class)
1107
+ num_trees = len(model_json["trees"])
1108
+ assert (num_trees % num_class) == 0
1109
+ if not np.array_equal(
1110
+ tree_class_orders, np.tile(sequential_ids, int(num_trees / num_class))
1111
+ ):
1112
+ argsorted_class_indices = np.argsort(tree_class_orders)
1113
+ per_class_indices = np.split(argsorted_class_indices, num_class)
1114
+ correct_order = np.vstack(per_class_indices).reshape(-1, order="F")
1115
+ model_json["trees"] = [model_json["trees"][ix] for ix in correct_order]
1116
+ model_json["class_id"] = [model_json["class_id"][ix] for ix in correct_order]
1117
+
1118
+ # In the case of multi-class classification with base scores,
1119
+ # since daal4py only supports scalar intercepts, this follows the
1120
+ # same strategy as in catboost of dividing the intercepts equally
1121
+ # among the number of trees
1122
+ if task_type == "kMultiClf" and not looks_like_random_forest:
1123
+ add_intercept_to_treelite_leafs(model_json, base_score)
1124
+ base_score = None
1125
+
1126
+ if isinstance(base_score, list):
1127
+ if len(base_score) == 1:
1128
+ base_score = base_score[0]
1129
+ else:
1130
+ raise TypeError("Model to convert is malformed.")
1131
+
1132
+ tree_list = TreeList.from_treelite_dict(model_json)
1133
+ return (
1134
+ get_gbt_model_from_tree_list(
1135
+ tree_list,
1136
+ n_iterations=num_trees
1137
+ / (
1138
+ num_class
1139
+ if task_type == "kMultiClf" and not looks_like_random_forest
1140
+ else 1
1141
+ ),
1142
+ is_regression=is_regression,
1143
+ n_features=num_feature,
1144
+ n_classes=num_class,
1145
+ base_score=base_score,
1146
+ ),
1147
+ num_class,
1148
+ num_feature,
1149
+ shap_ready,
1150
+ )
1151
+
1152
+
1153
+ def divide_treelite_leaf_values_by_const(
1154
+ tl_json: dict[str, Any], divisor: "int | float"
1155
+ ) -> None:
1156
+ for tree in tl_json["trees"]:
1157
+ for node in tree["nodes"]:
1158
+ if "leaf_value" in node:
1159
+ if isinstance(node["leaf_value"], (list, tuple)):
1160
+ node["leaf_value"] = list(np.array(node["leaf_value"]) / divisor)
1161
+ else:
1162
+ node["leaf_value"] /= divisor
1163
+
1164
+
1165
+ def leave_only_last_treelite_leaf_value(tl_json: dict[str, Any]) -> None:
1166
+ for tree in tl_json["trees"]:
1167
+ for node in tree["nodes"]:
1168
+ if "leaf_value" in node:
1169
+ assert len(node["leaf_value"]) == 2
1170
+ node["leaf_value"] = node["leaf_value"][-1]
1171
+
1172
+
1173
+ def add_intercept_to_treelite_leafs(
1174
+ tl_json: dict[str, Any], base_score: list[float]
1175
+ ) -> None:
1176
+ num_trees_per_class = len(tl_json["trees"]) / tl_json["num_class"][0]
1177
+ for tree_index, tree in enumerate(tl_json["trees"]):
1178
+ leaf_add = base_score[tl_json["class_id"][tree_index]] / num_trees_per_class
1179
+ for node in tree["nodes"]:
1180
+ if "leaf_value" in node:
1181
+ node["leaf_value"] += leaf_add
1182
+
1183
+
1184
+ def add_empty_tree_to_treelite_json(tl_json: dict[str, Any], class_add: int) -> None:
1185
+ tl_json["class_id"].append(class_add)
1186
+ tl_json["trees"].append(
1187
+ {
1188
+ "num_nodes": 1,
1189
+ "has_categorical_split": False,
1190
+ "nodes": [
1191
+ {
1192
+ "node_id": 0,
1193
+ "leaf_value": 0.0,
1194
+ "data_count": 0,
1195
+ "sum_hess": 0.0,
1196
+ },
1197
+ ],
1198
+ }
1199
+ )