scikit-learn-intelex 2024.4.0__py312-none-win_amd64.whl → 2025.10.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
  2. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
  3. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp312-win_amd64.pyd +0 -0
  4. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
  5. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
  6. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
  7. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
  8. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp312-win_amd64.pyd +0 -0
  9. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
  10. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
  11. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
  12. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
  13. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
  14. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
  15. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  16. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn}/decomposition/__init__.py +2 -2
  17. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
  18. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  19. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
  20. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
  21. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
  22. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  23. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -28
  24. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
  25. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
  26. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
  27. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
  28. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
  29. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  30. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
  31. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
  32. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
  33. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
  34. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  35. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +3 -3
  36. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
  37. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
  38. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
  39. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
  40. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +4 -2
  41. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
  42. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  43. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
  44. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  45. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
  46. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  47. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  48. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  49. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
  50. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
  51. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
  52. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
  53. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  54. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  55. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
  56. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
  57. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/covariance → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils}/__init__.py +5 -3
  58. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
  59. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
  60. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
  61. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
  62. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
  63. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
  64. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp312-win_amd64.pyd +0 -0
  65. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp312-win_amd64.pyd +0 -0
  66. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
  67. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
  68. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
  69. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  70. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
  71. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
  72. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
  73. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
  74. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
  75. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
  76. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
  77. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
  78. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
  79. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
  80. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
  81. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
  82. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
  83. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
  84. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
  85. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
  86. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
  87. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
  88. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
  89. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
  90. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
  91. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
  92. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
  93. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
  94. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
  95. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
  96. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
  97. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
  98. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
  99. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
  100. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
  101. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
  102. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
  103. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
  104. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
  105. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
  106. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
  107. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  108. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  109. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
  110. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
  111. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
  112. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal}/neighbors/__init__.py +19 -19
  113. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
  114. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
  115. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
  116. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
  117. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
  118. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
  119. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
  120. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
  121. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
  122. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
  123. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
  124. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
  125. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
  126. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
  127. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
  128. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
  129. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
  130. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
  131. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
  132. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
  133. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
  134. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/__init__.py +7 -3
  135. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/__main__.py +2 -2
  136. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
  137. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
  138. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
  139. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
  140. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  141. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
  142. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +128 -78
  143. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  144. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +101 -32
  145. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
  146. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +38 -29
  147. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
  148. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
  149. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
  150. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/conftest.py +20 -1
  151. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
  152. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
  153. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
  154. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
  155. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +199 -21
  156. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +207 -2
  157. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -17
  158. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
  159. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
  160. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +288 -440
  161. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
  162. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +1 -1
  163. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +17 -3
  164. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
  165. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
  166. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
  167. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
  168. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
  169. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
  170. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
  171. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  172. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  173. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
  174. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
  175. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
  176. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +11 -0
  177. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
  178. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +3 -0
  179. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
  180. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +3 -0
  181. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +30 -62
  182. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +56 -9
  183. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +45 -101
  184. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +63 -94
  185. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +49 -25
  186. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +6 -4
  187. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  188. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
  189. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +54 -8
  190. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  191. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
  192. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
  193. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
  194. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
  195. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  196. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
  197. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
  198. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +9 -4
  199. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
  200. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
  201. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +3 -4
  202. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
  203. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
  204. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
  205. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
  206. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
  207. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
  208. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +6 -4
  209. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
  210. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
  211. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
  212. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
  213. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
  214. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
  215. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +7 -4
  216. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
  217. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
  218. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
  219. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +1 -3
  220. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
  221. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
  222. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +99 -117
  223. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +55 -16
  224. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +95 -113
  225. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +51 -16
  226. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +43 -20
  227. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
  228. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
  229. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
  230. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
  231. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +5 -4
  232. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
  233. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +122 -75
  234. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
  235. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
  236. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
  237. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
  238. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/validation.py → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -1
  239. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
  240. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
  241. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
  242. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
  243. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
  244. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
  245. scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
  246. scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
  247. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/WHEEL +1 -1
  248. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_config.py +0 -110
  249. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -250
  250. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/_utils.py +0 -109
  251. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
  252. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
  253. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
  254. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -143
  255. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -335
  256. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -56
  257. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -113
  258. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -316
  259. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
  260. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +0 -385
  261. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -117
  262. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -91
  263. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -26
  264. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
  265. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -303
  266. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -133
  267. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -50
  268. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -71
  269. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
  270. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +0 -164
  271. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -39
  272. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -227
  273. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -99
  274. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
  275. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -20
  276. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +0 -97
  277. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -59
  278. scikit_learn_intelex-2024.4.0.dist-info/METADATA +0 -230
  279. scikit_learn_intelex-2024.4.0.dist-info/RECORD +0 -101
  280. {scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal}/basic_statistics/__init__.py +0 -0
  281. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  282. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  283. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  284. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  285. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  286. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  287. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  288. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  289. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  290. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  291. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  292. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  293. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  294. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  295. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  296. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  297. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2025.10.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  298. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/LICENSE.txt +0 -0
  299. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2025.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,781 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABCMeta, abstractmethod
