scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (267) hide show
  1. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
  2. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
  3. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
  4. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
  5. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
  6. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
  7. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
  8. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
  9. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
  10. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
  11. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
  12. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
  13. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
  14. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
  15. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  16. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
  17. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
  18. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  19. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
  20. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
  21. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
  22. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  23. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
  24. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
  25. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
  26. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
  27. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  28. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
  29. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  30. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
  31. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
  32. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
  33. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
  34. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  35. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
  36. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
  37. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
  38. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
  39. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
  40. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
  41. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
  42. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  43. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
  44. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  45. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
  46. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  47. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  48. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  49. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
  50. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
  51. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
  52. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
  53. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  54. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  55. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
  56. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
  57. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
  58. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
  59. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
  60. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
  61. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
  62. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
  63. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
  64. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
  65. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
  66. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
  67. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
  68. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
  69. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
  70. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  71. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
  72. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
  73. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
  74. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
  75. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
  76. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
  77. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
  78. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
  79. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
  80. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
  81. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
  82. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
  83. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
  84. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
  85. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
  86. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
  87. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
  88. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
  89. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
  90. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
  91. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
  92. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
  93. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
  94. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
  95. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
  96. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
  97. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
  98. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
  99. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
  100. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
  101. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
  102. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
  103. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
  104. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
  105. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
  106. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
  107. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
  108. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  109. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  110. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
  111. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
  112. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
  113. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
  114. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
  115. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
  116. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
  117. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
  118. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
  119. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
  120. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
  121. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
  122. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
  123. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
  124. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
  125. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
  126. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
  127. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
  128. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
  129. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
  130. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
  131. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
  132. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
  133. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
  134. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
  135. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
  136. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
  137. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
  138. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
  139. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
  140. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
  141. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  142. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
  143. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
  144. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  145. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  146. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
  147. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
  148. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
  149. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
  150. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
  151. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
  152. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  153. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
  154. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
  155. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
  156. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
  157. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
  158. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
  159. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
  160. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
  161. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
  162. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
  163. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
  164. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
  165. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
  166. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
  167. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
  168. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
  169. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
  170. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
  171. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
  172. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
  173. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
  174. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
  175. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  176. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  177. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
  178. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
  179. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
  180. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
  181. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
  182. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
  183. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
  184. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
  185. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
  186. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
  187. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
  188. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
  189. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
  190. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
  191. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
  192. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
  193. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
  194. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
  195. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
  196. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
  197. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
  198. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
  199. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
  200. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
  201. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  202. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
  203. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
  204. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
  205. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
  206. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
  207. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  208. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
  209. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
  210. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  211. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
  212. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  213. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
  214. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
  215. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
  216. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
  217. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
  218. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
  219. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
  220. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
  221. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
  222. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
  223. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
  224. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
  225. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  226. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
  227. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
  228. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
  229. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
  230. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
  231. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  232. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
  233. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
  234. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
  235. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
  236. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
  237. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
  238. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
  239. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
  240. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
  241. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
  242. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
  243. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
  244. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
  245. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
  246. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
  247. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
  248. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
  249. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
  250. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
  251. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
  252. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
  253. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
  254. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
  255. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
  256. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  257. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
  258. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
  259. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
  260. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
  261. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
  262. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
  263. scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
  264. scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
  265. scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
  266. scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
  267. scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1799 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABC
20
+ from collections.abc import Iterable
21
+
22
+ import numpy as np
23
+ from scipy import sparse as sp
24
+ from sklearn.base import BaseEstimator, clone
25
+ from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
26
+ from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
27
+ from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
28
+ from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
29
+ from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
30
+ from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
31
+ from sklearn.ensemble._forest import _get_n_samples_bootstrap
32
+ from sklearn.exceptions import DataConversionWarning
33
+ from sklearn.metrics import accuracy_score, r2_score
34
+ from sklearn.tree import (
35
+ DecisionTreeClassifier,
36
+ DecisionTreeRegressor,
37
+ ExtraTreeClassifier,
38
+ ExtraTreeRegressor,
39
+ )
40
+ from sklearn.tree._tree import Tree
41
+ from sklearn.utils import check_random_state, deprecated
42
+ from sklearn.utils.validation import (
43
+ _check_sample_weight,
44
+ check_array,
45
+ check_is_fitted,
46
+ check_X_y,
47
+ )
48
+
49
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
50
+ from daal4py.sklearn._utils import (
51
+ check_tree_nodes,
52
+ daal_check_version,
53
+ sklearn_check_version,
54
+ )
55
+ from onedal._device_offload import support_input_format
56
+ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
57
+ from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
58
+ from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
59
+ from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
60
+ from onedal.primitives import get_tree_state_cls, get_tree_state_reg
61
+ from onedal.utils.validation import _num_features, _num_samples
62
+ from sklearnex._utils import register_hyperparameters
63
+
64
+ from .._config import get_config
65
+ from .._device_offload import dispatch, wrap_output_data
66
+ from .._utils import PatchingConditionsChain
67
+ from ..base import oneDALEstimator
68
+ from ..utils._array_api import get_namespace
69
+ from ..utils.validation import check_n_features, validate_data
70
+
71
+ if sklearn_check_version("1.2"):
72
+ from sklearn.utils._param_validation import Interval
73
+ if sklearn_check_version("1.4"):
74
+ from daal4py.sklearn.utils import _assert_all_finite
75
+
76
+
77
+ class BaseForest(oneDALEstimator, ABC):
78
+ _onedal_factory = None
79
+
80
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
81
+ use_raw_input = get_config().get("use_raw_input", False) is True
82
+ xp, _ = get_namespace(X)
83
+ if not use_raw_input:
84
+ X, y = validate_data(
85
+ self,
86
+ X,
87
+ y,
88
+ multi_output=True,
89
+ accept_sparse=False,
90
+ dtype=[np.float64, np.float32],
91
+ ensure_all_finite=False,
92
+ ensure_2d=True,
93
+ )
94
+
95
+ if sample_weight is not None:
96
+ sample_weight = _check_sample_weight(sample_weight, X)
97
+
98
+ if y.ndim == 2 and y.shape[1] == 1:
99
+ warnings.warn(
100
+ "A column-vector y was passed when a 1d array was"
101
+ " expected. Please change the shape of y to "
102
+ "(n_samples,), for example using ravel().",
103
+ DataConversionWarning,
104
+ stacklevel=2,
105
+ )
106
+
107
+ if y.ndim == 1:
108
+ # reshape is necessary to preserve the data contiguity against vs
109
+ # [:, np.newaxis] that does not.
110
+ y = xp.reshape(y, (-1, 1))
111
+
112
+ self._n_samples, self.n_outputs_ = y.shape
113
+
114
+ if not use_raw_input:
115
+ y, expanded_class_weight = self._validate_y_class_weight(y)
116
+
117
+ if expanded_class_weight is not None:
118
+ if sample_weight is not None:
119
+ sample_weight = sample_weight * expanded_class_weight
120
+ else:
121
+ sample_weight = expanded_class_weight
122
+ if sample_weight is not None:
123
+ sample_weight = [sample_weight]
124
+ else:
125
+ # try catch needed for raw_inputs + array_api data where unlike
126
+ # numpy the way to yield unique values is via `unique_values`
127
+ # This should be removed when refactored for gpu zero-copy
128
+ try:
129
+ self.classes_ = xp.unique(y)
130
+ except AttributeError:
131
+ self.classes_ = xp.unique_values(y)
132
+ self.n_classes_ = len(self.classes_)
133
+ self.n_features_in_ = X.shape[1]
134
+
135
+ onedal_params = {
136
+ "n_estimators": self.n_estimators,
137
+ "criterion": self.criterion,
138
+ "max_depth": self.max_depth,
139
+ "min_samples_split": self.min_samples_split,
140
+ "min_samples_leaf": self.min_samples_leaf,
141
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
142
+ "max_features": self._to_absolute_max_features(
143
+ self.max_features, self.n_features_in_
144
+ ),
145
+ "max_leaf_nodes": self.max_leaf_nodes,
146
+ "min_impurity_decrease": self.min_impurity_decrease,
147
+ "bootstrap": self.bootstrap,
148
+ "oob_score": self.oob_score,
149
+ "n_jobs": self.n_jobs,
150
+ "random_state": self.random_state,
151
+ "verbose": self.verbose,
152
+ "warm_start": self.warm_start,
153
+ "error_metric_mode": self._err if self.oob_score else "none",
154
+ "variable_importance_mode": "mdi",
155
+ "class_weight": self.class_weight,
156
+ "max_bins": self.max_bins,
157
+ "min_bin_size": self.min_bin_size,
158
+ "max_samples": self.max_samples,
159
+ }
160
+
161
+ onedal_params["min_impurity_split"] = None
162
+
163
+ # Lazy evaluation of estimators_
164
+ self._cached_estimators_ = None
165
+
166
+ # Compute
167
+ self._onedal_estimator = self._onedal_factory(**onedal_params)
168
+ self._onedal_estimator.fit(X, xp.reshape(y, (-1,)), sample_weight, queue=queue)
169
+
170
+ self._save_attributes()
171
+
172
+ # Decapsulate classes_ attributes
173
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
174
+ self.n_classes_ = (
175
+ self.n_classes_[0]
176
+ if isinstance(self.n_classes_, Iterable)
177
+ else self.n_classes_
178
+ )
179
+ self.classes_ = (
180
+ self.classes_[0]
181
+ if isinstance(self.classes_[0], Iterable)
182
+ else self.classes_
183
+ )
184
+
185
+ return self
186
+
187
+ def _save_attributes(self):
188
+ if self.oob_score:
189
+ self.oob_score_ = self._onedal_estimator.oob_score_
190
+ if hasattr(self._onedal_estimator, "oob_prediction_"):
191
+ self.oob_prediction_ = self._onedal_estimator.oob_prediction_
192
+ if hasattr(self._onedal_estimator, "oob_decision_function_"):
193
+ self.oob_decision_function_ = (
194
+ self._onedal_estimator.oob_decision_function_
195
+ )
196
+ if self.bootstrap:
197
+ self._n_samples_bootstrap = max(
198
+ round(
199
+ self._onedal_estimator.observations_per_tree_fraction
200
+ * self._n_samples
201
+ ),
202
+ 1,
203
+ )
204
+ else:
205
+ self._n_samples_bootstrap = None
206
+ self._validate_estimator()
207
+ return self
208
+
209
+ def _to_absolute_max_features(self, max_features, n_features):
210
+ if max_features is None:
211
+ return n_features
212
+ if isinstance(max_features, str):
213
+ if max_features == "auto":
214
+ if not sklearn_check_version("1.3"):
215
+ if sklearn_check_version("1.1"):
216
+ warnings.warn(
217
+ "`max_features='auto'` has been deprecated in 1.1 "
218
+ "and will be removed in 1.3. To keep the past behaviour, "
219
+ "explicitly set `max_features=1.0` or remove this "
220
+ "parameter as it is also the default value for "
221
+ "RandomForestRegressors and ExtraTreesRegressors.",
222
+ FutureWarning,
223
+ )
224
+ return (
225
+ max(1, int(np.sqrt(n_features)))
226
+ if isinstance(self, ForestClassifier)
227
+ else n_features
228
+ )
229
+ if max_features == "sqrt":
230
+ return max(1, int(np.sqrt(n_features)))
231
+ if max_features == "log2":
232
+ return max(1, int(np.log2(n_features)))
233
+ allowed_string_values = (
234
+ '"sqrt" or "log2"'
235
+ if sklearn_check_version("1.3")
236
+ else '"auto", "sqrt" or "log2"'
237
+ )
238
+ raise ValueError(
239
+ "Invalid value for max_features. Allowed string "
240
+ f"values are {allowed_string_values}."
241
+ )
242
+ if isinstance(max_features, (numbers.Integral, np.integer)):
243
+ return max_features
244
+ if max_features > 0.0:
245
+ return max(1, int(max_features * n_features))
246
+ return 0
247
+
248
+ def _check_parameters(self):
249
+ if isinstance(self.min_samples_leaf, numbers.Integral):
250
+ if not 1 <= self.min_samples_leaf:
251
+ raise ValueError(
252
+ "min_samples_leaf must be at least 1 "
253
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
254
+ )
255
+ else: # float
256
+ if not 0.0 < self.min_samples_leaf <= 0.5:
257
+ raise ValueError(
258
+ "min_samples_leaf must be at least 1 "
259
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
260
+ )
261
+ if isinstance(self.min_samples_split, numbers.Integral):
262
+ if not 2 <= self.min_samples_split:
263
+ raise ValueError(
264
+ "min_samples_split must be an integer "
265
+ "greater than 1 or a float in (0.0, 1.0]; "
266
+ "got the integer %s" % self.min_samples_split
267
+ )
268
+ else: # float
269
+ if not 0.0 < self.min_samples_split <= 1.0:
270
+ raise ValueError(
271
+ "min_samples_split must be an integer "
272
+ "greater than 1 or a float in (0.0, 1.0]; "
273
+ "got the float %s" % self.min_samples_split
274
+ )
275
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
276
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
277
+ if hasattr(self, "min_impurity_split"):
278
+ warnings.warn(
279
+ "The min_impurity_split parameter is deprecated. "
280
+ "Its default value has changed from 1e-7 to 0 in "
281
+ "version 0.23, and it will be removed in 0.25. "
282
+ "Use the min_impurity_decrease parameter instead.",
283
+ FutureWarning,
284
+ )
285
+
286
+ if getattr(self, "min_impurity_split") < 0.0:
287
+ raise ValueError(
288
+ "min_impurity_split must be greater than " "or equal to 0"
289
+ )
290
+ if self.min_impurity_decrease < 0.0:
291
+ raise ValueError(
292
+ "min_impurity_decrease must be greater than " "or equal to 0"
293
+ )
294
+ if self.max_leaf_nodes is not None:
295
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
296
+ raise ValueError(
297
+ "max_leaf_nodes must be integral number but was "
298
+ "%r" % self.max_leaf_nodes
299
+ )
300
+ if self.max_leaf_nodes < 2:
301
+ raise ValueError(
302
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
303
+ self.max_leaf_nodes
304
+ )
305
+ )
306
+ if isinstance(self.max_bins, numbers.Integral):
307
+ if not 2 <= self.max_bins:
308
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
309
+ else:
310
+ raise ValueError(
311
+ "max_bins must be integral number but was " "%r" % self.max_bins
312
+ )
313
+ if isinstance(self.min_bin_size, numbers.Integral):
314
+ if not 1 <= self.min_bin_size:
315
+ raise ValueError(
316
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
317
+ )
318
+ else:
319
+ raise ValueError(
320
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
321
+ )
322
+
323
+ @property
324
+ def estimators_(self):
325
+ if hasattr(self, "_cached_estimators_"):
326
+ if self._cached_estimators_ is None:
327
+ self._estimators_()
328
+ return self._cached_estimators_
329
+ else:
330
+ raise AttributeError(
331
+ f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
332
+ )
333
+
334
+ @estimators_.setter
335
+ def estimators_(self, estimators):
336
+ # Needed to allow for proper sklearn operation in fallback mode
337
+ self._cached_estimators_ = estimators
338
+
339
+ def _estimators_(self):
340
+ # _estimators_ should only be called if _onedal_estimator exists
341
+ check_is_fitted(self, "_onedal_estimator")
342
+ if hasattr(self, "n_classes_"):
343
+ n_classes_ = (
344
+ self.n_classes_
345
+ if isinstance(self.n_classes_, int)
346
+ else self.n_classes_[0]
347
+ )
348
+ else:
349
+ n_classes_ = 1
350
+
351
+ # convert model to estimators
352
+ params = {
353
+ "criterion": self._onedal_estimator.criterion,
354
+ "max_depth": self._onedal_estimator.max_depth,
355
+ "min_samples_split": self._onedal_estimator.min_samples_split,
356
+ "min_samples_leaf": self._onedal_estimator.min_samples_leaf,
357
+ "min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
358
+ "max_features": self._onedal_estimator.max_features,
359
+ "max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
360
+ "min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
361
+ "random_state": None,
362
+ }
363
+ est = self.estimator.__class__(**params)
364
+ # we need to set est.tree_ field with Trees constructed from
365
+ # oneAPI Data Analytics Library solution
366
+ estimators_ = []
367
+
368
+ random_state_checked = check_random_state(self.random_state)
369
+
370
+ for i in range(self._onedal_estimator.n_estimators):
371
+ est_i = clone(est)
372
+ est_i.set_params(
373
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
374
+ )
375
+ est_i.n_features_in_ = self.n_features_in_
376
+ est_i.n_outputs_ = self.n_outputs_
377
+ est_i.n_classes_ = n_classes_
378
+ tree_i_state_class = self._get_tree_state(
379
+ self._onedal_estimator._onedal_model, i, n_classes_
380
+ )
381
+ tree_i_state_dict = {
382
+ "max_depth": tree_i_state_class.max_depth,
383
+ "node_count": tree_i_state_class.node_count,
384
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
385
+ "values": tree_i_state_class.value_ar,
386
+ }
387
+ # Note: only on host.
388
+ est_i.tree_ = Tree(
389
+ self.n_features_in_,
390
+ np.array([n_classes_], dtype=np.intp),
391
+ self.n_outputs_,
392
+ )
393
+ est_i.tree_.__setstate__(tree_i_state_dict)
394
+ estimators_.append(est_i)
395
+
396
+ self._cached_estimators_ = estimators_
397
+
398
+ if not sklearn_check_version("1.2"):
399
+
400
+ @property
401
+ def base_estimator(self):
402
+ return self.estimator
403
+
404
+ @base_estimator.setter
405
+ def base_estimator(self, estimator):
406
+ self.estimator = estimator
407
+
408
+
409
+ class ForestClassifier(BaseForest, _sklearn_ForestClassifier):
410
+ # Surprisingly, even though scikit-learn warns against using
411
+ # their ForestClassifier directly, it actually has a more stable
412
+ # API than the user-facing objects (over time). If they change it
413
+ # significantly at some point then this may need to be versioned.
414
+
415
+ _err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
416
+ _get_tree_state = staticmethod(get_tree_state_cls)
417
+
418
+ def __init__(
419
+ self,
420
+ estimator,
421
+ n_estimators=100,
422
+ *,
423
+ estimator_params=tuple(),
424
+ bootstrap=False,
425
+ oob_score=False,
426
+ n_jobs=None,
427
+ random_state=None,
428
+ verbose=0,
429
+ warm_start=False,
430
+ class_weight=None,
431
+ max_samples=None,
432
+ ):
433
+ super().__init__(
434
+ estimator,
435
+ n_estimators=n_estimators,
436
+ estimator_params=estimator_params,
437
+ bootstrap=bootstrap,
438
+ oob_score=oob_score,
439
+ n_jobs=n_jobs,
440
+ random_state=random_state,
441
+ verbose=verbose,
442
+ warm_start=warm_start,
443
+ class_weight=class_weight,
444
+ max_samples=max_samples,
445
+ )
446
+
447
+ # The estimator is checked against the class attribute for conformance.
448
+ # This should only trigger if the user uses this class directly.
449
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
450
+ self._onedal_factory, onedal_RandomForestClassifier
451
+ ):
452
+ self._onedal_factory = onedal_RandomForestClassifier
453
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
454
+ self._onedal_factory, onedal_ExtraTreesClassifier
455
+ ):
456
+ self._onedal_factory = onedal_ExtraTreesClassifier
457
+
458
+ if self._onedal_factory is None:
459
+ raise TypeError(f" oneDAL estimator has not been set.")
460
+
461
+ decision_path = support_input_format(_sklearn_ForestClassifier.decision_path)
462
+ apply = support_input_format(_sklearn_ForestClassifier.apply)
463
+
464
+ def _estimators_(self):
465
+ super()._estimators_()
466
+ for est in self._cached_estimators_:
467
+ est.classes_ = self.classes_
468
+
469
+ def fit(self, X, y, sample_weight=None):
470
+ dispatch(
471
+ self,
472
+ "fit",
473
+ {
474
+ "onedal": self.__class__._onedal_fit,
475
+ "sklearn": _sklearn_ForestClassifier.fit,
476
+ },
477
+ X,
478
+ y,
479
+ sample_weight,
480
+ )
481
+ return self
482
+
483
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
484
+ if sp.issparse(y):
485
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
486
+
487
+ if sklearn_check_version("1.2"):
488
+ self._validate_params()
489
+ else:
490
+ self._check_parameters()
491
+
492
+ if not self.bootstrap and self.oob_score:
493
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
494
+
495
+ patching_status.and_conditions(
496
+ [
497
+ (
498
+ self.oob_score
499
+ and daal_check_version((2021, "P", 500))
500
+ or not self.oob_score,
501
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
502
+ ),
503
+ (self.warm_start is False, "Warm start is not supported."),
504
+ (
505
+ self.criterion == "gini",
506
+ f"'{self.criterion}' criterion is not supported. "
507
+ "Only 'gini' criterion is supported.",
508
+ ),
509
+ (
510
+ self.ccp_alpha == 0.0,
511
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
512
+ ),
513
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
514
+ (
515
+ self.n_estimators <= 6024,
516
+ "More than 6024 estimators is not supported.",
517
+ ),
518
+ ]
519
+ )
520
+
521
+ if self.bootstrap:
522
+ patching_status.and_conditions(
523
+ [
524
+ (
525
+ self.class_weight != "balanced_subsample",
526
+ "'balanced_subsample' for class_weight is not supported",
527
+ )
528
+ ]
529
+ )
530
+
531
+ if patching_status.get_status() and sklearn_check_version("1.4"):
532
+ try:
533
+ _assert_all_finite(X)
534
+ input_is_finite = True
535
+ except ValueError:
536
+ input_is_finite = False
537
+ patching_status.and_conditions(
538
+ [
539
+ (input_is_finite, "Non-finite input is not supported."),
540
+ (
541
+ self.monotonic_cst is None,
542
+ "Monotonicity constraints are not supported.",
543
+ ),
544
+ ]
545
+ )
546
+
547
+ if patching_status.get_status():
548
+ if sklearn_check_version("1.6"):
549
+ X, y = check_X_y(
550
+ X,
551
+ y,
552
+ multi_output=True,
553
+ accept_sparse=True,
554
+ dtype=[np.float64, np.float32],
555
+ ensure_all_finite=False,
556
+ )
557
+ else:
558
+ X, y = check_X_y(
559
+ X,
560
+ y,
561
+ multi_output=True,
562
+ accept_sparse=True,
563
+ dtype=[np.float64, np.float32],
564
+ force_all_finite=False,
565
+ )
566
+
567
+ if y.ndim == 2 and y.shape[1] == 1:
568
+ warnings.warn(
569
+ "A column-vector y was passed when a 1d array was"
570
+ " expected. Please change the shape of y to "
571
+ "(n_samples,), for example using ravel().",
572
+ DataConversionWarning,
573
+ stacklevel=2,
574
+ )
575
+
576
+ if y.ndim == 1:
577
+ y = np.reshape(y, (-1, 1))
578
+
579
+ self.n_outputs_ = y.shape[1]
580
+
581
+ patching_status.and_conditions(
582
+ [
583
+ (
584
+ self.n_outputs_ == 1,
585
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
586
+ ),
587
+ (
588
+ y.dtype in [np.float32, np.float64, np.int32, np.int64],
589
+ f"Datatype ({y.dtype}) for y is not supported.",
590
+ ),
591
+ ]
592
+ )
593
+ # TODO: Fix to support integers as input
594
+
595
+ if self.n_outputs_ == 1:
596
+ xp, is_array_api_compliant = get_namespace(y)
597
+ sety = xp.unique_values(y) if is_array_api_compliant else np.unique(y)
598
+ num_classes = sety.shape[0]
599
+ patching_status.and_conditions(
600
+ [
601
+ (
602
+ num_classes >= 2,
603
+ "Number of classes must be at least 2.",
604
+ ),
605
+ ]
606
+ )
607
+
608
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
609
+
610
+ if not self.bootstrap and self.max_samples is not None:
611
+ raise ValueError(
612
+ "`max_sample` cannot be set if `bootstrap=False`. "
613
+ "Either switch to `bootstrap=True` or set "
614
+ "`max_sample=None`."
615
+ )
616
+
617
+ if (
618
+ patching_status.get_status()
619
+ and (self.random_state is not None)
620
+ and (not daal_check_version((2024, "P", 0)))
621
+ ):
622
+ warnings.warn(
623
+ "Setting 'random_state' value is not supported. "
624
+ "State set by oneDAL to default value (777).",
625
+ RuntimeWarning,
626
+ )
627
+
628
+ return patching_status, X, y, sample_weight
629
+
630
+ @wrap_output_data
631
+ def predict(self, X):
632
+ check_is_fitted(self)
633
+ return dispatch(
634
+ self,
635
+ "predict",
636
+ {
637
+ "onedal": self.__class__._onedal_predict,
638
+ "sklearn": _sklearn_ForestClassifier.predict,
639
+ },
640
+ X,
641
+ )
642
+
643
+ @wrap_output_data
644
+ def predict_proba(self, X):
645
+ # TODO:
646
+ # _check_proba()
647
+ # self._check_proba()
648
+ check_is_fitted(self)
649
+ return dispatch(
650
+ self,
651
+ "predict_proba",
652
+ {
653
+ "onedal": self.__class__._onedal_predict_proba,
654
+ "sklearn": _sklearn_ForestClassifier.predict_proba,
655
+ },
656
+ X,
657
+ )
658
+
659
+ def predict_log_proba(self, X):
660
+ xp, _ = get_namespace(X)
661
+ proba = self.predict_proba(X)
662
+
663
+ if self.n_outputs_ == 1:
664
+ return xp.log(proba)
665
+
666
+ else:
667
+ for k in range(self.n_outputs_):
668
+ proba[k] = xp.log(proba[k])
669
+
670
+ return proba
671
+
672
+ @wrap_output_data
673
+ def score(self, X, y, sample_weight=None):
674
+ check_is_fitted(self)
675
+ return dispatch(
676
+ self,
677
+ "score",
678
+ {
679
+ "onedal": self.__class__._onedal_score,
680
+ "sklearn": _sklearn_ForestClassifier.score,
681
+ },
682
+ X,
683
+ y,
684
+ sample_weight=sample_weight,
685
+ )
686
+
687
+ fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
688
+ predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
689
+ predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
690
+ predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
691
+ score.__doc__ = _sklearn_ForestClassifier.score.__doc__
692
+
693
+ def _onedal_cpu_supported(self, method_name, *data):
694
+ class_name = self.__class__.__name__
695
+ patching_status = PatchingConditionsChain(
696
+ f"sklearn.ensemble.{class_name}.{method_name}"
697
+ )
698
+
699
+ if method_name == "fit":
700
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
701
+ patching_status, *data
702
+ )
703
+
704
+ patching_status.and_conditions(
705
+ [
706
+ (
707
+ daal_check_version((2023, "P", 200))
708
+ or self.estimator.__class__ == DecisionTreeClassifier,
709
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
710
+ ),
711
+ (
712
+ not sp.issparse(sample_weight),
713
+ "sample_weight is sparse. " "Sparse input is not supported.",
714
+ ),
715
+ ]
716
+ )
717
+
718
+ elif method_name in ["predict", "predict_proba", "score"]:
719
+ X = data[0]
720
+
721
+ patching_status.and_conditions(
722
+ [
723
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
724
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
725
+ (self.warm_start is False, "Warm start is not supported."),
726
+ (
727
+ daal_check_version((2023, "P", 100))
728
+ or self.estimator.__class__ == DecisionTreeClassifier,
729
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
730
+ ),
731
+ ]
732
+ )
733
+
734
+ if method_name == "predict_proba":
735
+ patching_status.and_conditions(
736
+ [
737
+ (
738
+ daal_check_version((2021, "P", 400)),
739
+ "oneDAL version is lower than 2021.4.",
740
+ )
741
+ ]
742
+ )
743
+
744
+ if hasattr(self, "n_outputs_"):
745
+ patching_status.and_conditions(
746
+ [
747
+ (
748
+ self.n_outputs_ == 1,
749
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
750
+ ),
751
+ ]
752
+ )
753
+
754
+ else:
755
+ raise RuntimeError(
756
+ f"Unknown method {method_name} in {self.__class__.__name__}"
757
+ )
758
+
759
+ return patching_status
760
+
761
+ def _onedal_gpu_supported(self, method_name, *data):
762
+ class_name = self.__class__.__name__
763
+ patching_status = PatchingConditionsChain(
764
+ f"sklearn.ensemble.{class_name}.{method_name}"
765
+ )
766
+
767
+ if method_name == "fit":
768
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
769
+ patching_status, *data
770
+ )
771
+
772
+ patching_status.and_conditions(
773
+ [
774
+ (
775
+ daal_check_version((2023, "P", 100))
776
+ or self.estimator.__class__ == DecisionTreeClassifier,
777
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
778
+ ),
779
+ (
780
+ not self.oob_score,
781
+ "oob_scores using r2 or accuracy not implemented.",
782
+ ),
783
+ (sample_weight is None, "sample_weight is not supported."),
784
+ ]
785
+ )
786
+
787
+ elif method_name in ["predict", "predict_proba", "score"]:
788
+ X = data[0]
789
+
790
+ patching_status.and_conditions(
791
+ [
792
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
793
+ (
794
+ not sp.issparse(X),
795
+ "X is sparse. Sparse input is not supported.",
796
+ ),
797
+ (self.warm_start is False, "Warm start is not supported."),
798
+ (
799
+ daal_check_version((2023, "P", 100)),
800
+ "ExtraTrees supported starting from oneDAL version 2023.1",
801
+ ),
802
+ ]
803
+ )
804
+ if hasattr(self, "n_outputs_"):
805
+ patching_status.and_conditions(
806
+ [
807
+ (
808
+ self.n_outputs_ == 1,
809
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
810
+ ),
811
+ ]
812
+ )
813
+
814
+ else:
815
+ raise RuntimeError(
816
+ f"Unknown method {method_name} in {self.__class__.__name__}"
817
+ )
818
+
819
+ return patching_status
820
+
821
+ def _onedal_predict(self, X, queue=None):
822
+ xp, _ = get_namespace(X)
823
+ if not get_config()["use_raw_input"]:
824
+ X = validate_data(
825
+ self,
826
+ X,
827
+ dtype=[np.float64, np.float32],
828
+ ensure_all_finite=False,
829
+ reset=False,
830
+ ensure_2d=True,
831
+ )
832
+ if hasattr(self, "n_features_in_"):
833
+ try:
834
+ num_features = _num_features(X)
835
+ except TypeError:
836
+ num_features = _num_samples(X)
837
+ if num_features != self.n_features_in_:
838
+ raise ValueError(
839
+ (
840
+ f"X has {num_features} features, "
841
+ f"but {self.__class__.__name__} is expecting "
842
+ f"{self.n_features_in_} features as input"
843
+ )
844
+ )
845
+ check_n_features(self, X, reset=False)
846
+
847
+ res = self._onedal_estimator.predict(X, queue=queue)
848
+ try:
849
+ return xp.take(
850
+ xp.asarray(self.classes_, device=res.sycl_queue),
851
+ xp.astype(xp.reshape(res, (-1,)), xp.int64),
852
+ )
853
+ except AttributeError:
854
+ return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
855
+
856
+ def _onedal_predict_proba(self, X, queue=None):
857
+ use_raw_input = get_config().get("use_raw_input", False) is True
858
+ if not use_raw_input:
859
+ X = validate_data(
860
+ self,
861
+ X,
862
+ dtype=[np.float64, np.float32],
863
+ ensure_all_finite=False,
864
+ reset=False,
865
+ ensure_2d=True,
866
+ )
867
+
868
+ return self._onedal_estimator.predict_proba(X, queue=queue)
869
+
870
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
871
+ return accuracy_score(
872
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
873
+ )
874
+
875
+
876
+ class ForestRegressor(BaseForest, _sklearn_ForestRegressor):
877
+ _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
878
+ _get_tree_state = staticmethod(get_tree_state_reg)
879
+
880
+ def __init__(
881
+ self,
882
+ estimator,
883
+ n_estimators=100,
884
+ *,
885
+ estimator_params=tuple(),
886
+ bootstrap=False,
887
+ oob_score=False,
888
+ n_jobs=None,
889
+ random_state=None,
890
+ verbose=0,
891
+ warm_start=False,
892
+ max_samples=None,
893
+ ):
894
+ super().__init__(
895
+ estimator,
896
+ n_estimators=n_estimators,
897
+ estimator_params=estimator_params,
898
+ bootstrap=bootstrap,
899
+ oob_score=oob_score,
900
+ n_jobs=n_jobs,
901
+ random_state=random_state,
902
+ verbose=verbose,
903
+ warm_start=warm_start,
904
+ max_samples=max_samples,
905
+ )
906
+
907
+ # The splitter is checked against the class attribute for conformance
908
+ # This should only trigger if the user uses this class directly.
909
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
910
+ self._onedal_factory, onedal_RandomForestRegressor
911
+ ):
912
+ self._onedal_factory = onedal_RandomForestRegressor
913
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
914
+ self._onedal_factory, onedal_ExtraTreesRegressor
915
+ ):
916
+ self._onedal_factory = onedal_ExtraTreesRegressor
917
+
918
+ if self._onedal_factory is None:
919
+ raise TypeError(f" oneDAL estimator has not been set.")
920
+
921
+ decision_path = support_input_format(_sklearn_ForestRegressor.decision_path)
922
+ apply = support_input_format(_sklearn_ForestRegressor.apply)
923
+
924
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
925
+ if sp.issparse(y):
926
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
927
+
928
+ if sklearn_check_version("1.2"):
929
+ self._validate_params()
930
+ else:
931
+ self._check_parameters()
932
+
933
+ if not self.bootstrap and self.oob_score:
934
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
935
+
936
+ if not sklearn_check_version("1.2") and self.criterion == "mse":
937
+ warnings.warn(
938
+ "Criterion 'mse' was deprecated in v1.0 and will be "
939
+ "removed in version 1.2. Use `criterion='squared_error'` "
940
+ "which is equivalent.",
941
+ FutureWarning,
942
+ )
943
+
944
+ patching_status.and_conditions(
945
+ [
946
+ (
947
+ self.oob_score
948
+ and daal_check_version((2021, "P", 500))
949
+ or not self.oob_score,
950
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
951
+ ),
952
+ (self.warm_start is False, "Warm start is not supported."),
953
+ (
954
+ self.criterion in ["mse", "squared_error"],
955
+ f"'{self.criterion}' criterion is not supported. "
956
+ "Only 'mse' and 'squared_error' criteria are supported.",
957
+ ),
958
+ (
959
+ self.ccp_alpha == 0.0,
960
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
961
+ ),
962
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
963
+ (
964
+ self.n_estimators <= 6024,
965
+ "More than 6024 estimators is not supported.",
966
+ ),
967
+ ]
968
+ )
969
+
970
+ if patching_status.get_status() and sklearn_check_version("1.4"):
971
+ try:
972
+ _assert_all_finite(X)
973
+ input_is_finite = True
974
+ except ValueError:
975
+ input_is_finite = False
976
+ patching_status.and_conditions(
977
+ [
978
+ (input_is_finite, "Non-finite input is not supported."),
979
+ (
980
+ self.monotonic_cst is None,
981
+ "Monotonicity constraints are not supported.",
982
+ ),
983
+ ]
984
+ )
985
+
986
+ if patching_status.get_status():
987
+ if sklearn_check_version("1.6"):
988
+ X, y = check_X_y(
989
+ X,
990
+ y,
991
+ multi_output=True,
992
+ accept_sparse=True,
993
+ dtype=[np.float64, np.float32],
994
+ ensure_all_finite=False,
995
+ )
996
+ else:
997
+ X, y = check_X_y(
998
+ X,
999
+ y,
1000
+ multi_output=True,
1001
+ accept_sparse=True,
1002
+ dtype=[np.float64, np.float32],
1003
+ force_all_finite=False,
1004
+ )
1005
+
1006
+ if y.ndim == 2 and y.shape[1] == 1:
1007
+ warnings.warn(
1008
+ "A column-vector y was passed when a 1d array was"
1009
+ " expected. Please change the shape of y to "
1010
+ "(n_samples,), for example using ravel().",
1011
+ DataConversionWarning,
1012
+ stacklevel=2,
1013
+ )
1014
+
1015
+ if y.ndim == 1:
1016
+ # reshape is necessary to preserve the data contiguity against vs
1017
+ # [:, np.newaxis] that does not.
1018
+ y = np.reshape(y, (-1, 1))
1019
+
1020
+ self.n_outputs_ = y.shape[1]
1021
+
1022
+ patching_status.and_conditions(
1023
+ [
1024
+ (
1025
+ self.n_outputs_ == 1,
1026
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1027
+ )
1028
+ ]
1029
+ )
1030
+
1031
+ # Sklearn function used for doing checks on max_samples attribute
1032
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
1033
+
1034
+ if not self.bootstrap and self.max_samples is not None:
1035
+ raise ValueError(
1036
+ "`max_sample` cannot be set if `bootstrap=False`. "
1037
+ "Either switch to `bootstrap=True` or set "
1038
+ "`max_sample=None`."
1039
+ )
1040
+
1041
+ if (
1042
+ patching_status.get_status()
1043
+ and (self.random_state is not None)
1044
+ and (not daal_check_version((2024, "P", 0)))
1045
+ ):
1046
+ warnings.warn(
1047
+ "Setting 'random_state' value is not supported. "
1048
+ "State set by oneDAL to default value (777).",
1049
+ RuntimeWarning,
1050
+ )
1051
+
1052
+ return patching_status, X, y, sample_weight
1053
+
1054
+ def _onedal_cpu_supported(self, method_name, *data):
1055
+ class_name = self.__class__.__name__
1056
+ patching_status = PatchingConditionsChain(
1057
+ f"sklearn.ensemble.{class_name}.{method_name}"
1058
+ )
1059
+
1060
+ if method_name == "fit":
1061
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1062
+ patching_status, *data
1063
+ )
1064
+
1065
+ patching_status.and_conditions(
1066
+ [
1067
+ (
1068
+ daal_check_version((2023, "P", 200))
1069
+ or self.estimator.__class__ == DecisionTreeClassifier,
1070
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1071
+ ),
1072
+ (
1073
+ not sp.issparse(sample_weight),
1074
+ "sample_weight is sparse. " "Sparse input is not supported.",
1075
+ ),
1076
+ ]
1077
+ )
1078
+
1079
+ elif method_name in ["predict", "score"]:
1080
+ X = data[0]
1081
+
1082
+ patching_status.and_conditions(
1083
+ [
1084
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1085
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1086
+ (self.warm_start is False, "Warm start is not supported."),
1087
+ (
1088
+ daal_check_version((2023, "P", 200))
1089
+ or self.estimator.__class__ == DecisionTreeClassifier,
1090
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1091
+ ),
1092
+ ]
1093
+ )
1094
+ if hasattr(self, "n_outputs_"):
1095
+ patching_status.and_conditions(
1096
+ [
1097
+ (
1098
+ self.n_outputs_ == 1,
1099
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1100
+ ),
1101
+ ]
1102
+ )
1103
+
1104
+ else:
1105
+ raise RuntimeError(
1106
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1107
+ )
1108
+
1109
+ return patching_status
1110
+
1111
+ def _onedal_gpu_supported(self, method_name, *data):
1112
+ class_name = self.__class__.__name__
1113
+ patching_status = PatchingConditionsChain(
1114
+ f"sklearn.ensemble.{class_name}.{method_name}"
1115
+ )
1116
+
1117
+ if method_name == "fit":
1118
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1119
+ patching_status, *data
1120
+ )
1121
+
1122
+ patching_status.and_conditions(
1123
+ [
1124
+ (
1125
+ daal_check_version((2023, "P", 100))
1126
+ or self.estimator.__class__ == DecisionTreeClassifier,
1127
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1128
+ ),
1129
+ (not self.oob_score, "oob_score value is not sklearn conformant."),
1130
+ (sample_weight is None, "sample_weight is not supported."),
1131
+ ]
1132
+ )
1133
+
1134
+ elif method_name in ["predict", "score"]:
1135
+ X = data[0]
1136
+
1137
+ patching_status.and_conditions(
1138
+ [
1139
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1140
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1141
+ (self.warm_start is False, "Warm start is not supported."),
1142
+ (
1143
+ daal_check_version((2023, "P", 100))
1144
+ or self.estimator.__class__ == DecisionTreeClassifier,
1145
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1146
+ ),
1147
+ ]
1148
+ )
1149
+ if hasattr(self, "n_outputs_"):
1150
+ patching_status.and_conditions(
1151
+ [
1152
+ (
1153
+ self.n_outputs_ == 1,
1154
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1155
+ ),
1156
+ ]
1157
+ )
1158
+
1159
+ else:
1160
+ raise RuntimeError(
1161
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1162
+ )
1163
+
1164
+ return patching_status
1165
+
1166
+ def _onedal_predict(self, X, queue=None):
1167
+ check_is_fitted(self, "_onedal_estimator")
1168
+ use_raw_input = get_config().get("use_raw_input", False) is True
1169
+
1170
+ if not use_raw_input:
1171
+ X = validate_data(
1172
+ self,
1173
+ X,
1174
+ dtype=[np.float64, np.float32],
1175
+ ensure_all_finite=False,
1176
+ reset=False,
1177
+ ensure_2d=True,
1178
+ ) # Warning, order of dtype matters
1179
+
1180
+ return self._onedal_estimator.predict(X, queue=queue)
1181
+
1182
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
1183
+ return r2_score(
1184
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
1185
+ )
1186
+
1187
+ def fit(self, X, y, sample_weight=None):
1188
+ dispatch(
1189
+ self,
1190
+ "fit",
1191
+ {
1192
+ "onedal": self.__class__._onedal_fit,
1193
+ "sklearn": _sklearn_ForestRegressor.fit,
1194
+ },
1195
+ X,
1196
+ y,
1197
+ sample_weight,
1198
+ )
1199
+ return self
1200
+
1201
+ @wrap_output_data
1202
+ def predict(self, X):
1203
+ check_is_fitted(self)
1204
+ return dispatch(
1205
+ self,
1206
+ "predict",
1207
+ {
1208
+ "onedal": self.__class__._onedal_predict,
1209
+ "sklearn": _sklearn_ForestRegressor.predict,
1210
+ },
1211
+ X,
1212
+ )
1213
+
1214
+ @wrap_output_data
1215
+ def score(self, X, y, sample_weight=None):
1216
+ check_is_fitted(self)
1217
+ return dispatch(
1218
+ self,
1219
+ "score",
1220
+ {
1221
+ "onedal": self.__class__._onedal_score,
1222
+ "sklearn": _sklearn_ForestRegressor.score,
1223
+ },
1224
+ X,
1225
+ y,
1226
+ sample_weight=sample_weight,
1227
+ )
1228
+
1229
+ fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
1230
+ predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
1231
+ score.__doc__ = _sklearn_ForestRegressor.score.__doc__
1232
+
1233
+
1234
+ @register_hyperparameters({"predict": ("decision_forest", "infer")})
1235
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1236
+ class RandomForestClassifier(ForestClassifier):
1237
+ __doc__ = _sklearn_RandomForestClassifier.__doc__
1238
+ _onedal_factory = onedal_RandomForestClassifier
1239
+
1240
+ if sklearn_check_version("1.2"):
1241
+ _parameter_constraints: dict = {
1242
+ **_sklearn_RandomForestClassifier._parameter_constraints,
1243
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1244
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1245
+ }
1246
+
1247
+ if sklearn_check_version("1.4"):
1248
+
1249
+ def __init__(
1250
+ self,
1251
+ n_estimators=100,
1252
+ *,
1253
+ criterion="gini",
1254
+ max_depth=None,
1255
+ min_samples_split=2,
1256
+ min_samples_leaf=1,
1257
+ min_weight_fraction_leaf=0.0,
1258
+ max_features="sqrt",
1259
+ max_leaf_nodes=None,
1260
+ min_impurity_decrease=0.0,
1261
+ bootstrap=True,
1262
+ oob_score=False,
1263
+ n_jobs=None,
1264
+ random_state=None,
1265
+ verbose=0,
1266
+ warm_start=False,
1267
+ class_weight=None,
1268
+ ccp_alpha=0.0,
1269
+ max_samples=None,
1270
+ monotonic_cst=None,
1271
+ max_bins=256,
1272
+ min_bin_size=1,
1273
+ ):
1274
+ super().__init__(
1275
+ DecisionTreeClassifier(),
1276
+ n_estimators,
1277
+ estimator_params=(
1278
+ "criterion",
1279
+ "max_depth",
1280
+ "min_samples_split",
1281
+ "min_samples_leaf",
1282
+ "min_weight_fraction_leaf",
1283
+ "max_features",
1284
+ "max_leaf_nodes",
1285
+ "min_impurity_decrease",
1286
+ "random_state",
1287
+ "ccp_alpha",
1288
+ "monotonic_cst",
1289
+ ),
1290
+ bootstrap=bootstrap,
1291
+ oob_score=oob_score,
1292
+ n_jobs=n_jobs,
1293
+ random_state=random_state,
1294
+ verbose=verbose,
1295
+ warm_start=warm_start,
1296
+ class_weight=class_weight,
1297
+ max_samples=max_samples,
1298
+ )
1299
+
1300
+ self.criterion = criterion
1301
+ self.max_depth = max_depth
1302
+ self.min_samples_split = min_samples_split
1303
+ self.min_samples_leaf = min_samples_leaf
1304
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1305
+ self.max_features = max_features
1306
+ self.max_leaf_nodes = max_leaf_nodes
1307
+ self.min_impurity_decrease = min_impurity_decrease
1308
+ self.ccp_alpha = ccp_alpha
1309
+ self.max_bins = max_bins
1310
+ self.min_bin_size = min_bin_size
1311
+ self.monotonic_cst = monotonic_cst
1312
+
1313
+ else:
1314
+
1315
+ def __init__(
1316
+ self,
1317
+ n_estimators=100,
1318
+ *,
1319
+ criterion="gini",
1320
+ max_depth=None,
1321
+ min_samples_split=2,
1322
+ min_samples_leaf=1,
1323
+ min_weight_fraction_leaf=0.0,
1324
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1325
+ max_leaf_nodes=None,
1326
+ min_impurity_decrease=0.0,
1327
+ bootstrap=True,
1328
+ oob_score=False,
1329
+ n_jobs=None,
1330
+ random_state=None,
1331
+ verbose=0,
1332
+ warm_start=False,
1333
+ class_weight=None,
1334
+ ccp_alpha=0.0,
1335
+ max_samples=None,
1336
+ max_bins=256,
1337
+ min_bin_size=1,
1338
+ ):
1339
+ super().__init__(
1340
+ DecisionTreeClassifier(),
1341
+ n_estimators,
1342
+ estimator_params=(
1343
+ "criterion",
1344
+ "max_depth",
1345
+ "min_samples_split",
1346
+ "min_samples_leaf",
1347
+ "min_weight_fraction_leaf",
1348
+ "max_features",
1349
+ "max_leaf_nodes",
1350
+ "min_impurity_decrease",
1351
+ "random_state",
1352
+ "ccp_alpha",
1353
+ ),
1354
+ bootstrap=bootstrap,
1355
+ oob_score=oob_score,
1356
+ n_jobs=n_jobs,
1357
+ random_state=random_state,
1358
+ verbose=verbose,
1359
+ warm_start=warm_start,
1360
+ class_weight=class_weight,
1361
+ max_samples=max_samples,
1362
+ )
1363
+
1364
+ self.criterion = criterion
1365
+ self.max_depth = max_depth
1366
+ self.min_samples_split = min_samples_split
1367
+ self.min_samples_leaf = min_samples_leaf
1368
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1369
+ self.max_features = max_features
1370
+ self.max_leaf_nodes = max_leaf_nodes
1371
+ self.min_impurity_decrease = min_impurity_decrease
1372
+ self.ccp_alpha = ccp_alpha
1373
+ self.max_bins = max_bins
1374
+ self.min_bin_size = min_bin_size
1375
+
1376
+
1377
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
1378
+ class RandomForestRegressor(ForestRegressor):
1379
+ __doc__ = _sklearn_RandomForestRegressor.__doc__
1380
+ _onedal_factory = onedal_RandomForestRegressor
1381
+
1382
+ if sklearn_check_version("1.2"):
1383
+ _parameter_constraints: dict = {
1384
+ **_sklearn_RandomForestRegressor._parameter_constraints,
1385
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1386
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1387
+ }
1388
+
1389
+ if sklearn_check_version("1.4"):
1390
+
1391
+ def __init__(
1392
+ self,
1393
+ n_estimators=100,
1394
+ *,
1395
+ criterion="squared_error",
1396
+ max_depth=None,
1397
+ min_samples_split=2,
1398
+ min_samples_leaf=1,
1399
+ min_weight_fraction_leaf=0.0,
1400
+ max_features=1.0,
1401
+ max_leaf_nodes=None,
1402
+ min_impurity_decrease=0.0,
1403
+ bootstrap=True,
1404
+ oob_score=False,
1405
+ n_jobs=None,
1406
+ random_state=None,
1407
+ verbose=0,
1408
+ warm_start=False,
1409
+ ccp_alpha=0.0,
1410
+ max_samples=None,
1411
+ monotonic_cst=None,
1412
+ max_bins=256,
1413
+ min_bin_size=1,
1414
+ ):
1415
+ super().__init__(
1416
+ DecisionTreeRegressor(),
1417
+ n_estimators=n_estimators,
1418
+ estimator_params=(
1419
+ "criterion",
1420
+ "max_depth",
1421
+ "min_samples_split",
1422
+ "min_samples_leaf",
1423
+ "min_weight_fraction_leaf",
1424
+ "max_features",
1425
+ "max_leaf_nodes",
1426
+ "min_impurity_decrease",
1427
+ "random_state",
1428
+ "ccp_alpha",
1429
+ "monotonic_cst",
1430
+ ),
1431
+ bootstrap=bootstrap,
1432
+ oob_score=oob_score,
1433
+ n_jobs=n_jobs,
1434
+ random_state=random_state,
1435
+ verbose=verbose,
1436
+ warm_start=warm_start,
1437
+ max_samples=max_samples,
1438
+ )
1439
+
1440
+ self.criterion = criterion
1441
+ self.max_depth = max_depth
1442
+ self.min_samples_split = min_samples_split
1443
+ self.min_samples_leaf = min_samples_leaf
1444
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1445
+ self.max_features = max_features
1446
+ self.max_leaf_nodes = max_leaf_nodes
1447
+ self.min_impurity_decrease = min_impurity_decrease
1448
+ self.ccp_alpha = ccp_alpha
1449
+ self.max_bins = max_bins
1450
+ self.min_bin_size = min_bin_size
1451
+ self.monotonic_cst = monotonic_cst
1452
+
1453
+ else:
1454
+
1455
+ def __init__(
1456
+ self,
1457
+ n_estimators=100,
1458
+ *,
1459
+ criterion="squared_error",
1460
+ max_depth=None,
1461
+ min_samples_split=2,
1462
+ min_samples_leaf=1,
1463
+ min_weight_fraction_leaf=0.0,
1464
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1465
+ max_leaf_nodes=None,
1466
+ min_impurity_decrease=0.0,
1467
+ bootstrap=True,
1468
+ oob_score=False,
1469
+ n_jobs=None,
1470
+ random_state=None,
1471
+ verbose=0,
1472
+ warm_start=False,
1473
+ ccp_alpha=0.0,
1474
+ max_samples=None,
1475
+ max_bins=256,
1476
+ min_bin_size=1,
1477
+ ):
1478
+ super().__init__(
1479
+ DecisionTreeRegressor(),
1480
+ n_estimators=n_estimators,
1481
+ estimator_params=(
1482
+ "criterion",
1483
+ "max_depth",
1484
+ "min_samples_split",
1485
+ "min_samples_leaf",
1486
+ "min_weight_fraction_leaf",
1487
+ "max_features",
1488
+ "max_leaf_nodes",
1489
+ "min_impurity_decrease",
1490
+ "random_state",
1491
+ "ccp_alpha",
1492
+ ),
1493
+ bootstrap=bootstrap,
1494
+ oob_score=oob_score,
1495
+ n_jobs=n_jobs,
1496
+ random_state=random_state,
1497
+ verbose=verbose,
1498
+ warm_start=warm_start,
1499
+ max_samples=max_samples,
1500
+ )
1501
+
1502
+ self.criterion = criterion
1503
+ self.max_depth = max_depth
1504
+ self.min_samples_split = min_samples_split
1505
+ self.min_samples_leaf = min_samples_leaf
1506
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1507
+ self.max_features = max_features
1508
+ self.max_leaf_nodes = max_leaf_nodes
1509
+ self.min_impurity_decrease = min_impurity_decrease
1510
+ self.ccp_alpha = ccp_alpha
1511
+ self.max_bins = max_bins
1512
+ self.min_bin_size = min_bin_size
1513
+
1514
+
1515
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1516
+ class ExtraTreesClassifier(ForestClassifier):
1517
+ __doc__ = _sklearn_ExtraTreesClassifier.__doc__
1518
+ _onedal_factory = onedal_ExtraTreesClassifier
1519
+
1520
+ if sklearn_check_version("1.2"):
1521
+ _parameter_constraints: dict = {
1522
+ **_sklearn_ExtraTreesClassifier._parameter_constraints,
1523
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1524
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1525
+ }
1526
+
1527
+ if sklearn_check_version("1.4"):
1528
+
1529
+ def __init__(
1530
+ self,
1531
+ n_estimators=100,
1532
+ *,
1533
+ criterion="gini",
1534
+ max_depth=None,
1535
+ min_samples_split=2,
1536
+ min_samples_leaf=1,
1537
+ min_weight_fraction_leaf=0.0,
1538
+ max_features="sqrt",
1539
+ max_leaf_nodes=None,
1540
+ min_impurity_decrease=0.0,
1541
+ bootstrap=False,
1542
+ oob_score=False,
1543
+ n_jobs=None,
1544
+ random_state=None,
1545
+ verbose=0,
1546
+ warm_start=False,
1547
+ class_weight=None,
1548
+ ccp_alpha=0.0,
1549
+ max_samples=None,
1550
+ monotonic_cst=None,
1551
+ max_bins=256,
1552
+ min_bin_size=1,
1553
+ ):
1554
+ super().__init__(
1555
+ ExtraTreeClassifier(),
1556
+ n_estimators,
1557
+ estimator_params=(
1558
+ "criterion",
1559
+ "max_depth",
1560
+ "min_samples_split",
1561
+ "min_samples_leaf",
1562
+ "min_weight_fraction_leaf",
1563
+ "max_features",
1564
+ "max_leaf_nodes",
1565
+ "min_impurity_decrease",
1566
+ "random_state",
1567
+ "ccp_alpha",
1568
+ "monotonic_cst",
1569
+ ),
1570
+ bootstrap=bootstrap,
1571
+ oob_score=oob_score,
1572
+ n_jobs=n_jobs,
1573
+ random_state=random_state,
1574
+ verbose=verbose,
1575
+ warm_start=warm_start,
1576
+ class_weight=class_weight,
1577
+ max_samples=max_samples,
1578
+ )
1579
+
1580
+ self.criterion = criterion
1581
+ self.max_depth = max_depth
1582
+ self.min_samples_split = min_samples_split
1583
+ self.min_samples_leaf = min_samples_leaf
1584
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1585
+ self.max_features = max_features
1586
+ self.max_leaf_nodes = max_leaf_nodes
1587
+ self.min_impurity_decrease = min_impurity_decrease
1588
+ self.ccp_alpha = ccp_alpha
1589
+ self.max_bins = max_bins
1590
+ self.min_bin_size = min_bin_size
1591
+ self.monotonic_cst = monotonic_cst
1592
+
1593
+ else:
1594
+
1595
+ def __init__(
1596
+ self,
1597
+ n_estimators=100,
1598
+ *,
1599
+ criterion="gini",
1600
+ max_depth=None,
1601
+ min_samples_split=2,
1602
+ min_samples_leaf=1,
1603
+ min_weight_fraction_leaf=0.0,
1604
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1605
+ max_leaf_nodes=None,
1606
+ min_impurity_decrease=0.0,
1607
+ bootstrap=False,
1608
+ oob_score=False,
1609
+ n_jobs=None,
1610
+ random_state=None,
1611
+ verbose=0,
1612
+ warm_start=False,
1613
+ class_weight=None,
1614
+ ccp_alpha=0.0,
1615
+ max_samples=None,
1616
+ max_bins=256,
1617
+ min_bin_size=1,
1618
+ ):
1619
+ super().__init__(
1620
+ ExtraTreeClassifier(),
1621
+ n_estimators,
1622
+ estimator_params=(
1623
+ "criterion",
1624
+ "max_depth",
1625
+ "min_samples_split",
1626
+ "min_samples_leaf",
1627
+ "min_weight_fraction_leaf",
1628
+ "max_features",
1629
+ "max_leaf_nodes",
1630
+ "min_impurity_decrease",
1631
+ "random_state",
1632
+ "ccp_alpha",
1633
+ ),
1634
+ bootstrap=bootstrap,
1635
+ oob_score=oob_score,
1636
+ n_jobs=n_jobs,
1637
+ random_state=random_state,
1638
+ verbose=verbose,
1639
+ warm_start=warm_start,
1640
+ class_weight=class_weight,
1641
+ max_samples=max_samples,
1642
+ )
1643
+
1644
+ self.criterion = criterion
1645
+ self.max_depth = max_depth
1646
+ self.min_samples_split = min_samples_split
1647
+ self.min_samples_leaf = min_samples_leaf
1648
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1649
+ self.max_features = max_features
1650
+ self.max_leaf_nodes = max_leaf_nodes
1651
+ self.min_impurity_decrease = min_impurity_decrease
1652
+ self.ccp_alpha = ccp_alpha
1653
+ self.max_bins = max_bins
1654
+ self.min_bin_size = min_bin_size
1655
+
1656
+
1657
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
1658
+ class ExtraTreesRegressor(ForestRegressor):
1659
+ __doc__ = _sklearn_ExtraTreesRegressor.__doc__
1660
+ _onedal_factory = onedal_ExtraTreesRegressor
1661
+
1662
+ if sklearn_check_version("1.2"):
1663
+ _parameter_constraints: dict = {
1664
+ **_sklearn_ExtraTreesRegressor._parameter_constraints,
1665
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1666
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1667
+ }
1668
+
1669
+ if sklearn_check_version("1.4"):
1670
+
1671
+ def __init__(
1672
+ self,
1673
+ n_estimators=100,
1674
+ *,
1675
+ criterion="squared_error",
1676
+ max_depth=None,
1677
+ min_samples_split=2,
1678
+ min_samples_leaf=1,
1679
+ min_weight_fraction_leaf=0.0,
1680
+ max_features=1.0,
1681
+ max_leaf_nodes=None,
1682
+ min_impurity_decrease=0.0,
1683
+ bootstrap=False,
1684
+ oob_score=False,
1685
+ n_jobs=None,
1686
+ random_state=None,
1687
+ verbose=0,
1688
+ warm_start=False,
1689
+ ccp_alpha=0.0,
1690
+ max_samples=None,
1691
+ monotonic_cst=None,
1692
+ max_bins=256,
1693
+ min_bin_size=1,
1694
+ ):
1695
+ super().__init__(
1696
+ ExtraTreeRegressor(),
1697
+ n_estimators=n_estimators,
1698
+ estimator_params=(
1699
+ "criterion",
1700
+ "max_depth",
1701
+ "min_samples_split",
1702
+ "min_samples_leaf",
1703
+ "min_weight_fraction_leaf",
1704
+ "max_features",
1705
+ "max_leaf_nodes",
1706
+ "min_impurity_decrease",
1707
+ "random_state",
1708
+ "ccp_alpha",
1709
+ "monotonic_cst",
1710
+ ),
1711
+ bootstrap=bootstrap,
1712
+ oob_score=oob_score,
1713
+ n_jobs=n_jobs,
1714
+ random_state=random_state,
1715
+ verbose=verbose,
1716
+ warm_start=warm_start,
1717
+ max_samples=max_samples,
1718
+ )
1719
+
1720
+ self.criterion = criterion
1721
+ self.max_depth = max_depth
1722
+ self.min_samples_split = min_samples_split
1723
+ self.min_samples_leaf = min_samples_leaf
1724
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1725
+ self.max_features = max_features
1726
+ self.max_leaf_nodes = max_leaf_nodes
1727
+ self.min_impurity_decrease = min_impurity_decrease
1728
+ self.ccp_alpha = ccp_alpha
1729
+ self.max_bins = max_bins
1730
+ self.min_bin_size = min_bin_size
1731
+ self.monotonic_cst = monotonic_cst
1732
+
1733
+ else:
1734
+
1735
+ def __init__(
1736
+ self,
1737
+ n_estimators=100,
1738
+ *,
1739
+ criterion="squared_error",
1740
+ max_depth=None,
1741
+ min_samples_split=2,
1742
+ min_samples_leaf=1,
1743
+ min_weight_fraction_leaf=0.0,
1744
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1745
+ max_leaf_nodes=None,
1746
+ min_impurity_decrease=0.0,
1747
+ bootstrap=False,
1748
+ oob_score=False,
1749
+ n_jobs=None,
1750
+ random_state=None,
1751
+ verbose=0,
1752
+ warm_start=False,
1753
+ ccp_alpha=0.0,
1754
+ max_samples=None,
1755
+ max_bins=256,
1756
+ min_bin_size=1,
1757
+ ):
1758
+ super().__init__(
1759
+ ExtraTreeRegressor(),
1760
+ n_estimators=n_estimators,
1761
+ estimator_params=(
1762
+ "criterion",
1763
+ "max_depth",
1764
+ "min_samples_split",
1765
+ "min_samples_leaf",
1766
+ "min_weight_fraction_leaf",
1767
+ "max_features",
1768
+ "max_leaf_nodes",
1769
+ "min_impurity_decrease",
1770
+ "random_state",
1771
+ "ccp_alpha",
1772
+ ),
1773
+ bootstrap=bootstrap,
1774
+ oob_score=oob_score,
1775
+ n_jobs=n_jobs,
1776
+ random_state=random_state,
1777
+ verbose=verbose,
1778
+ warm_start=warm_start,
1779
+ max_samples=max_samples,
1780
+ )
1781
+
1782
+ self.criterion = criterion
1783
+ self.max_depth = max_depth
1784
+ self.min_samples_split = min_samples_split
1785
+ self.min_samples_leaf = min_samples_leaf
1786
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1787
+ self.max_features = max_features
1788
+ self.max_leaf_nodes = max_leaf_nodes
1789
+ self.min_impurity_decrease = min_impurity_decrease
1790
+ self.ccp_alpha = ccp_alpha
1791
+ self.max_bins = max_bins
1792
+ self.min_bin_size = min_bin_size
1793
+
1794
+
1795
+ # Allow for isinstance calls without inheritance changes using ABCMeta
1796
+ _sklearn_RandomForestClassifier.register(RandomForestClassifier)
1797
+ _sklearn_RandomForestRegressor.register(RandomForestRegressor)
1798
+ _sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
1799
+ _sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)