20
+ from math import ceil
21
+
22
+ import numpy as np
23
+ from sklearn.ensemble import BaseEnsemble
24
+ from sklearn.utils import check_random_state
25
+
26
+ from daal4py.sklearn._utils import daal_check_version
27
+ from onedal._device_offload import supports_queue
28
+ from onedal.common._backend import bind_default_backend
29
+ from onedal.utils import _sycl_queue_manager as QM
30
+ from sklearnex import get_hyperparameters
31
+
32
+ from .._config import _get_config
33
+ from ..common._estimator_checks import _check_is_fitted
34
+ from ..common._mixin import ClassifierMixin, RegressorMixin
35
+ from ..datatypes import from_table, to_table
36
+ from ..utils._array_api import _get_sycl_namespace
37
+ from ..utils.validation import (
38
+ _check_array,
39
+ _check_n_features,
40
+ _check_X_y,
41
+ _column_or_1d,
42
+ _validate_targets,
43
+ )
44
+
45
+
46
+ class BaseForest(BaseEnsemble, metaclass=ABCMeta):
47
+ @abstractmethod
48
+ def __init__(
49
+ self,
50
+ n_estimators,
51
+ criterion,
52
+ max_depth,
53
+ min_samples_split,
54
+ min_samples_leaf,
55
+ min_weight_fraction_leaf,
56
+ max_features,
57
+ max_leaf_nodes,
58
+ min_impurity_decrease,
59
+ min_impurity_split,
60
+ bootstrap,
61
+ oob_score,
62
+ random_state,
63
+ warm_start,
64
+ class_weight,
65
+ ccp_alpha,
66
+ max_samples,
67
+ max_bins,
68
+ min_bin_size,
69
+ infer_mode,
70
+ splitter_mode,
71
+ voting_mode,
72
+ error_metric_mode,
73
+ variable_importance_mode,
74
+ algorithm,
75
+ **kwargs,
76
+ ):
77
+ self.n_estimators = n_estimators
78
+ self.bootstrap = bootstrap
79
+ self.oob_score = oob_score
80
+ self.random_state = random_state
81
+ self.warm_start = warm_start
82
+ self.class_weight = class_weight
83
+ self.max_samples = max_samples
84
+ self.criterion = criterion
85
+ self.max_depth = max_depth
86
+ self.min_samples_split = min_samples_split
87
+ self.min_samples_leaf = min_samples_leaf
88
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
89
+ self.max_features = max_features
90
+ self.max_leaf_nodes = max_leaf_nodes
91
+ self.min_impurity_decrease = min_impurity_decrease
92
+ self.min_impurity_split = min_impurity_split
93
+ self.ccp_alpha = ccp_alpha
94
+ self.max_bins = max_bins
95
+ self.min_bin_size = min_bin_size
96
+ self.infer_mode = infer_mode
97
+ self.splitter_mode = splitter_mode
98
+ self.voting_mode = voting_mode
99
+ self.error_metric_mode = error_metric_mode
100
+ self.variable_importance_mode = variable_importance_mode
101
+ self.algorithm = algorithm
102
+
103
+ @abstractmethod
104
+ def train(self, *args, **kwargs): ...
105
+
106
+ @abstractmethod
107
+ def infer(self, *args, **kwargs): ...
108
+
109
+ def _to_absolute_max_features(self, n_features):
110
+ if self.max_features is None:
111
+ return n_features
112
+ elif isinstance(self.max_features, str):
113
+ return max(1, int(getattr(np, self.max_features)(n_features)))
114
+ elif isinstance(self.max_features, (numbers.Integral, np.integer)):
115
+ return self.max_features
116
+ elif self.max_features > 0.0:
117
+ return max(1, int(self.max_features * n_features))
118
+ return 0
119
+
120
+ def _get_observations_per_tree_fraction(self, n_samples, max_samples):
121
+ if max_samples is None:
122
+ return 1.0
123
+
124
+ if isinstance(max_samples, numbers.Integral):
125
+ if not (1 <= max_samples <= n_samples):
126
+ msg = "`max_samples` must be in range 1 to {} but got value {}"
127
+ raise ValueError(msg.format(n_samples, max_samples))
128
+ return max(float(max_samples / n_samples), 1 / n_samples)
129
+
130
+ if isinstance(max_samples, numbers.Real):
131
+ return max(float(max_samples), 1 / n_samples)
132
+
133
+ msg = "`max_samples` should be int or float, but got type '{}'"
134
+ raise TypeError(msg.format(type(max_samples)))
135
+
136
+ def _get_onedal_params(self, data):
137
+ n_samples, n_features = data.shape
138
+
139
+ self.observations_per_tree_fraction = self._get_observations_per_tree_fraction(
140
+ n_samples=n_samples, max_samples=self.max_samples
141
+ )
142
+ self.observations_per_tree_fraction = (
143
+ self.observations_per_tree_fraction if bool(self.bootstrap) else 1.0
144
+ )
145
+
146
+ if not self.bootstrap and self.max_samples is not None:
147
+ raise ValueError(
148
+ "`max_sample` cannot be set if `bootstrap=False`. "
149
+ "Either switch to `bootstrap=True` or set "
150
+ "`max_sample=None`."
151
+ )
152
+ if not self.bootstrap and self.oob_score:
153
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
154
+
155
+ min_observations_in_leaf_node = (
156
+ self.min_samples_leaf
157
+ if isinstance(self.min_samples_leaf, numbers.Integral)
158
+ else int(ceil(self.min_samples_leaf * n_samples))
159
+ )
160
+
161
+ min_observations_in_split_node = (
162
+ self.min_samples_split
163
+ if isinstance(self.min_samples_split, numbers.Integral)
164
+ else int(ceil(self.min_samples_split * n_samples))
165
+ )
166
+
167
+ rs = check_random_state(self.random_state)
168
+ seed = rs.randint(0, np.iinfo("i").max)
169
+
170
+ onedal_params = {
171
+ "fptype": data.dtype,
172
+ "method": self.algorithm,
173
+ "infer_mode": self.infer_mode,
174
+ "voting_mode": self.voting_mode,
175
+ "observations_per_tree_fraction": self.observations_per_tree_fraction,
176
+ "impurity_threshold": float(
177
+ 0.0 if self.min_impurity_split is None else self.min_impurity_split
178
+ ),
179
+ "min_weight_fraction_in_leaf_node": self.min_weight_fraction_leaf,
180
+ "min_impurity_decrease_in_split_node": self.min_impurity_decrease,
181
+ "tree_count": int(self.n_estimators),
182
+ "features_per_node": self._to_absolute_max_features(n_features),
183
+ "max_tree_depth": int(0 if self.max_depth is None else self.max_depth),
184
+ "min_observations_in_leaf_node": min_observations_in_leaf_node,
185
+ "min_observations_in_split_node": min_observations_in_split_node,
186
+ "max_leaf_nodes": (0 if self.max_leaf_nodes is None else self.max_leaf_nodes),
187
+ "max_bins": self.max_bins,
188
+ "min_bin_size": self.min_bin_size,
189
+ "seed": seed,
190
+ "memory_saving_mode": False,
191
+ "bootstrap": bool(self.bootstrap),
192
+ "error_metric_mode": self.error_metric_mode,
193
+ "variable_importance_mode": self.variable_importance_mode,
194
+ }
195
+ if isinstance(self, ClassifierMixin):
196
+ onedal_params["class_count"] = (
197
+ 0 if self.classes_ is None else len(self.classes_)
198
+ )
199
+ if daal_check_version((2023, "P", 101)):
200
+ onedal_params["splitter_mode"] = self.splitter_mode
201
+ return onedal_params
202
+
203
+ def _check_parameters(self):
204
+ if isinstance(self.min_samples_leaf, numbers.Integral):
205
+ if not 1 <= self.min_samples_leaf:
206
+ raise ValueError(
207
+ "min_samples_leaf must be at least 1 "
208
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
209
+ )
210
+ else: # float
211
+ if not 0.0 < self.min_samples_leaf <= 0.5:
212
+ raise ValueError(
213
+ "min_samples_leaf must be at least 1 "
214
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
215
+ )
216
+ if isinstance(self.min_samples_split, numbers.Integral):
217
+ if not 2 <= self.min_samples_split:
218
+ raise ValueError(
219
+ "min_samples_split must be an integer "
220
+ "greater than 1 or a float in (0.0, 1.0]; "
221
+ "got the integer %s" % self.min_samples_split
222
+ )
223
+ else: # float
224
+ if not 0.0 < self.min_samples_split <= 1.0:
225
+ raise ValueError(
226
+ "min_samples_split must be an integer "
227
+ "greater than 1 or a float in (0.0, 1.0]; "
228
+ "got the float %s" % self.min_samples_split
229
+ )
230
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
231
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
232
+ if self.min_impurity_split is not None:
233
+ warnings.warn(
234
+ "The min_impurity_split parameter is deprecated. "
235
+ "Its default value has changed from 1e-7 to 0 in "
236
+ "version 0.23, and it will be removed in 0.25. "
237
+ "Use the min_impurity_decrease parameter instead.",
238
+ FutureWarning,
239
+ )
240
+
241
+ if self.min_impurity_split < 0.0:
242
+ raise ValueError(
243
+ "min_impurity_split must be greater than " "or equal to 0"
244
+ )
245
+ if self.min_impurity_decrease < 0.0:
246
+ raise ValueError(
247
+ "min_impurity_decrease must be greater than " "or equal to 0"
248
+ )
249
+ if self.max_leaf_nodes is not None:
250
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
251
+ raise ValueError(
252
+ "max_leaf_nodes must be integral number but was "
253
+ "%r" % self.max_leaf_nodes
254
+ )
255
+ if self.max_leaf_nodes < 2:
256
+ raise ValueError(
257
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
258
+ self.max_leaf_nodes
259
+ )
260
+ )
261
+ if isinstance(self.max_bins, numbers.Integral):
262
+ if not 2 <= self.max_bins:
263
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
264
+ else:
265
+ raise ValueError(
266
+ "max_bins must be integral number but was " "%r" % self.max_bins
267
+ )
268
+ if isinstance(self.min_bin_size, numbers.Integral):
269
+ if not 1 <= self.min_bin_size:
270
+ raise ValueError(
271
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
272
+ )
273
+ else:
274
+ raise ValueError(
275
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
276
+ )
277
+
278
+ def _validate_targets(self, y, dtype):
279
+ self.class_weight_ = None
280
+ self.classes_ = None
281
+ return _column_or_1d(y, warn=True).astype(dtype, copy=False)
282
+
283
+ def _get_sample_weight(self, sample_weight, X):
284
+ sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel()
285
+
286
+ sample_weight = _check_array(
287
+ sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C"
288
+ )
289
+
290
+ if sample_weight.size != X.shape[0]:
291
+ raise ValueError(
292
+ "sample_weight and X have incompatible shapes: "
293
+ "%r vs %r\n"
294
+ "Note: Sparse matrices cannot be indexed w/"
295
+ "boolean masks (use `indices=True` in CV)."
296
+ % (sample_weight.shape, X.shape)
297
+ )
298
+
299
+ return sample_weight
300
+
301
+ def _fit(self, X, y, sample_weight):
302
+ use_raw_input = _get_config().get("use_raw_input", False) is True
303
+ sua_iface, xp, _ = _get_sycl_namespace(X)
304
+
305
+ if not use_raw_input:
306
+ X, y = _check_X_y(
307
+ X,
308
+ y,
309
+ dtype=[np.float64, np.float32],
310
+ force_all_finite=True,
311
+ accept_sparse="csr",
312
+ )
313
+ y = self._validate_targets(y, X.dtype)
314
+ else:
315
+ if sua_iface is not None:
316
+ queue = X.sycl_queue
317
+ # try catch needed for raw_inputs + array_api data where unlike
318
+ # numpy the way to yield unique values is via `unique_values`
319
+ # This should be removed when refactored for gpu zero-copy
320
+ try:
321
+ self.classes_ = xp.unique(y)
322
+ except AttributeError:
323
+ self.classes_ = xp.unique_values(y)
324
+
325
+ self.n_features_in_ = X.shape[1]
326
+
327
+ if sample_weight is not None and len(sample_weight) > 0:
328
+ if not use_raw_input:
329
+ sample_weight = self._get_sample_weight(sample_weight, X)
330
+ data = (X, y, sample_weight)
331
+ else:
332
+ data = (X, y)
333
+ data = to_table(*data, queue=QM.get_global_queue())
334
+ params = self._get_onedal_params(data[0])
335
+ train_result = self.train(params, *data)
336
+
337
+ self._onedal_model = train_result.model
338
+
339
+ if self.oob_score:
340
+ if isinstance(self, ClassifierMixin):
341
+ self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
342
+ self.oob_decision_function_ = from_table(
343
+ train_result.oob_err_decision_function
344
+ )
345
+ if xp.any(self.oob_decision_function_ == 0):
346
+ warnings.warn(
347
+ "Some inputs do not have OOB scores. This probably means "
348
+ "too few trees were used to compute any reliable OOB "
349
+ "estimates.",
350
+ UserWarning,
351
+ )
352
+ else:
353
+ self.oob_score_ = from_table(train_result.oob_err_r2).item()
354
+ self.oob_prediction_ = from_table(
355
+ train_result.oob_err_prediction
356
+ ).reshape(-1)
357
+ if np.any(self.oob_prediction_ == 0):
358
+ warnings.warn(
359
+ "Some inputs do not have OOB scores. This probably means "
360
+ "too few trees were used to compute any reliable OOB "
361
+ "estimates.",
362
+ UserWarning,
363
+ )
364
+
365
+ return self
366
+
367
+ def _create_model(self, module):
368
+ # TODO:
369
+ # update error msg.
370
+ raise NotImplementedError("Creating model is not supported.")
371
+
372
+ def _predict(self, X, hparams=None):
373
+ _check_is_fitted(self)
374
+
375
+ use_raw_input = _get_config().get("use_raw_input", False) is True
376
+ sua_iface, xp, _ = _get_sycl_namespace(X)
377
+
378
+ # All data should use the same sycl queue
379
+ if use_raw_input and sua_iface is not None:
380
+ queue = X.sycl_queue
381
+
382
+ if not use_raw_input:
383
+ X = _check_array(
384
+ X,
385
+ dtype=[np.float64, np.float32],
386
+ force_all_finite=True,
387
+ accept_sparse=False,
388
+ )
389
+ _check_n_features(self, X, False)
390
+
391
+ model = self._onedal_model
392
+ queue = QM.get_global_queue()
393
+ X_table = to_table(X, queue=queue)
394
+ params = self._get_onedal_params(X_table)
395
+ if hparams is not None and not hparams.is_default:
396
+ result = self.infer(params, hparams.backend, model, X_table)
397
+ else:
398
+ result = self.infer(params, model, X_table)
399
+
400
+ y = from_table(result.responses, like=X)
401
+ return y
402
+
403
+ def _predict_proba(self, X, hparams=None):
404
+ _check_is_fitted(self)
405
+ use_raw_input = _get_config().get("use_raw_input", False) is True
406
+ sua_iface, xp, _ = _get_sycl_namespace(X)
407
+
408
+ # All data should use the same sycl queue
409
+ if use_raw_input and sua_iface is not None:
410
+ queue = X.sycl_queue
411
+ else:
412
+ queue = QM.get_global_queue()
413
+
414
+ if not use_raw_input:
415
+ X = _check_array(
416
+ X,
417
+ dtype=[np.float64, np.float32],
418
+ force_all_finite=True,
419
+ accept_sparse=False,
420
+ )
421
+ _check_n_features(self, X, False)
422
+ X = to_table(X, queue=queue)
423
+ params = self._get_onedal_params(X)
424
+ params["infer_mode"] = "class_probabilities"
425
+
426
+ model = self._onedal_model
427
+ if hparams is not None and not hparams.is_default:
428
+ result = self.infer(params, hparams.backend, model, X)
429
+ else:
430
+ result = self.infer(params, model, X)
431
+
432
+ # TODO: fix probabilities out of [0, 1] interval on oneDAL side
433
+ pred = from_table(result.probabilities)
434
+ return pred.clip(0.0, 1.0)
435
+
436
+
437
+ class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
438
+ def __init__(
439
+ self,
440
+ n_estimators=100,
441
+ criterion="gini",
442
+ max_depth=None,
443
+ min_samples_split=2,
444
+ min_samples_leaf=1,
445
+ min_weight_fraction_leaf=0.0,
446
+ max_features="sqrt",
447
+ max_leaf_nodes=None,
448
+ min_impurity_decrease=0.0,
449
+ min_impurity_split=None,
450
+ bootstrap=True,
451
+ oob_score=False,
452
+ random_state=None,
453
+ warm_start=False,
454
+ class_weight=None,
455
+ ccp_alpha=0.0,
456
+ max_samples=None,
457
+ max_bins=256,
458
+ min_bin_size=1,
459
+ infer_mode="class_responses",
460
+ splitter_mode="best",
461
+ voting_mode="weighted",
462
+ error_metric_mode="none",
463
+ variable_importance_mode="none",
464
+ algorithm="hist",
465
+ **kwargs,
466
+ ):
467
+ super().__init__(
468
+ n_estimators=n_estimators,
469
+ criterion=criterion,
470
+ max_depth=max_depth,
471
+ min_samples_split=min_samples_split,
472
+ min_samples_leaf=min_samples_leaf,
473
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
474
+ max_features=max_features,
475
+ max_leaf_nodes=max_leaf_nodes,
476
+ min_impurity_decrease=min_impurity_decrease,
477
+ min_impurity_split=min_impurity_split,
478
+ bootstrap=bootstrap,
479
+ oob_score=oob_score,
480
+ random_state=random_state,
481
+ warm_start=warm_start,
482
+ class_weight=class_weight,
483
+ ccp_alpha=ccp_alpha,
484
+ max_samples=max_samples,
485
+ max_bins=max_bins,
486
+ min_bin_size=min_bin_size,
487
+ infer_mode=infer_mode,
488
+ splitter_mode=splitter_mode,
489
+ voting_mode=voting_mode,
490
+ error_metric_mode=error_metric_mode,
491
+ variable_importance_mode=variable_importance_mode,
492
+ algorithm=algorithm,
493
+ )
494
+
495
+ @bind_default_backend("decision_forest.classification")
496
+ def train(self, *args, **kwargs): ...
497
+
498
+ @bind_default_backend("decision_forest.classification")
499
+ def infer(self, *args, **kwargs): ...
500
+
501
+ def _validate_targets(self, y, dtype):
502
+ y, self.class_weight_, self.classes_ = _validate_targets(
503
+ y, self.class_weight, dtype
504
+ )
505
+
506
+ # Decapsulate classes_ attributes
507
+ # TODO:
508
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
509
+ # if hasattr(self, "classes_"):
510
+ # self.n_classes_ = self.classes_
511
+ return y
512
+
513
+ @supports_queue
514
+ def fit(self, X, y, sample_weight=None, queue=None):
515
+ return self._fit(X, y, sample_weight)
516
+
517
+ @supports_queue
518
+ def predict(self, X, queue=None):
519
+ _, xp, _ = _get_sycl_namespace(X)
520
+ hparams = get_hyperparameters("decision_forest", "infer")
521
+ pred = xp.reshape(self._predict(X, hparams), -1)
522
+
523
+ try:
524
+ return xp.take(
525
+ xp.asarray(self.classes_, device=pred.sycl_queue),
526
+ xp.astype(xp.reshape(pred, (-1,)), xp.int64),
527
+ )
528
+ except AttributeError:
529
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
530
+
531
+ @supports_queue
532
+ def predict_proba(self, X, queue=None):
533
+ hparams = get_hyperparameters("decision_forest", "infer")
534
+
535
+ return super()._predict_proba(X, hparams)
536
+
537
+
538
+ class RandomForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
539
+ def __init__(
540
+ self,
541
+ n_estimators=100,
542
+ criterion="squared_error",
543
+ max_depth=None,
544
+ min_samples_split=2,
545
+ min_samples_leaf=1,
546
+ min_weight_fraction_leaf=0.0,
547
+ max_features=1.0,
548
+ max_leaf_nodes=None,
549
+ min_impurity_decrease=0.0,
550
+ min_impurity_split=None,
551
+ bootstrap=True,
552
+ oob_score=False,
553
+ random_state=None,
554
+ warm_start=False,
555
+ class_weight=None,
556
+ ccp_alpha=0.0,
557
+ max_samples=None,
558
+ max_bins=256,
559
+ min_bin_size=1,
560
+ infer_mode="class_responses",
561
+ splitter_mode="best",
562
+ voting_mode="weighted",
563
+ error_metric_mode="none",
564
+ variable_importance_mode="none",
565
+ algorithm="hist",
566
+ **kwargs,
567
+ ):
568
+ super().__init__(
569
+ n_estimators=n_estimators,
570
+ criterion=criterion,
571
+ max_depth=max_depth,
572
+ min_samples_split=min_samples_split,
573
+ min_samples_leaf=min_samples_leaf,
574
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
575
+ max_features=max_features,
576
+ max_leaf_nodes=max_leaf_nodes,
577
+ min_impurity_decrease=min_impurity_decrease,
578
+ min_impurity_split=min_impurity_split,
579
+ bootstrap=bootstrap,
580
+ oob_score=oob_score,
581
+ random_state=random_state,
582
+ warm_start=warm_start,
583
+ class_weight=class_weight,
584
+ ccp_alpha=ccp_alpha,
585
+ max_samples=max_samples,
586
+ max_bins=max_bins,
587
+ min_bin_size=min_bin_size,
588
+ infer_mode=infer_mode,
589
+ splitter_mode=splitter_mode,
590
+ voting_mode=voting_mode,
591
+ error_metric_mode=error_metric_mode,
592
+ variable_importance_mode=variable_importance_mode,
593
+ algorithm=algorithm,
594
+ )
595
+
596
+ @bind_default_backend("decision_forest.regression")
597
+ def train(self, *args, **kwargs): ...
598
+
599
+ @bind_default_backend("decision_forest.regression")
600
+ def infer(self, *args, **kwargs): ...
601
+
602
+ @supports_queue
603
+ def fit(self, X, y, sample_weight=None, queue=None):
604
+ if sample_weight is not None:
605
+ if hasattr(sample_weight, "__array__"):
606
+ sample_weight[sample_weight == 0.0] = 1.0
607
+ sample_weight = [sample_weight]
608
+ return self._fit(X, y, sample_weight)
609
+
610
+ @supports_queue
611
+ def predict(self, X, queue=None):
612
+ _, xp, _ = _get_sycl_namespace(X)
613
+ return xp.reshape(self._predict(X), -1)
614
+
615
+
616
+ class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
617
+ def __init__(
618
+ self,
619
+ n_estimators=100,
620
+ criterion="gini",
621
+ max_depth=None,
622
+ min_samples_split=2,
623
+ min_samples_leaf=1,
624
+ min_weight_fraction_leaf=0.0,
625
+ max_features="sqrt",
626
+ max_leaf_nodes=None,
627
+ min_impurity_decrease=0.0,
628
+ min_impurity_split=None,
629
+ bootstrap=False,
630
+ oob_score=False,
631
+ random_state=None,
632
+ warm_start=False,
633
+ class_weight=None,
634
+ ccp_alpha=0.0,
635
+ max_samples=None,
636
+ max_bins=256,
637
+ min_bin_size=1,
638
+ infer_mode="class_responses",
639
+ splitter_mode="random",
640
+ voting_mode="weighted",
641
+ error_metric_mode="none",
642
+ variable_importance_mode="none",
643
+ algorithm="hist",
644
+ **kwargs,
645
+ ):
646
+ super().__init__(
647
+ n_estimators=n_estimators,
648
+ criterion=criterion,
649
+ max_depth=max_depth,
650
+ min_samples_split=min_samples_split,
651
+ min_samples_leaf=min_samples_leaf,
652
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
653
+ max_features=max_features,
654
+ max_leaf_nodes=max_leaf_nodes,
655
+ min_impurity_decrease=min_impurity_decrease,
656
+ min_impurity_split=min_impurity_split,
657
+ bootstrap=bootstrap,
658
+ oob_score=oob_score,
659
+ random_state=random_state,
660
+ warm_start=warm_start,
661
+ class_weight=class_weight,
662
+ ccp_alpha=ccp_alpha,
663
+ max_samples=max_samples,
664
+ max_bins=max_bins,
665
+ min_bin_size=min_bin_size,
666
+ infer_mode=infer_mode,
667
+ splitter_mode=splitter_mode,
668
+ voting_mode=voting_mode,
669
+ error_metric_mode=error_metric_mode,
670
+ variable_importance_mode=variable_importance_mode,
671
+ algorithm=algorithm,
672
+ )
673
+
674
+ @bind_default_backend("decision_forest.classification")
675
+ def train(self, *args, **kwargs): ...
676
+
677
+ @bind_default_backend("decision_forest.classification")
678
+ def infer(self, *args, **kwargs): ...
679
+
680
+ def _validate_targets(self, y, dtype):
681
+ y, self.class_weight_, self.classes_ = _validate_targets(
682
+ y, self.class_weight, dtype
683
+ )
684
+
685
+ # Decapsulate classes_ attributes
686
+ # TODO:
687
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
688
+ # if hasattr(self, "classes_"):
689
+ # self.n_classes_ = self.classes_
690
+ return y
691
+
692
+ @supports_queue
693
+ def fit(self, X, y, sample_weight=None, queue=None):
694
+ return self._fit(X, y, sample_weight)
695
+
696
+ @supports_queue
697
+ def predict(self, X, queue=None):
698
+ pred = self._predict(X)
699
+
700
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
701
+
702
+ @supports_queue
703
+ def predict_proba(self, X, queue=None):
704
+ return super()._predict_proba(X)
705
+
706
+
707
+ class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
708
+ def __init__(
709
+ self,
710
+ n_estimators=100,
711
+ criterion="squared_error",
712
+ max_depth=None,
713
+ min_samples_split=2,
714
+ min_samples_leaf=1,
715
+ min_weight_fraction_leaf=0.0,
716
+ max_features=1.0,
717
+ max_leaf_nodes=None,
718
+ min_impurity_decrease=0.0,
719
+ min_impurity_split=None,
720
+ bootstrap=False,
721
+ oob_score=False,
722
+ random_state=None,
723
+ warm_start=False,
724
+ class_weight=None,
725
+ ccp_alpha=0.0,
726
+ max_samples=None,
727
+ max_bins=256,
728
+ min_bin_size=1,
729
+ infer_mode="class_responses",
730
+ splitter_mode="random",
731
+ voting_mode="weighted",
732
+ error_metric_mode="none",
733
+ variable_importance_mode="none",
734
+ algorithm="hist",
735
+ **kwargs,
736
+ ):
737
+ super().__init__(
738
+ n_estimators=n_estimators,
739
+ criterion=criterion,
740
+ max_depth=max_depth,
741
+ min_samples_split=min_samples_split,
742
+ min_samples_leaf=min_samples_leaf,
743
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
744
+ max_features=max_features,
745
+ max_leaf_nodes=max_leaf_nodes,
746
+ min_impurity_decrease=min_impurity_decrease,
747
+ min_impurity_split=min_impurity_split,
748
+ bootstrap=bootstrap,
749
+ oob_score=oob_score,
750
+ random_state=random_state,
751
+ warm_start=warm_start,
752
+ class_weight=class_weight,
753
+ ccp_alpha=ccp_alpha,
754
+ max_samples=max_samples,
755
+ max_bins=max_bins,
756
+ min_bin_size=min_bin_size,
757
+ infer_mode=infer_mode,
758
+ splitter_mode=splitter_mode,
759
+ voting_mode=voting_mode,
760
+ error_metric_mode=error_metric_mode,
761
+ variable_importance_mode=variable_importance_mode,
762
+ algorithm=algorithm,
763
+ )
764
+
765
+ @bind_default_backend("decision_forest.regression")
766
+ def train(self, *args, **kwargs): ...
767
+
768
+ @bind_default_backend("decision_forest.regression")
769
+ def infer(self, *args, **kwargs): ...
770
+
771
+ @supports_queue
772
+ def fit(self, X, y, sample_weight=None, queue=None):
773
+ if sample_weight is not None:
774
+ if hasattr(sample_weight, "__array__"):
775
+ sample_weight[sample_weight == 0.0] = 1.0
776
+ sample_weight = [sample_weight]
777
+ return self._fit(X, y, sample_weight)
778
+
779
+ @supports_queue
780
+ def predict(self, X, queue=None):
781
+ return self._predict(X).ravel()