scikit-learn-intelex 2025.0.0__py310-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (278) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-310-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-310-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +242 -0
  10. daal4py/sklearn/_utils.py +241 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +155 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +53 -0
  61. onedal/_device_offload.py +229 -0
  62. onedal/_onedal_py_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-310-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +560 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +116 -0
  83. onedal/common/tests/test_policy.py +75 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +95 -0
  91. onedal/datatypes/tests/test_data.py +235 -0
  92. onedal/decomposition/__init__.py +20 -0
  93. onedal/decomposition/incremental_pca.py +204 -0
  94. onedal/decomposition/pca.py +186 -0
  95. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  96. onedal/ensemble/__init__.py +29 -0
  97. onedal/ensemble/forest.py +720 -0
  98. onedal/ensemble/tests/test_random_forest.py +97 -0
  99. onedal/linear_model/__init__.py +27 -0
  100. onedal/linear_model/incremental_linear_model.py +258 -0
  101. onedal/linear_model/linear_model.py +329 -0
  102. onedal/linear_model/logistic_regression.py +249 -0
  103. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  104. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  105. onedal/linear_model/tests/test_linear_regression.py +149 -0
  106. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  107. onedal/linear_model/tests/test_ridge.py +95 -0
  108. onedal/neighbors/__init__.py +19 -0
  109. onedal/neighbors/neighbors.py +778 -0
  110. onedal/neighbors/tests/test_knn_classification.py +49 -0
  111. onedal/primitives/__init__.py +27 -0
  112. onedal/primitives/get_tree.py +25 -0
  113. onedal/primitives/kernel_functions.py +153 -0
  114. onedal/primitives/tests/test_kernel_functions.py +159 -0
  115. onedal/spmd/__init__.py +25 -0
  116. onedal/spmd/_base.py +30 -0
  117. onedal/spmd/basic_statistics/__init__.py +20 -0
  118. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  119. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  120. onedal/spmd/cluster/__init__.py +28 -0
  121. onedal/spmd/cluster/dbscan.py +23 -0
  122. onedal/spmd/cluster/kmeans.py +56 -0
  123. onedal/spmd/covariance/__init__.py +20 -0
  124. onedal/spmd/covariance/covariance.py +26 -0
  125. onedal/spmd/covariance/incremental_covariance.py +82 -0
  126. onedal/spmd/decomposition/__init__.py +20 -0
  127. onedal/spmd/decomposition/incremental_pca.py +117 -0
  128. onedal/spmd/decomposition/pca.py +26 -0
  129. onedal/spmd/ensemble/__init__.py +19 -0
  130. onedal/spmd/ensemble/forest.py +28 -0
  131. onedal/spmd/linear_model/__init__.py +21 -0
  132. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  133. onedal/spmd/linear_model/linear_model.py +30 -0
  134. onedal/spmd/linear_model/logistic_regression.py +38 -0
  135. onedal/spmd/neighbors/__init__.py +19 -0
  136. onedal/spmd/neighbors/neighbors.py +75 -0
  137. onedal/svm/__init__.py +19 -0
  138. onedal/svm/svm.py +556 -0
  139. onedal/svm/tests/test_csr_svm.py +351 -0
  140. onedal/svm/tests/test_nusvc.py +204 -0
  141. onedal/svm/tests/test_nusvr.py +210 -0
  142. onedal/svm/tests/test_svc.py +168 -0
  143. onedal/svm/tests/test_svr.py +243 -0
  144. onedal/tests/test_common.py +41 -0
  145. onedal/tests/utils/_dataframes_support.py +168 -0
  146. onedal/tests/utils/_device_selection.py +107 -0
  147. onedal/utils/__init__.py +49 -0
  148. onedal/utils/_array_api.py +91 -0
  149. onedal/utils/validation.py +432 -0
  150. scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
  151. scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
  152. scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
  153. scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
  154. scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
  155. sklearnex/__init__.py +65 -0
  156. sklearnex/__main__.py +58 -0
  157. sklearnex/_config.py +98 -0
  158. sklearnex/_device_offload.py +121 -0
  159. sklearnex/_utils.py +109 -0
  160. sklearnex/basic_statistics/__init__.py +20 -0
  161. sklearnex/basic_statistics/basic_statistics.py +140 -0
  162. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  163. sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
  164. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
  165. sklearnex/cluster/__init__.py +20 -0
  166. sklearnex/cluster/dbscan.py +192 -0
  167. sklearnex/cluster/k_means.py +383 -0
  168. sklearnex/cluster/tests/test_dbscan.py +38 -0
  169. sklearnex/cluster/tests/test_kmeans.py +153 -0
  170. sklearnex/conftest.py +73 -0
  171. sklearnex/covariance/__init__.py +19 -0
  172. sklearnex/covariance/incremental_covariance.py +368 -0
  173. sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
  174. sklearnex/decomposition/__init__.py +19 -0
  175. sklearnex/decomposition/pca.py +414 -0
  176. sklearnex/decomposition/tests/test_pca.py +58 -0
  177. sklearnex/dispatcher.py +543 -0
  178. sklearnex/doc/third-party-programs.txt +424 -0
  179. sklearnex/ensemble/__init__.py +29 -0
  180. sklearnex/ensemble/_forest.py +2016 -0
  181. sklearnex/ensemble/tests/test_forest.py +120 -0
  182. sklearnex/glob/__main__.py +72 -0
  183. sklearnex/glob/dispatcher.py +101 -0
  184. sklearnex/linear_model/__init__.py +32 -0
  185. sklearnex/linear_model/coordinate_descent.py +30 -0
  186. sklearnex/linear_model/incremental_linear.py +463 -0
  187. sklearnex/linear_model/incremental_ridge.py +418 -0
  188. sklearnex/linear_model/linear.py +302 -0
  189. sklearnex/linear_model/logistic_path.py +17 -0
  190. sklearnex/linear_model/logistic_regression.py +403 -0
  191. sklearnex/linear_model/ridge.py +24 -0
  192. sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
  193. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  194. sklearnex/linear_model/tests/test_linear.py +142 -0
  195. sklearnex/linear_model/tests/test_logreg.py +134 -0
  196. sklearnex/manifold/__init__.py +19 -0
  197. sklearnex/manifold/t_sne.py +21 -0
  198. sklearnex/manifold/tests/test_tsne.py +26 -0
  199. sklearnex/metrics/__init__.py +23 -0
  200. sklearnex/metrics/pairwise.py +22 -0
  201. sklearnex/metrics/ranking.py +20 -0
  202. sklearnex/metrics/tests/test_metrics.py +39 -0
  203. sklearnex/model_selection/__init__.py +21 -0
  204. sklearnex/model_selection/split.py +22 -0
  205. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  206. sklearnex/neighbors/__init__.py +27 -0
  207. sklearnex/neighbors/_lof.py +231 -0
  208. sklearnex/neighbors/common.py +310 -0
  209. sklearnex/neighbors/knn_classification.py +226 -0
  210. sklearnex/neighbors/knn_regression.py +203 -0
  211. sklearnex/neighbors/knn_unsupervised.py +170 -0
  212. sklearnex/neighbors/tests/test_neighbors.py +80 -0
  213. sklearnex/preview/__init__.py +17 -0
  214. sklearnex/preview/covariance/__init__.py +19 -0
  215. sklearnex/preview/covariance/covariance.py +133 -0
  216. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  217. sklearnex/preview/decomposition/__init__.py +19 -0
  218. sklearnex/preview/decomposition/incremental_pca.py +228 -0
  219. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  220. sklearnex/preview/linear_model/__init__.py +19 -0
  221. sklearnex/preview/linear_model/ridge.py +419 -0
  222. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  223. sklearnex/spmd/__init__.py +25 -0
  224. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  225. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  226. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  227. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  228. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  229. sklearnex/spmd/cluster/__init__.py +30 -0
  230. sklearnex/spmd/cluster/dbscan.py +50 -0
  231. sklearnex/spmd/cluster/kmeans.py +21 -0
  232. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  233. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  234. sklearnex/spmd/covariance/__init__.py +20 -0
  235. sklearnex/spmd/covariance/covariance.py +21 -0
  236. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  237. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  238. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  239. sklearnex/spmd/decomposition/__init__.py +20 -0
  240. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  241. sklearnex/spmd/decomposition/pca.py +21 -0
  242. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  243. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  244. sklearnex/spmd/ensemble/__init__.py +19 -0
  245. sklearnex/spmd/ensemble/forest.py +71 -0
  246. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  247. sklearnex/spmd/linear_model/__init__.py +21 -0
  248. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  249. sklearnex/spmd/linear_model/linear_model.py +21 -0
  250. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  251. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  252. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  253. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
  254. sklearnex/spmd/neighbors/__init__.py +19 -0
  255. sklearnex/spmd/neighbors/neighbors.py +25 -0
  256. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  257. sklearnex/svm/__init__.py +29 -0
  258. sklearnex/svm/_common.py +328 -0
  259. sklearnex/svm/nusvc.py +332 -0
  260. sklearnex/svm/nusvr.py +148 -0
  261. sklearnex/svm/svc.py +360 -0
  262. sklearnex/svm/svr.py +149 -0
  263. sklearnex/svm/tests/test_svm.py +93 -0
  264. sklearnex/tests/_utils.py +328 -0
  265. sklearnex/tests/_utils_spmd.py +198 -0
  266. sklearnex/tests/test_common.py +54 -0
  267. sklearnex/tests/test_config.py +43 -0
  268. sklearnex/tests/test_memory_usage.py +291 -0
  269. sklearnex/tests/test_monkeypatch.py +276 -0
  270. sklearnex/tests/test_n_jobs_support.py +103 -0
  271. sklearnex/tests/test_parallel.py +48 -0
  272. sklearnex/tests/test_patching.py +385 -0
  273. sklearnex/tests/test_run_to_run_stability.py +296 -0
  274. sklearnex/utils/__init__.py +19 -0
  275. sklearnex/utils/_array_api.py +82 -0
  276. sklearnex/utils/parallel.py +59 -0
  277. sklearnex/utils/tests/test_finite.py +89 -0
  278. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,2016 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABC
20
+
21
+ import numpy as np
22
+ from scipy import sparse as sp
23
+ from sklearn.base import clone
24
+ from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
25
+ from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
26
+ from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
27
+ from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
28
+ from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
29
+ from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
30
+ from sklearn.ensemble._forest import _get_n_samples_bootstrap
31
+ from sklearn.exceptions import DataConversionWarning
32
+ from sklearn.metrics import accuracy_score, r2_score
33
+ from sklearn.tree import (
34
+ DecisionTreeClassifier,
35
+ DecisionTreeRegressor,
36
+ ExtraTreeClassifier,
37
+ ExtraTreeRegressor,
38
+ )
39
+ from sklearn.tree._tree import Tree
40
+ from sklearn.utils import check_random_state, deprecated
41
+ from sklearn.utils.validation import (
42
+ _check_sample_weight,
43
+ check_array,
44
+ check_is_fitted,
45
+ check_X_y,
46
+ )
47
+
48
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
49
+ from daal4py.sklearn._utils import (
50
+ check_tree_nodes,
51
+ daal_check_version,
52
+ sklearn_check_version,
53
+ )
54
+ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
55
+ from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
56
+ from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
57
+ from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
58
+ from onedal.primitives import get_tree_state_cls, get_tree_state_reg
59
+ from onedal.utils import _num_features, _num_samples
60
+
61
+ from .._device_offload import dispatch, wrap_output_data
62
+ from .._utils import PatchingConditionsChain
63
+ from ..utils._array_api import get_namespace
64
+
65
+ if sklearn_check_version("1.2"):
66
+ from sklearn.utils._param_validation import Interval
67
+ if sklearn_check_version("1.4"):
68
+ from daal4py.sklearn.utils import _assert_all_finite
69
+
70
+
71
+ class BaseForest(ABC):
72
+ _onedal_factory = None
73
+
74
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
75
+ X, y = self._validate_data(
76
+ X,
77
+ y,
78
+ multi_output=True,
79
+ accept_sparse=False,
80
+ dtype=[np.float64, np.float32],
81
+ force_all_finite=False,
82
+ ensure_2d=True,
83
+ )
84
+
85
+ if sample_weight is not None:
86
+ sample_weight = _check_sample_weight(sample_weight, X)
87
+
88
+ if y.ndim == 2 and y.shape[1] == 1:
89
+ warnings.warn(
90
+ "A column-vector y was passed when a 1d array was"
91
+ " expected. Please change the shape of y to "
92
+ "(n_samples,), for example using ravel().",
93
+ DataConversionWarning,
94
+ stacklevel=2,
95
+ )
96
+
97
+ if y.ndim == 1:
98
+ # reshape is necessary to preserve the data contiguity against vs
99
+ # [:, np.newaxis] that does not.
100
+ y = np.reshape(y, (-1, 1))
101
+
102
+ self._n_samples, self.n_outputs_ = y.shape
103
+
104
+ y, expanded_class_weight = self._validate_y_class_weight(y)
105
+
106
+ if expanded_class_weight is not None:
107
+ if sample_weight is not None:
108
+ sample_weight = sample_weight * expanded_class_weight
109
+ else:
110
+ sample_weight = expanded_class_weight
111
+ if sample_weight is not None:
112
+ sample_weight = [sample_weight]
113
+
114
+ onedal_params = {
115
+ "n_estimators": self.n_estimators,
116
+ "criterion": self.criterion,
117
+ "max_depth": self.max_depth,
118
+ "min_samples_split": self.min_samples_split,
119
+ "min_samples_leaf": self.min_samples_leaf,
120
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
121
+ "max_features": self._to_absolute_max_features(
122
+ self.max_features, self.n_features_in_
123
+ ),
124
+ "max_leaf_nodes": self.max_leaf_nodes,
125
+ "min_impurity_decrease": self.min_impurity_decrease,
126
+ "bootstrap": self.bootstrap,
127
+ "oob_score": self.oob_score,
128
+ "n_jobs": self.n_jobs,
129
+ "random_state": self.random_state,
130
+ "verbose": self.verbose,
131
+ "warm_start": self.warm_start,
132
+ "error_metric_mode": self._err if self.oob_score else "none",
133
+ "variable_importance_mode": "mdi",
134
+ "class_weight": self.class_weight,
135
+ "max_bins": self.max_bins,
136
+ "min_bin_size": self.min_bin_size,
137
+ "max_samples": self.max_samples,
138
+ }
139
+
140
+ if not sklearn_check_version("1.0"):
141
+ onedal_params["min_impurity_split"] = self.min_impurity_split
142
+ else:
143
+ onedal_params["min_impurity_split"] = None
144
+
145
+ # Lazy evaluation of estimators_
146
+ self._cached_estimators_ = None
147
+
148
+ # Compute
149
+ self._onedal_estimator = self._onedal_factory(**onedal_params)
150
+ self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
151
+
152
+ self._save_attributes()
153
+
154
+ # Decapsulate classes_ attributes
155
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
156
+ self.n_classes_ = self.n_classes_[0]
157
+ self.classes_ = self.classes_[0]
158
+
159
+ return self
160
+
161
+ def _save_attributes(self):
162
+ if self.oob_score:
163
+ self.oob_score_ = self._onedal_estimator.oob_score_
164
+ if hasattr(self._onedal_estimator, "oob_prediction_"):
165
+ self.oob_prediction_ = self._onedal_estimator.oob_prediction_
166
+ if hasattr(self._onedal_estimator, "oob_decision_function_"):
167
+ self.oob_decision_function_ = (
168
+ self._onedal_estimator.oob_decision_function_
169
+ )
170
+ if self.bootstrap:
171
+ self._n_samples_bootstrap = max(
172
+ round(
173
+ self._onedal_estimator.observations_per_tree_fraction
174
+ * self._n_samples
175
+ ),
176
+ 1,
177
+ )
178
+ else:
179
+ self._n_samples_bootstrap = None
180
+ self._validate_estimator()
181
+ return self
182
+
183
+ def _to_absolute_max_features(self, max_features, n_features):
184
+ if max_features is None:
185
+ return n_features
186
+ if isinstance(max_features, str):
187
+ if max_features == "auto":
188
+ if not sklearn_check_version("1.3"):
189
+ if sklearn_check_version("1.1"):
190
+ warnings.warn(
191
+ "`max_features='auto'` has been deprecated in 1.1 "
192
+ "and will be removed in 1.3. To keep the past behaviour, "
193
+ "explicitly set `max_features=1.0` or remove this "
194
+ "parameter as it is also the default value for "
195
+ "RandomForestRegressors and ExtraTreesRegressors.",
196
+ FutureWarning,
197
+ )
198
+ return (
199
+ max(1, int(np.sqrt(n_features)))
200
+ if isinstance(self, ForestClassifier)
201
+ else n_features
202
+ )
203
+ if max_features == "sqrt":
204
+ return max(1, int(np.sqrt(n_features)))
205
+ if max_features == "log2":
206
+ return max(1, int(np.log2(n_features)))
207
+ allowed_string_values = (
208
+ '"sqrt" or "log2"'
209
+ if sklearn_check_version("1.3")
210
+ else '"auto", "sqrt" or "log2"'
211
+ )
212
+ raise ValueError(
213
+ "Invalid value for max_features. Allowed string "
214
+ f"values are {allowed_string_values}."
215
+ )
216
+ if isinstance(max_features, (numbers.Integral, np.integer)):
217
+ return max_features
218
+ if max_features > 0.0:
219
+ return max(1, int(max_features * n_features))
220
+ return 0
221
+
222
+ def _check_parameters(self):
223
+ if isinstance(self.min_samples_leaf, numbers.Integral):
224
+ if not 1 <= self.min_samples_leaf:
225
+ raise ValueError(
226
+ "min_samples_leaf must be at least 1 "
227
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
228
+ )
229
+ else: # float
230
+ if not 0.0 < self.min_samples_leaf <= 0.5:
231
+ raise ValueError(
232
+ "min_samples_leaf must be at least 1 "
233
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
234
+ )
235
+ if isinstance(self.min_samples_split, numbers.Integral):
236
+ if not 2 <= self.min_samples_split:
237
+ raise ValueError(
238
+ "min_samples_split must be an integer "
239
+ "greater than 1 or a float in (0.0, 1.0]; "
240
+ "got the integer %s" % self.min_samples_split
241
+ )
242
+ else: # float
243
+ if not 0.0 < self.min_samples_split <= 1.0:
244
+ raise ValueError(
245
+ "min_samples_split must be an integer "
246
+ "greater than 1 or a float in (0.0, 1.0]; "
247
+ "got the float %s" % self.min_samples_split
248
+ )
249
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
250
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
251
+ if hasattr(self, "min_impurity_split"):
252
+ warnings.warn(
253
+ "The min_impurity_split parameter is deprecated. "
254
+ "Its default value has changed from 1e-7 to 0 in "
255
+ "version 0.23, and it will be removed in 0.25. "
256
+ "Use the min_impurity_decrease parameter instead.",
257
+ FutureWarning,
258
+ )
259
+
260
+ if getattr(self, "min_impurity_split") < 0.0:
261
+ raise ValueError(
262
+ "min_impurity_split must be greater than " "or equal to 0"
263
+ )
264
+ if self.min_impurity_decrease < 0.0:
265
+ raise ValueError(
266
+ "min_impurity_decrease must be greater than " "or equal to 0"
267
+ )
268
+ if self.max_leaf_nodes is not None:
269
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
270
+ raise ValueError(
271
+ "max_leaf_nodes must be integral number but was "
272
+ "%r" % self.max_leaf_nodes
273
+ )
274
+ if self.max_leaf_nodes < 2:
275
+ raise ValueError(
276
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
277
+ self.max_leaf_nodes
278
+ )
279
+ )
280
+ if isinstance(self.max_bins, numbers.Integral):
281
+ if not 2 <= self.max_bins:
282
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
283
+ else:
284
+ raise ValueError(
285
+ "max_bins must be integral number but was " "%r" % self.max_bins
286
+ )
287
+ if isinstance(self.min_bin_size, numbers.Integral):
288
+ if not 1 <= self.min_bin_size:
289
+ raise ValueError(
290
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
291
+ )
292
+ else:
293
+ raise ValueError(
294
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
295
+ )
296
+
297
+ @property
298
+ def estimators_(self):
299
+ if hasattr(self, "_cached_estimators_"):
300
+ if self._cached_estimators_ is None:
301
+ self._estimators_()
302
+ return self._cached_estimators_
303
+ else:
304
+ raise AttributeError(
305
+ f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
306
+ )
307
+
308
+ @estimators_.setter
309
+ def estimators_(self, estimators):
310
+ # Needed to allow for proper sklearn operation in fallback mode
311
+ self._cached_estimators_ = estimators
312
+
313
+ def _estimators_(self):
314
+ # _estimators_ should only be called if _onedal_estimator exists
315
+ check_is_fitted(self, "_onedal_estimator")
316
+ if hasattr(self, "n_classes_"):
317
+ n_classes_ = (
318
+ self.n_classes_
319
+ if isinstance(self.n_classes_, int)
320
+ else self.n_classes_[0]
321
+ )
322
+ else:
323
+ n_classes_ = 1
324
+
325
+ # convert model to estimators
326
+ params = {
327
+ "criterion": self._onedal_estimator.criterion,
328
+ "max_depth": self._onedal_estimator.max_depth,
329
+ "min_samples_split": self._onedal_estimator.min_samples_split,
330
+ "min_samples_leaf": self._onedal_estimator.min_samples_leaf,
331
+ "min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
332
+ "max_features": self._onedal_estimator.max_features,
333
+ "max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
334
+ "min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
335
+ "random_state": None,
336
+ }
337
+ if not sklearn_check_version("1.0"):
338
+ params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
339
+ est = self.estimator.__class__(**params)
340
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
341
+ # oneAPI Data Analytics Library solution
342
+ estimators_ = []
343
+
344
+ random_state_checked = check_random_state(self.random_state)
345
+
346
+ for i in range(self._onedal_estimator.n_estimators):
347
+ est_i = clone(est)
348
+ est_i.set_params(
349
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
350
+ )
351
+ if sklearn_check_version("1.0"):
352
+ est_i.n_features_in_ = self.n_features_in_
353
+ else:
354
+ est_i.n_features_ = self.n_features_in_
355
+ est_i.n_outputs_ = self.n_outputs_
356
+ est_i.n_classes_ = n_classes_
357
+ tree_i_state_class = self._get_tree_state(
358
+ self._onedal_estimator._onedal_model, i, n_classes_
359
+ )
360
+ tree_i_state_dict = {
361
+ "max_depth": tree_i_state_class.max_depth,
362
+ "node_count": tree_i_state_class.node_count,
363
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
364
+ "values": tree_i_state_class.value_ar,
365
+ }
366
+ est_i.tree_ = Tree(
367
+ self.n_features_in_,
368
+ np.array([n_classes_], dtype=np.intp),
369
+ self.n_outputs_,
370
+ )
371
+ est_i.tree_.__setstate__(tree_i_state_dict)
372
+ estimators_.append(est_i)
373
+
374
+ self._cached_estimators_ = estimators_
375
+
376
+ if sklearn_check_version("1.0"):
377
+
378
+ @deprecated(
379
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
380
+ "removed in 1.2. Use `n_features_in_` instead."
381
+ )
382
+ @property
383
+ def n_features_(self):
384
+ return self.n_features_in_
385
+
386
+ if not sklearn_check_version("1.2"):
387
+
388
+ @property
389
+ def base_estimator(self):
390
+ return self.estimator
391
+
392
+ @base_estimator.setter
393
+ def base_estimator(self, estimator):
394
+ self.estimator = estimator
395
+
396
+
397
+ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
398
+ # Surprisingly, even though scikit-learn warns against using
399
+ # their ForestClassifier directly, it actually has a more stable
400
+ # API than the user-facing objects (over time). If they change it
401
+ # significantly at some point then this may need to be versioned.
402
+
403
+ _err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
404
+ _get_tree_state = staticmethod(get_tree_state_cls)
405
+
406
+ def __init__(
407
+ self,
408
+ estimator,
409
+ n_estimators=100,
410
+ *,
411
+ estimator_params=tuple(),
412
+ bootstrap=False,
413
+ oob_score=False,
414
+ n_jobs=None,
415
+ random_state=None,
416
+ verbose=0,
417
+ warm_start=False,
418
+ class_weight=None,
419
+ max_samples=None,
420
+ ):
421
+ super().__init__(
422
+ estimator,
423
+ n_estimators=n_estimators,
424
+ estimator_params=estimator_params,
425
+ bootstrap=bootstrap,
426
+ oob_score=oob_score,
427
+ n_jobs=n_jobs,
428
+ random_state=random_state,
429
+ verbose=verbose,
430
+ warm_start=warm_start,
431
+ class_weight=class_weight,
432
+ max_samples=max_samples,
433
+ )
434
+
435
+ # The estimator is checked against the class attribute for conformance.
436
+ # This should only trigger if the user uses this class directly.
437
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
438
+ self._onedal_factory, onedal_RandomForestClassifier
439
+ ):
440
+ self._onedal_factory = onedal_RandomForestClassifier
441
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
442
+ self._onedal_factory, onedal_ExtraTreesClassifier
443
+ ):
444
+ self._onedal_factory = onedal_ExtraTreesClassifier
445
+
446
+ if self._onedal_factory is None:
447
+ raise TypeError(f" oneDAL estimator has not been set.")
448
+
449
+ def _estimators_(self):
450
+ super()._estimators_()
451
+ classes_ = self.classes_[0]
452
+ for est in self._cached_estimators_:
453
+ est.classes_ = classes_
454
+
455
+ def fit(self, X, y, sample_weight=None):
456
+ dispatch(
457
+ self,
458
+ "fit",
459
+ {
460
+ "onedal": self.__class__._onedal_fit,
461
+ "sklearn": sklearn_ForestClassifier.fit,
462
+ },
463
+ X,
464
+ y,
465
+ sample_weight,
466
+ )
467
+ return self
468
+
469
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
470
+ if sp.issparse(y):
471
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
472
+
473
+ if sklearn_check_version("1.2"):
474
+ self._validate_params()
475
+ else:
476
+ self._check_parameters()
477
+
478
+ if not self.bootstrap and self.oob_score:
479
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
480
+
481
+ patching_status.and_conditions(
482
+ [
483
+ (
484
+ self.oob_score
485
+ and daal_check_version((2021, "P", 500))
486
+ or not self.oob_score,
487
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
488
+ ),
489
+ (self.warm_start is False, "Warm start is not supported."),
490
+ (
491
+ self.criterion == "gini",
492
+ f"'{self.criterion}' criterion is not supported. "
493
+ "Only 'gini' criterion is supported.",
494
+ ),
495
+ (
496
+ self.ccp_alpha == 0.0,
497
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
498
+ ),
499
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
500
+ (
501
+ self.n_estimators <= 6024,
502
+ "More than 6024 estimators is not supported.",
503
+ ),
504
+ ]
505
+ )
506
+
507
+ if self.bootstrap:
508
+ patching_status.and_conditions(
509
+ [
510
+ (
511
+ self.class_weight != "balanced_subsample",
512
+ "'balanced_subsample' for class_weight is not supported",
513
+ )
514
+ ]
515
+ )
516
+
517
+ if patching_status.get_status() and sklearn_check_version("1.4"):
518
+ try:
519
+ _assert_all_finite(X)
520
+ input_is_finite = True
521
+ except ValueError:
522
+ input_is_finite = False
523
+ patching_status.and_conditions(
524
+ [
525
+ (input_is_finite, "Non-finite input is not supported."),
526
+ (
527
+ self.monotonic_cst is None,
528
+ "Monotonicity constraints are not supported.",
529
+ ),
530
+ ]
531
+ )
532
+
533
+ if patching_status.get_status():
534
+ X, y = check_X_y(
535
+ X,
536
+ y,
537
+ multi_output=True,
538
+ accept_sparse=True,
539
+ dtype=[np.float64, np.float32],
540
+ force_all_finite=False,
541
+ )
542
+
543
+ if y.ndim == 2 and y.shape[1] == 1:
544
+ warnings.warn(
545
+ "A column-vector y was passed when a 1d array was"
546
+ " expected. Please change the shape of y to "
547
+ "(n_samples,), for example using ravel().",
548
+ DataConversionWarning,
549
+ stacklevel=2,
550
+ )
551
+
552
+ if y.ndim == 1:
553
+ y = np.reshape(y, (-1, 1))
554
+
555
+ self.n_outputs_ = y.shape[1]
556
+
557
+ patching_status.and_conditions(
558
+ [
559
+ (
560
+ self.n_outputs_ == 1,
561
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
562
+ ),
563
+ (
564
+ y.dtype in [np.float32, np.float64, np.int32, np.int64],
565
+ f"Datatype ({y.dtype}) for y is not supported.",
566
+ ),
567
+ ]
568
+ )
569
+ # TODO: Fix to support integers as input
570
+
571
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
572
+
573
+ if not self.bootstrap and self.max_samples is not None:
574
+ raise ValueError(
575
+ "`max_sample` cannot be set if `bootstrap=False`. "
576
+ "Either switch to `bootstrap=True` or set "
577
+ "`max_sample=None`."
578
+ )
579
+
580
+ if (
581
+ patching_status.get_status()
582
+ and (self.random_state is not None)
583
+ and (not daal_check_version((2024, "P", 0)))
584
+ ):
585
+ warnings.warn(
586
+ "Setting 'random_state' value is not supported. "
587
+ "State set by oneDAL to default value (777).",
588
+ RuntimeWarning,
589
+ )
590
+
591
+ return patching_status, X, y, sample_weight
592
+
593
+ @wrap_output_data
594
+ def predict(self, X):
595
+ return dispatch(
596
+ self,
597
+ "predict",
598
+ {
599
+ "onedal": self.__class__._onedal_predict,
600
+ "sklearn": sklearn_ForestClassifier.predict,
601
+ },
602
+ X,
603
+ )
604
+
605
+ @wrap_output_data
606
+ def predict_proba(self, X):
607
+ # TODO:
608
+ # _check_proba()
609
+ # self._check_proba()
610
+ if sklearn_check_version("1.0"):
611
+ self._check_feature_names(X, reset=False)
612
+ if hasattr(self, "n_features_in_"):
613
+ try:
614
+ num_features = _num_features(X)
615
+ except TypeError:
616
+ num_features = _num_samples(X)
617
+ if num_features != self.n_features_in_:
618
+ raise ValueError(
619
+ (
620
+ f"X has {num_features} features, "
621
+ f"but {self.__class__.__name__} is expecting "
622
+ f"{self.n_features_in_} features as input"
623
+ )
624
+ )
625
+ return dispatch(
626
+ self,
627
+ "predict_proba",
628
+ {
629
+ "onedal": self.__class__._onedal_predict_proba,
630
+ "sklearn": sklearn_ForestClassifier.predict_proba,
631
+ },
632
+ X,
633
+ )
634
+
635
+ def predict_log_proba(self, X):
636
+ xp, _ = get_namespace(X)
637
+ proba = self.predict_proba(X)
638
+
639
+ if self.n_outputs_ == 1:
640
+ return xp.log(proba)
641
+
642
+ else:
643
+ for k in range(self.n_outputs_):
644
+ proba[k] = xp.log(proba[k])
645
+
646
+ return proba
647
+
648
+ @wrap_output_data
649
+ def score(self, X, y, sample_weight=None):
650
+ return dispatch(
651
+ self,
652
+ "score",
653
+ {
654
+ "onedal": self.__class__._onedal_score,
655
+ "sklearn": sklearn_ForestClassifier.score,
656
+ },
657
+ X,
658
+ y,
659
+ sample_weight=sample_weight,
660
+ )
661
+
662
+ fit.__doc__ = sklearn_ForestClassifier.fit.__doc__
663
+ predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
664
+ predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
665
+ predict_log_proba.__doc__ = sklearn_ForestClassifier.predict_log_proba.__doc__
666
+ score.__doc__ = sklearn_ForestClassifier.score.__doc__
667
+
668
+ def _onedal_cpu_supported(self, method_name, *data):
669
+ class_name = self.__class__.__name__
670
+ patching_status = PatchingConditionsChain(
671
+ f"sklearn.ensemble.{class_name}.{method_name}"
672
+ )
673
+
674
+ if method_name == "fit":
675
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
676
+ patching_status, *data
677
+ )
678
+
679
+ patching_status.and_conditions(
680
+ [
681
+ (
682
+ daal_check_version((2023, "P", 200))
683
+ or self.estimator.__class__ == DecisionTreeClassifier,
684
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
685
+ ),
686
+ (
687
+ not sp.issparse(sample_weight),
688
+ "sample_weight is sparse. " "Sparse input is not supported.",
689
+ ),
690
+ ]
691
+ )
692
+
693
+ elif method_name in ["predict", "predict_proba", "score"]:
694
+ X = data[0]
695
+
696
+ patching_status.and_conditions(
697
+ [
698
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
699
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
700
+ (self.warm_start is False, "Warm start is not supported."),
701
+ (
702
+ daal_check_version((2023, "P", 100))
703
+ or self.estimator.__class__ == DecisionTreeClassifier,
704
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
705
+ ),
706
+ ]
707
+ )
708
+
709
+ if method_name == "predict_proba":
710
+ patching_status.and_conditions(
711
+ [
712
+ (
713
+ daal_check_version((2021, "P", 400)),
714
+ "oneDAL version is lower than 2021.4.",
715
+ )
716
+ ]
717
+ )
718
+
719
+ if hasattr(self, "n_outputs_"):
720
+ patching_status.and_conditions(
721
+ [
722
+ (
723
+ self.n_outputs_ == 1,
724
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
725
+ ),
726
+ ]
727
+ )
728
+
729
+ else:
730
+ raise RuntimeError(
731
+ f"Unknown method {method_name} in {self.__class__.__name__}"
732
+ )
733
+
734
+ return patching_status
735
+
736
+ def _onedal_gpu_supported(self, method_name, *data):
737
+ class_name = self.__class__.__name__
738
+ patching_status = PatchingConditionsChain(
739
+ f"sklearn.ensemble.{class_name}.{method_name}"
740
+ )
741
+
742
+ if method_name == "fit":
743
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
744
+ patching_status, *data
745
+ )
746
+
747
+ patching_status.and_conditions(
748
+ [
749
+ (
750
+ daal_check_version((2023, "P", 100))
751
+ or self.estimator.__class__ == DecisionTreeClassifier,
752
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
753
+ ),
754
+ (
755
+ not self.oob_score,
756
+ "oob_scores using r2 or accuracy not implemented.",
757
+ ),
758
+ (sample_weight is None, "sample_weight is not supported."),
759
+ ]
760
+ )
761
+
762
+ elif method_name in ["predict", "predict_proba", "score"]:
763
+ X = data[0]
764
+
765
+ patching_status.and_conditions(
766
+ [
767
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
768
+ (
769
+ not sp.issparse(X),
770
+ "X is sparse. Sparse input is not supported.",
771
+ ),
772
+ (self.warm_start is False, "Warm start is not supported."),
773
+ (
774
+ daal_check_version((2023, "P", 100)),
775
+ "ExtraTrees supported starting from oneDAL version 2023.1",
776
+ ),
777
+ ]
778
+ )
779
+ if hasattr(self, "n_outputs_"):
780
+ patching_status.and_conditions(
781
+ [
782
+ (
783
+ self.n_outputs_ == 1,
784
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
785
+ ),
786
+ ]
787
+ )
788
+
789
+ else:
790
+ raise RuntimeError(
791
+ f"Unknown method {method_name} in {self.__class__.__name__}"
792
+ )
793
+
794
+ return patching_status
795
+
796
+ def _onedal_predict(self, X, queue=None):
797
+ check_is_fitted(self, "_onedal_estimator")
798
+
799
+ if sklearn_check_version("1.0"):
800
+ X = self._validate_data(
801
+ X,
802
+ dtype=[np.float64, np.float32],
803
+ force_all_finite=False,
804
+ reset=False,
805
+ ensure_2d=True,
806
+ )
807
+ else:
808
+ X = check_array(
809
+ X,
810
+ dtype=[np.float64, np.float32],
811
+ force_all_finite=False,
812
+ ) # Warning, order of dtype matters
813
+ self._check_n_features(X, reset=False)
814
+
815
+ res = self._onedal_estimator.predict(X, queue=queue)
816
+ return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
817
+
818
+ def _onedal_predict_proba(self, X, queue=None):
819
+ check_is_fitted(self, "_onedal_estimator")
820
+
821
+ if sklearn_check_version("1.0"):
822
+ X = self._validate_data(
823
+ X,
824
+ dtype=[np.float64, np.float32],
825
+ force_all_finite=False,
826
+ reset=False,
827
+ ensure_2d=True,
828
+ )
829
+ else:
830
+ X = check_array(
831
+ X,
832
+ dtype=[np.float64, np.float32],
833
+ force_all_finite=False,
834
+ ) # Warning, order of dtype matters
835
+ self._check_n_features(X, reset=False)
836
+
837
+ return self._onedal_estimator.predict_proba(X, queue=queue)
838
+
839
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
840
+ return accuracy_score(
841
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
842
+ )
843
+
844
+
845
+ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
846
+ _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
847
+ _get_tree_state = staticmethod(get_tree_state_reg)
848
+
849
+ def __init__(
850
+ self,
851
+ estimator,
852
+ n_estimators=100,
853
+ *,
854
+ estimator_params=tuple(),
855
+ bootstrap=False,
856
+ oob_score=False,
857
+ n_jobs=None,
858
+ random_state=None,
859
+ verbose=0,
860
+ warm_start=False,
861
+ max_samples=None,
862
+ ):
863
+ super().__init__(
864
+ estimator,
865
+ n_estimators=n_estimators,
866
+ estimator_params=estimator_params,
867
+ bootstrap=bootstrap,
868
+ oob_score=oob_score,
869
+ n_jobs=n_jobs,
870
+ random_state=random_state,
871
+ verbose=verbose,
872
+ warm_start=warm_start,
873
+ max_samples=max_samples,
874
+ )
875
+
876
+ # The splitter is checked against the class attribute for conformance
877
+ # This should only trigger if the user uses this class directly.
878
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
879
+ self._onedal_factory, onedal_RandomForestRegressor
880
+ ):
881
+ self._onedal_factory = onedal_RandomForestRegressor
882
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
883
+ self._onedal_factory, onedal_ExtraTreesRegressor
884
+ ):
885
+ self._onedal_factory = onedal_ExtraTreesRegressor
886
+
887
+ if self._onedal_factory is None:
888
+ raise TypeError(f" oneDAL estimator has not been set.")
889
+
890
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
891
+ if sp.issparse(y):
892
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
893
+
894
+ if sklearn_check_version("1.2"):
895
+ self._validate_params()
896
+ else:
897
+ self._check_parameters()
898
+
899
+ if not self.bootstrap and self.oob_score:
900
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
901
+
902
+ if sklearn_check_version("1.0") and self.criterion == "mse":
903
+ warnings.warn(
904
+ "Criterion 'mse' was deprecated in v1.0 and will be "
905
+ "removed in version 1.2. Use `criterion='squared_error'` "
906
+ "which is equivalent.",
907
+ FutureWarning,
908
+ )
909
+
910
+ patching_status.and_conditions(
911
+ [
912
+ (
913
+ self.oob_score
914
+ and daal_check_version((2021, "P", 500))
915
+ or not self.oob_score,
916
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
917
+ ),
918
+ (self.warm_start is False, "Warm start is not supported."),
919
+ (
920
+ self.criterion in ["mse", "squared_error"],
921
+ f"'{self.criterion}' criterion is not supported. "
922
+ "Only 'mse' and 'squared_error' criteria are supported.",
923
+ ),
924
+ (
925
+ self.ccp_alpha == 0.0,
926
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
927
+ ),
928
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
929
+ (
930
+ self.n_estimators <= 6024,
931
+ "More than 6024 estimators is not supported.",
932
+ ),
933
+ ]
934
+ )
935
+
936
+ if patching_status.get_status() and sklearn_check_version("1.4"):
937
+ try:
938
+ _assert_all_finite(X)
939
+ input_is_finite = True
940
+ except ValueError:
941
+ input_is_finite = False
942
+ patching_status.and_conditions(
943
+ [
944
+ (input_is_finite, "Non-finite input is not supported."),
945
+ (
946
+ self.monotonic_cst is None,
947
+ "Monotonicity constraints are not supported.",
948
+ ),
949
+ ]
950
+ )
951
+
952
+ if patching_status.get_status():
953
+ X, y = check_X_y(
954
+ X,
955
+ y,
956
+ multi_output=True,
957
+ accept_sparse=True,
958
+ dtype=[np.float64, np.float32],
959
+ force_all_finite=False,
960
+ )
961
+
962
+ if y.ndim == 2 and y.shape[1] == 1:
963
+ warnings.warn(
964
+ "A column-vector y was passed when a 1d array was"
965
+ " expected. Please change the shape of y to "
966
+ "(n_samples,), for example using ravel().",
967
+ DataConversionWarning,
968
+ stacklevel=2,
969
+ )
970
+
971
+ if y.ndim == 1:
972
+ # reshape is necessary to preserve the data contiguity against vs
973
+ # [:, np.newaxis] that does not.
974
+ y = np.reshape(y, (-1, 1))
975
+
976
+ self.n_outputs_ = y.shape[1]
977
+
978
+ patching_status.and_conditions(
979
+ [
980
+ (
981
+ self.n_outputs_ == 1,
982
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
983
+ )
984
+ ]
985
+ )
986
+
987
+ # Sklearn function used for doing checks on max_samples attribute
988
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
989
+
990
+ if not self.bootstrap and self.max_samples is not None:
991
+ raise ValueError(
992
+ "`max_sample` cannot be set if `bootstrap=False`. "
993
+ "Either switch to `bootstrap=True` or set "
994
+ "`max_sample=None`."
995
+ )
996
+
997
+ if (
998
+ patching_status.get_status()
999
+ and (self.random_state is not None)
1000
+ and (not daal_check_version((2024, "P", 0)))
1001
+ ):
1002
+ warnings.warn(
1003
+ "Setting 'random_state' value is not supported. "
1004
+ "State set by oneDAL to default value (777).",
1005
+ RuntimeWarning,
1006
+ )
1007
+
1008
+ return patching_status, X, y, sample_weight
1009
+
1010
+ def _onedal_cpu_supported(self, method_name, *data):
1011
+ class_name = self.__class__.__name__
1012
+ patching_status = PatchingConditionsChain(
1013
+ f"sklearn.ensemble.{class_name}.{method_name}"
1014
+ )
1015
+
1016
+ if method_name == "fit":
1017
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1018
+ patching_status, *data
1019
+ )
1020
+
1021
+ patching_status.and_conditions(
1022
+ [
1023
+ (
1024
+ daal_check_version((2023, "P", 200))
1025
+ or self.estimator.__class__ == DecisionTreeClassifier,
1026
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1027
+ ),
1028
+ (
1029
+ not sp.issparse(sample_weight),
1030
+ "sample_weight is sparse. " "Sparse input is not supported.",
1031
+ ),
1032
+ ]
1033
+ )
1034
+
1035
+ elif method_name in ["predict", "score"]:
1036
+ X = data[0]
1037
+
1038
+ patching_status.and_conditions(
1039
+ [
1040
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1041
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1042
+ (self.warm_start is False, "Warm start is not supported."),
1043
+ (
1044
+ daal_check_version((2023, "P", 200))
1045
+ or self.estimator.__class__ == DecisionTreeClassifier,
1046
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1047
+ ),
1048
+ ]
1049
+ )
1050
+ if hasattr(self, "n_outputs_"):
1051
+ patching_status.and_conditions(
1052
+ [
1053
+ (
1054
+ self.n_outputs_ == 1,
1055
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1056
+ ),
1057
+ ]
1058
+ )
1059
+
1060
+ else:
1061
+ raise RuntimeError(
1062
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1063
+ )
1064
+
1065
+ return patching_status
1066
+
1067
+ def _onedal_gpu_supported(self, method_name, *data):
1068
+ class_name = self.__class__.__name__
1069
+ patching_status = PatchingConditionsChain(
1070
+ f"sklearn.ensemble.{class_name}.{method_name}"
1071
+ )
1072
+
1073
+ if method_name == "fit":
1074
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1075
+ patching_status, *data
1076
+ )
1077
+
1078
+ patching_status.and_conditions(
1079
+ [
1080
+ (
1081
+ daal_check_version((2023, "P", 100))
1082
+ or self.estimator.__class__ == DecisionTreeClassifier,
1083
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1084
+ ),
1085
+ (not self.oob_score, "oob_score value is not sklearn conformant."),
1086
+ (sample_weight is None, "sample_weight is not supported."),
1087
+ ]
1088
+ )
1089
+
1090
+ elif method_name in ["predict", "score"]:
1091
+ X = data[0]
1092
+
1093
+ patching_status.and_conditions(
1094
+ [
1095
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1096
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1097
+ (self.warm_start is False, "Warm start is not supported."),
1098
+ (
1099
+ daal_check_version((2023, "P", 100))
1100
+ or self.estimator.__class__ == DecisionTreeClassifier,
1101
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1102
+ ),
1103
+ ]
1104
+ )
1105
+ if hasattr(self, "n_outputs_"):
1106
+ patching_status.and_conditions(
1107
+ [
1108
+ (
1109
+ self.n_outputs_ == 1,
1110
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1111
+ ),
1112
+ ]
1113
+ )
1114
+
1115
+ else:
1116
+ raise RuntimeError(
1117
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1118
+ )
1119
+
1120
+ return patching_status
1121
+
1122
+ def _onedal_predict(self, X, queue=None):
1123
+ check_is_fitted(self, "_onedal_estimator")
1124
+
1125
+ if sklearn_check_version("1.0"):
1126
+ X = self._validate_data(
1127
+ X,
1128
+ dtype=[np.float64, np.float32],
1129
+ force_all_finite=False,
1130
+ reset=False,
1131
+ ensure_2d=True,
1132
+ ) # Warning, order of dtype matters
1133
+ else:
1134
+ X = check_array(
1135
+ X, dtype=[np.float64, np.float32], force_all_finite=False
1136
+ ) # Warning, order of dtype matters
1137
+
1138
+ return self._onedal_estimator.predict(X, queue=queue)
1139
+
1140
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
1141
+ return r2_score(
1142
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
1143
+ )
1144
+
1145
+ def fit(self, X, y, sample_weight=None):
1146
+ dispatch(
1147
+ self,
1148
+ "fit",
1149
+ {
1150
+ "onedal": self.__class__._onedal_fit,
1151
+ "sklearn": sklearn_ForestRegressor.fit,
1152
+ },
1153
+ X,
1154
+ y,
1155
+ sample_weight,
1156
+ )
1157
+ return self
1158
+
1159
+ @wrap_output_data
1160
+ def predict(self, X):
1161
+ return dispatch(
1162
+ self,
1163
+ "predict",
1164
+ {
1165
+ "onedal": self.__class__._onedal_predict,
1166
+ "sklearn": sklearn_ForestRegressor.predict,
1167
+ },
1168
+ X,
1169
+ )
1170
+
1171
+ @wrap_output_data
1172
+ def score(self, X, y, sample_weight=None):
1173
+ return dispatch(
1174
+ self,
1175
+ "score",
1176
+ {
1177
+ "onedal": self.__class__._onedal_score,
1178
+ "sklearn": sklearn_ForestRegressor.score,
1179
+ },
1180
+ X,
1181
+ y,
1182
+ sample_weight=sample_weight,
1183
+ )
1184
+
1185
+ fit.__doc__ = sklearn_ForestRegressor.fit.__doc__
1186
+ predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1187
+ score.__doc__ = sklearn_ForestRegressor.score.__doc__
1188
+
1189
+
1190
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1191
+ class RandomForestClassifier(ForestClassifier):
1192
+ __doc__ = sklearn_RandomForestClassifier.__doc__
1193
+ _onedal_factory = onedal_RandomForestClassifier
1194
+
1195
+ if sklearn_check_version("1.2"):
1196
+ _parameter_constraints: dict = {
1197
+ **sklearn_RandomForestClassifier._parameter_constraints,
1198
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1199
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1200
+ }
1201
+
1202
+ if sklearn_check_version("1.4"):
1203
+
1204
+ def __init__(
1205
+ self,
1206
+ n_estimators=100,
1207
+ *,
1208
+ criterion="gini",
1209
+ max_depth=None,
1210
+ min_samples_split=2,
1211
+ min_samples_leaf=1,
1212
+ min_weight_fraction_leaf=0.0,
1213
+ max_features="sqrt",
1214
+ max_leaf_nodes=None,
1215
+ min_impurity_decrease=0.0,
1216
+ bootstrap=True,
1217
+ oob_score=False,
1218
+ n_jobs=None,
1219
+ random_state=None,
1220
+ verbose=0,
1221
+ warm_start=False,
1222
+ class_weight=None,
1223
+ ccp_alpha=0.0,
1224
+ max_samples=None,
1225
+ monotonic_cst=None,
1226
+ max_bins=256,
1227
+ min_bin_size=1,
1228
+ ):
1229
+ super().__init__(
1230
+ DecisionTreeClassifier(),
1231
+ n_estimators,
1232
+ estimator_params=(
1233
+ "criterion",
1234
+ "max_depth",
1235
+ "min_samples_split",
1236
+ "min_samples_leaf",
1237
+ "min_weight_fraction_leaf",
1238
+ "max_features",
1239
+ "max_leaf_nodes",
1240
+ "min_impurity_decrease",
1241
+ "random_state",
1242
+ "ccp_alpha",
1243
+ "monotonic_cst",
1244
+ ),
1245
+ bootstrap=bootstrap,
1246
+ oob_score=oob_score,
1247
+ n_jobs=n_jobs,
1248
+ random_state=random_state,
1249
+ verbose=verbose,
1250
+ warm_start=warm_start,
1251
+ class_weight=class_weight,
1252
+ max_samples=max_samples,
1253
+ )
1254
+
1255
+ self.criterion = criterion
1256
+ self.max_depth = max_depth
1257
+ self.min_samples_split = min_samples_split
1258
+ self.min_samples_leaf = min_samples_leaf
1259
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1260
+ self.max_features = max_features
1261
+ self.max_leaf_nodes = max_leaf_nodes
1262
+ self.min_impurity_decrease = min_impurity_decrease
1263
+ self.ccp_alpha = ccp_alpha
1264
+ self.max_bins = max_bins
1265
+ self.min_bin_size = min_bin_size
1266
+ self.monotonic_cst = monotonic_cst
1267
+
1268
+ elif sklearn_check_version("1.0"):
1269
+
1270
+ def __init__(
1271
+ self,
1272
+ n_estimators=100,
1273
+ *,
1274
+ criterion="gini",
1275
+ max_depth=None,
1276
+ min_samples_split=2,
1277
+ min_samples_leaf=1,
1278
+ min_weight_fraction_leaf=0.0,
1279
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1280
+ max_leaf_nodes=None,
1281
+ min_impurity_decrease=0.0,
1282
+ bootstrap=True,
1283
+ oob_score=False,
1284
+ n_jobs=None,
1285
+ random_state=None,
1286
+ verbose=0,
1287
+ warm_start=False,
1288
+ class_weight=None,
1289
+ ccp_alpha=0.0,
1290
+ max_samples=None,
1291
+ max_bins=256,
1292
+ min_bin_size=1,
1293
+ ):
1294
+ super().__init__(
1295
+ DecisionTreeClassifier(),
1296
+ n_estimators,
1297
+ estimator_params=(
1298
+ "criterion",
1299
+ "max_depth",
1300
+ "min_samples_split",
1301
+ "min_samples_leaf",
1302
+ "min_weight_fraction_leaf",
1303
+ "max_features",
1304
+ "max_leaf_nodes",
1305
+ "min_impurity_decrease",
1306
+ "random_state",
1307
+ "ccp_alpha",
1308
+ ),
1309
+ bootstrap=bootstrap,
1310
+ oob_score=oob_score,
1311
+ n_jobs=n_jobs,
1312
+ random_state=random_state,
1313
+ verbose=verbose,
1314
+ warm_start=warm_start,
1315
+ class_weight=class_weight,
1316
+ max_samples=max_samples,
1317
+ )
1318
+
1319
+ self.criterion = criterion
1320
+ self.max_depth = max_depth
1321
+ self.min_samples_split = min_samples_split
1322
+ self.min_samples_leaf = min_samples_leaf
1323
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1324
+ self.max_features = max_features
1325
+ self.max_leaf_nodes = max_leaf_nodes
1326
+ self.min_impurity_decrease = min_impurity_decrease
1327
+ self.ccp_alpha = ccp_alpha
1328
+ self.max_bins = max_bins
1329
+ self.min_bin_size = min_bin_size
1330
+
1331
+ else:
1332
+
1333
+ def __init__(
1334
+ self,
1335
+ n_estimators=100,
1336
+ *,
1337
+ criterion="gini",
1338
+ max_depth=None,
1339
+ min_samples_split=2,
1340
+ min_samples_leaf=1,
1341
+ min_weight_fraction_leaf=0.0,
1342
+ max_features="auto",
1343
+ max_leaf_nodes=None,
1344
+ min_impurity_decrease=0.0,
1345
+ min_impurity_split=None,
1346
+ bootstrap=True,
1347
+ oob_score=False,
1348
+ n_jobs=None,
1349
+ random_state=None,
1350
+ verbose=0,
1351
+ warm_start=False,
1352
+ class_weight=None,
1353
+ ccp_alpha=0.0,
1354
+ max_samples=None,
1355
+ max_bins=256,
1356
+ min_bin_size=1,
1357
+ ):
1358
+ super().__init__(
1359
+ DecisionTreeClassifier(),
1360
+ n_estimators,
1361
+ estimator_params=(
1362
+ "criterion",
1363
+ "max_depth",
1364
+ "min_samples_split",
1365
+ "min_samples_leaf",
1366
+ "min_weight_fraction_leaf",
1367
+ "max_features",
1368
+ "max_leaf_nodes",
1369
+ "min_impurity_decrease",
1370
+ "min_impurity_split",
1371
+ "random_state",
1372
+ "ccp_alpha",
1373
+ ),
1374
+ bootstrap=bootstrap,
1375
+ oob_score=oob_score,
1376
+ n_jobs=n_jobs,
1377
+ random_state=random_state,
1378
+ verbose=verbose,
1379
+ warm_start=warm_start,
1380
+ class_weight=class_weight,
1381
+ max_samples=max_samples,
1382
+ )
1383
+
1384
+ self.criterion = criterion
1385
+ self.max_depth = max_depth
1386
+ self.min_samples_split = min_samples_split
1387
+ self.min_samples_leaf = min_samples_leaf
1388
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1389
+ self.max_features = max_features
1390
+ self.max_leaf_nodes = max_leaf_nodes
1391
+ self.min_impurity_decrease = min_impurity_decrease
1392
+ self.min_impurity_split = min_impurity_split
1393
+ self.ccp_alpha = ccp_alpha
1394
+ self.max_bins = max_bins
1395
+ self.min_bin_size = min_bin_size
1396
+ self.max_bins = max_bins
1397
+ self.min_bin_size = min_bin_size
1398
+
1399
+
1400
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1401
+ class RandomForestRegressor(ForestRegressor):
1402
+ __doc__ = sklearn_RandomForestRegressor.__doc__
1403
+ _onedal_factory = onedal_RandomForestRegressor
1404
+
1405
+ if sklearn_check_version("1.2"):
1406
+ _parameter_constraints: dict = {
1407
+ **sklearn_RandomForestRegressor._parameter_constraints,
1408
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1409
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1410
+ }
1411
+
1412
+ if sklearn_check_version("1.4"):
1413
+
1414
+ def __init__(
1415
+ self,
1416
+ n_estimators=100,
1417
+ *,
1418
+ criterion="squared_error",
1419
+ max_depth=None,
1420
+ min_samples_split=2,
1421
+ min_samples_leaf=1,
1422
+ min_weight_fraction_leaf=0.0,
1423
+ max_features=1.0,
1424
+ max_leaf_nodes=None,
1425
+ min_impurity_decrease=0.0,
1426
+ bootstrap=True,
1427
+ oob_score=False,
1428
+ n_jobs=None,
1429
+ random_state=None,
1430
+ verbose=0,
1431
+ warm_start=False,
1432
+ ccp_alpha=0.0,
1433
+ max_samples=None,
1434
+ monotonic_cst=None,
1435
+ max_bins=256,
1436
+ min_bin_size=1,
1437
+ ):
1438
+ super().__init__(
1439
+ DecisionTreeRegressor(),
1440
+ n_estimators=n_estimators,
1441
+ estimator_params=(
1442
+ "criterion",
1443
+ "max_depth",
1444
+ "min_samples_split",
1445
+ "min_samples_leaf",
1446
+ "min_weight_fraction_leaf",
1447
+ "max_features",
1448
+ "max_leaf_nodes",
1449
+ "min_impurity_decrease",
1450
+ "random_state",
1451
+ "ccp_alpha",
1452
+ "monotonic_cst",
1453
+ ),
1454
+ bootstrap=bootstrap,
1455
+ oob_score=oob_score,
1456
+ n_jobs=n_jobs,
1457
+ random_state=random_state,
1458
+ verbose=verbose,
1459
+ warm_start=warm_start,
1460
+ max_samples=max_samples,
1461
+ )
1462
+
1463
+ self.criterion = criterion
1464
+ self.max_depth = max_depth
1465
+ self.min_samples_split = min_samples_split
1466
+ self.min_samples_leaf = min_samples_leaf
1467
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1468
+ self.max_features = max_features
1469
+ self.max_leaf_nodes = max_leaf_nodes
1470
+ self.min_impurity_decrease = min_impurity_decrease
1471
+ self.ccp_alpha = ccp_alpha
1472
+ self.max_bins = max_bins
1473
+ self.min_bin_size = min_bin_size
1474
+ self.monotonic_cst = monotonic_cst
1475
+
1476
+ elif sklearn_check_version("1.0"):
1477
+
1478
+ def __init__(
1479
+ self,
1480
+ n_estimators=100,
1481
+ *,
1482
+ criterion="squared_error",
1483
+ max_depth=None,
1484
+ min_samples_split=2,
1485
+ min_samples_leaf=1,
1486
+ min_weight_fraction_leaf=0.0,
1487
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1488
+ max_leaf_nodes=None,
1489
+ min_impurity_decrease=0.0,
1490
+ bootstrap=True,
1491
+ oob_score=False,
1492
+ n_jobs=None,
1493
+ random_state=None,
1494
+ verbose=0,
1495
+ warm_start=False,
1496
+ ccp_alpha=0.0,
1497
+ max_samples=None,
1498
+ max_bins=256,
1499
+ min_bin_size=1,
1500
+ ):
1501
+ super().__init__(
1502
+ DecisionTreeRegressor(),
1503
+ n_estimators=n_estimators,
1504
+ estimator_params=(
1505
+ "criterion",
1506
+ "max_depth",
1507
+ "min_samples_split",
1508
+ "min_samples_leaf",
1509
+ "min_weight_fraction_leaf",
1510
+ "max_features",
1511
+ "max_leaf_nodes",
1512
+ "min_impurity_decrease",
1513
+ "random_state",
1514
+ "ccp_alpha",
1515
+ ),
1516
+ bootstrap=bootstrap,
1517
+ oob_score=oob_score,
1518
+ n_jobs=n_jobs,
1519
+ random_state=random_state,
1520
+ verbose=verbose,
1521
+ warm_start=warm_start,
1522
+ max_samples=max_samples,
1523
+ )
1524
+
1525
+ self.criterion = criterion
1526
+ self.max_depth = max_depth
1527
+ self.min_samples_split = min_samples_split
1528
+ self.min_samples_leaf = min_samples_leaf
1529
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1530
+ self.max_features = max_features
1531
+ self.max_leaf_nodes = max_leaf_nodes
1532
+ self.min_impurity_decrease = min_impurity_decrease
1533
+ self.ccp_alpha = ccp_alpha
1534
+ self.max_bins = max_bins
1535
+ self.min_bin_size = min_bin_size
1536
+
1537
+ else:
1538
+
1539
+ def __init__(
1540
+ self,
1541
+ n_estimators=100,
1542
+ *,
1543
+ criterion="mse",
1544
+ max_depth=None,
1545
+ min_samples_split=2,
1546
+ min_samples_leaf=1,
1547
+ min_weight_fraction_leaf=0.0,
1548
+ max_features="auto",
1549
+ max_leaf_nodes=None,
1550
+ min_impurity_decrease=0.0,
1551
+ min_impurity_split=None,
1552
+ bootstrap=True,
1553
+ oob_score=False,
1554
+ n_jobs=None,
1555
+ random_state=None,
1556
+ verbose=0,
1557
+ warm_start=False,
1558
+ ccp_alpha=0.0,
1559
+ max_samples=None,
1560
+ max_bins=256,
1561
+ min_bin_size=1,
1562
+ ):
1563
+ super().__init__(
1564
+ DecisionTreeRegressor(),
1565
+ n_estimators=n_estimators,
1566
+ estimator_params=(
1567
+ "criterion",
1568
+ "max_depth",
1569
+ "min_samples_split",
1570
+ "min_samples_leaf",
1571
+ "min_weight_fraction_leaf",
1572
+ "max_features",
1573
+ "max_leaf_nodes",
1574
+ "min_impurity_decrease",
1575
+ "min_impurity_split" "random_state",
1576
+ "ccp_alpha",
1577
+ ),
1578
+ bootstrap=bootstrap,
1579
+ oob_score=oob_score,
1580
+ n_jobs=n_jobs,
1581
+ random_state=random_state,
1582
+ verbose=verbose,
1583
+ warm_start=warm_start,
1584
+ max_samples=max_samples,
1585
+ )
1586
+
1587
+ self.criterion = criterion
1588
+ self.max_depth = max_depth
1589
+ self.min_samples_split = min_samples_split
1590
+ self.min_samples_leaf = min_samples_leaf
1591
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1592
+ self.max_features = max_features
1593
+ self.max_leaf_nodes = max_leaf_nodes
1594
+ self.min_impurity_decrease = min_impurity_decrease
1595
+ self.min_impurity_split = min_impurity_split
1596
+ self.ccp_alpha = ccp_alpha
1597
+ self.max_bins = max_bins
1598
+ self.min_bin_size = min_bin_size
1599
+
1600
+
1601
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1602
+ class ExtraTreesClassifier(ForestClassifier):
1603
+ __doc__ = sklearn_ExtraTreesClassifier.__doc__
1604
+ _onedal_factory = onedal_ExtraTreesClassifier
1605
+
1606
+ if sklearn_check_version("1.2"):
1607
+ _parameter_constraints: dict = {
1608
+ **sklearn_ExtraTreesClassifier._parameter_constraints,
1609
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1610
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1611
+ }
1612
+
1613
+ if sklearn_check_version("1.4"):
1614
+
1615
+ def __init__(
1616
+ self,
1617
+ n_estimators=100,
1618
+ *,
1619
+ criterion="gini",
1620
+ max_depth=None,
1621
+ min_samples_split=2,
1622
+ min_samples_leaf=1,
1623
+ min_weight_fraction_leaf=0.0,
1624
+ max_features="sqrt",
1625
+ max_leaf_nodes=None,
1626
+ min_impurity_decrease=0.0,
1627
+ bootstrap=False,
1628
+ oob_score=False,
1629
+ n_jobs=None,
1630
+ random_state=None,
1631
+ verbose=0,
1632
+ warm_start=False,
1633
+ class_weight=None,
1634
+ ccp_alpha=0.0,
1635
+ max_samples=None,
1636
+ monotonic_cst=None,
1637
+ max_bins=256,
1638
+ min_bin_size=1,
1639
+ ):
1640
+ super().__init__(
1641
+ ExtraTreeClassifier(),
1642
+ n_estimators,
1643
+ estimator_params=(
1644
+ "criterion",
1645
+ "max_depth",
1646
+ "min_samples_split",
1647
+ "min_samples_leaf",
1648
+ "min_weight_fraction_leaf",
1649
+ "max_features",
1650
+ "max_leaf_nodes",
1651
+ "min_impurity_decrease",
1652
+ "random_state",
1653
+ "ccp_alpha",
1654
+ "monotonic_cst",
1655
+ ),
1656
+ bootstrap=bootstrap,
1657
+ oob_score=oob_score,
1658
+ n_jobs=n_jobs,
1659
+ random_state=random_state,
1660
+ verbose=verbose,
1661
+ warm_start=warm_start,
1662
+ class_weight=class_weight,
1663
+ max_samples=max_samples,
1664
+ )
1665
+
1666
+ self.criterion = criterion
1667
+ self.max_depth = max_depth
1668
+ self.min_samples_split = min_samples_split
1669
+ self.min_samples_leaf = min_samples_leaf
1670
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1671
+ self.max_features = max_features
1672
+ self.max_leaf_nodes = max_leaf_nodes
1673
+ self.min_impurity_decrease = min_impurity_decrease
1674
+ self.ccp_alpha = ccp_alpha
1675
+ self.max_bins = max_bins
1676
+ self.min_bin_size = min_bin_size
1677
+ self.monotonic_cst = monotonic_cst
1678
+
1679
+ elif sklearn_check_version("1.0"):
1680
+
1681
+ def __init__(
1682
+ self,
1683
+ n_estimators=100,
1684
+ *,
1685
+ criterion="gini",
1686
+ max_depth=None,
1687
+ min_samples_split=2,
1688
+ min_samples_leaf=1,
1689
+ min_weight_fraction_leaf=0.0,
1690
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1691
+ max_leaf_nodes=None,
1692
+ min_impurity_decrease=0.0,
1693
+ bootstrap=False,
1694
+ oob_score=False,
1695
+ n_jobs=None,
1696
+ random_state=None,
1697
+ verbose=0,
1698
+ warm_start=False,
1699
+ class_weight=None,
1700
+ ccp_alpha=0.0,
1701
+ max_samples=None,
1702
+ max_bins=256,
1703
+ min_bin_size=1,
1704
+ ):
1705
+ super().__init__(
1706
+ ExtraTreeClassifier(),
1707
+ n_estimators,
1708
+ estimator_params=(
1709
+ "criterion",
1710
+ "max_depth",
1711
+ "min_samples_split",
1712
+ "min_samples_leaf",
1713
+ "min_weight_fraction_leaf",
1714
+ "max_features",
1715
+ "max_leaf_nodes",
1716
+ "min_impurity_decrease",
1717
+ "random_state",
1718
+ "ccp_alpha",
1719
+ ),
1720
+ bootstrap=bootstrap,
1721
+ oob_score=oob_score,
1722
+ n_jobs=n_jobs,
1723
+ random_state=random_state,
1724
+ verbose=verbose,
1725
+ warm_start=warm_start,
1726
+ class_weight=class_weight,
1727
+ max_samples=max_samples,
1728
+ )
1729
+
1730
+ self.criterion = criterion
1731
+ self.max_depth = max_depth
1732
+ self.min_samples_split = min_samples_split
1733
+ self.min_samples_leaf = min_samples_leaf
1734
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1735
+ self.max_features = max_features
1736
+ self.max_leaf_nodes = max_leaf_nodes
1737
+ self.min_impurity_decrease = min_impurity_decrease
1738
+ self.ccp_alpha = ccp_alpha
1739
+ self.max_bins = max_bins
1740
+ self.min_bin_size = min_bin_size
1741
+
1742
+ else:
1743
+
1744
+ def __init__(
1745
+ self,
1746
+ n_estimators=100,
1747
+ *,
1748
+ criterion="gini",
1749
+ max_depth=None,
1750
+ min_samples_split=2,
1751
+ min_samples_leaf=1,
1752
+ min_weight_fraction_leaf=0.0,
1753
+ max_features="auto",
1754
+ max_leaf_nodes=None,
1755
+ min_impurity_decrease=0.0,
1756
+ min_impurity_split=None,
1757
+ bootstrap=False,
1758
+ oob_score=False,
1759
+ n_jobs=None,
1760
+ random_state=None,
1761
+ verbose=0,
1762
+ warm_start=False,
1763
+ class_weight=None,
1764
+ ccp_alpha=0.0,
1765
+ max_samples=None,
1766
+ max_bins=256,
1767
+ min_bin_size=1,
1768
+ ):
1769
+ super().__init__(
1770
+ ExtraTreeClassifier(),
1771
+ n_estimators,
1772
+ estimator_params=(
1773
+ "criterion",
1774
+ "max_depth",
1775
+ "min_samples_split",
1776
+ "min_samples_leaf",
1777
+ "min_weight_fraction_leaf",
1778
+ "max_features",
1779
+ "max_leaf_nodes",
1780
+ "min_impurity_decrease",
1781
+ "min_impurity_split",
1782
+ "random_state",
1783
+ "ccp_alpha",
1784
+ ),
1785
+ bootstrap=bootstrap,
1786
+ oob_score=oob_score,
1787
+ n_jobs=n_jobs,
1788
+ random_state=random_state,
1789
+ verbose=verbose,
1790
+ warm_start=warm_start,
1791
+ class_weight=class_weight,
1792
+ max_samples=max_samples,
1793
+ )
1794
+
1795
+ self.criterion = criterion
1796
+ self.max_depth = max_depth
1797
+ self.min_samples_split = min_samples_split
1798
+ self.min_samples_leaf = min_samples_leaf
1799
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1800
+ self.max_features = max_features
1801
+ self.max_leaf_nodes = max_leaf_nodes
1802
+ self.min_impurity_decrease = min_impurity_decrease
1803
+ self.min_impurity_split = min_impurity_split
1804
+ self.ccp_alpha = ccp_alpha
1805
+ self.max_bins = max_bins
1806
+ self.min_bin_size = min_bin_size
1807
+ self.max_bins = max_bins
1808
+ self.min_bin_size = min_bin_size
1809
+
1810
+
1811
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1812
+ class ExtraTreesRegressor(ForestRegressor):
1813
+ __doc__ = sklearn_ExtraTreesRegressor.__doc__
1814
+ _onedal_factory = onedal_ExtraTreesRegressor
1815
+
1816
+ if sklearn_check_version("1.2"):
1817
+ _parameter_constraints: dict = {
1818
+ **sklearn_ExtraTreesRegressor._parameter_constraints,
1819
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1820
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1821
+ }
1822
+
1823
+ if sklearn_check_version("1.4"):
1824
+
1825
+ def __init__(
1826
+ self,
1827
+ n_estimators=100,
1828
+ *,
1829
+ criterion="squared_error",
1830
+ max_depth=None,
1831
+ min_samples_split=2,
1832
+ min_samples_leaf=1,
1833
+ min_weight_fraction_leaf=0.0,
1834
+ max_features=1.0,
1835
+ max_leaf_nodes=None,
1836
+ min_impurity_decrease=0.0,
1837
+ bootstrap=False,
1838
+ oob_score=False,
1839
+ n_jobs=None,
1840
+ random_state=None,
1841
+ verbose=0,
1842
+ warm_start=False,
1843
+ ccp_alpha=0.0,
1844
+ max_samples=None,
1845
+ monotonic_cst=None,
1846
+ max_bins=256,
1847
+ min_bin_size=1,
1848
+ ):
1849
+ super().__init__(
1850
+ ExtraTreeRegressor(),
1851
+ n_estimators=n_estimators,
1852
+ estimator_params=(
1853
+ "criterion",
1854
+ "max_depth",
1855
+ "min_samples_split",
1856
+ "min_samples_leaf",
1857
+ "min_weight_fraction_leaf",
1858
+ "max_features",
1859
+ "max_leaf_nodes",
1860
+ "min_impurity_decrease",
1861
+ "random_state",
1862
+ "ccp_alpha",
1863
+ "monotonic_cst",
1864
+ ),
1865
+ bootstrap=bootstrap,
1866
+ oob_score=oob_score,
1867
+ n_jobs=n_jobs,
1868
+ random_state=random_state,
1869
+ verbose=verbose,
1870
+ warm_start=warm_start,
1871
+ max_samples=max_samples,
1872
+ )
1873
+
1874
+ self.criterion = criterion
1875
+ self.max_depth = max_depth
1876
+ self.min_samples_split = min_samples_split
1877
+ self.min_samples_leaf = min_samples_leaf
1878
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1879
+ self.max_features = max_features
1880
+ self.max_leaf_nodes = max_leaf_nodes
1881
+ self.min_impurity_decrease = min_impurity_decrease
1882
+ self.ccp_alpha = ccp_alpha
1883
+ self.max_bins = max_bins
1884
+ self.min_bin_size = min_bin_size
1885
+ self.monotonic_cst = monotonic_cst
1886
+
1887
+ elif sklearn_check_version("1.0"):
1888
+
1889
+ def __init__(
1890
+ self,
1891
+ n_estimators=100,
1892
+ *,
1893
+ criterion="squared_error",
1894
+ max_depth=None,
1895
+ min_samples_split=2,
1896
+ min_samples_leaf=1,
1897
+ min_weight_fraction_leaf=0.0,
1898
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1899
+ max_leaf_nodes=None,
1900
+ min_impurity_decrease=0.0,
1901
+ bootstrap=False,
1902
+ oob_score=False,
1903
+ n_jobs=None,
1904
+ random_state=None,
1905
+ verbose=0,
1906
+ warm_start=False,
1907
+ ccp_alpha=0.0,
1908
+ max_samples=None,
1909
+ max_bins=256,
1910
+ min_bin_size=1,
1911
+ ):
1912
+ super().__init__(
1913
+ ExtraTreeRegressor(),
1914
+ n_estimators=n_estimators,
1915
+ estimator_params=(
1916
+ "criterion",
1917
+ "max_depth",
1918
+ "min_samples_split",
1919
+ "min_samples_leaf",
1920
+ "min_weight_fraction_leaf",
1921
+ "max_features",
1922
+ "max_leaf_nodes",
1923
+ "min_impurity_decrease",
1924
+ "random_state",
1925
+ "ccp_alpha",
1926
+ ),
1927
+ bootstrap=bootstrap,
1928
+ oob_score=oob_score,
1929
+ n_jobs=n_jobs,
1930
+ random_state=random_state,
1931
+ verbose=verbose,
1932
+ warm_start=warm_start,
1933
+ max_samples=max_samples,
1934
+ )
1935
+
1936
+ self.criterion = criterion
1937
+ self.max_depth = max_depth
1938
+ self.min_samples_split = min_samples_split
1939
+ self.min_samples_leaf = min_samples_leaf
1940
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1941
+ self.max_features = max_features
1942
+ self.max_leaf_nodes = max_leaf_nodes
1943
+ self.min_impurity_decrease = min_impurity_decrease
1944
+ self.ccp_alpha = ccp_alpha
1945
+ self.max_bins = max_bins
1946
+ self.min_bin_size = min_bin_size
1947
+
1948
+ else:
1949
+
1950
+ def __init__(
1951
+ self,
1952
+ n_estimators=100,
1953
+ *,
1954
+ criterion="mse",
1955
+ max_depth=None,
1956
+ min_samples_split=2,
1957
+ min_samples_leaf=1,
1958
+ min_weight_fraction_leaf=0.0,
1959
+ max_features="auto",
1960
+ max_leaf_nodes=None,
1961
+ min_impurity_decrease=0.0,
1962
+ min_impurity_split=None,
1963
+ bootstrap=False,
1964
+ oob_score=False,
1965
+ n_jobs=None,
1966
+ random_state=None,
1967
+ verbose=0,
1968
+ warm_start=False,
1969
+ ccp_alpha=0.0,
1970
+ max_samples=None,
1971
+ max_bins=256,
1972
+ min_bin_size=1,
1973
+ ):
1974
+ super().__init__(
1975
+ ExtraTreeRegressor(),
1976
+ n_estimators=n_estimators,
1977
+ estimator_params=(
1978
+ "criterion",
1979
+ "max_depth",
1980
+ "min_samples_split",
1981
+ "min_samples_leaf",
1982
+ "min_weight_fraction_leaf",
1983
+ "max_features",
1984
+ "max_leaf_nodes",
1985
+ "min_impurity_decrease",
1986
+ "min_impurity_split" "random_state",
1987
+ "ccp_alpha",
1988
+ ),
1989
+ bootstrap=bootstrap,
1990
+ oob_score=oob_score,
1991
+ n_jobs=n_jobs,
1992
+ random_state=random_state,
1993
+ verbose=verbose,
1994
+ warm_start=warm_start,
1995
+ max_samples=max_samples,
1996
+ )
1997
+
1998
+ self.criterion = criterion
1999
+ self.max_depth = max_depth
2000
+ self.min_samples_split = min_samples_split
2001
+ self.min_samples_leaf = min_samples_leaf
2002
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
2003
+ self.max_features = max_features
2004
+ self.max_leaf_nodes = max_leaf_nodes
2005
+ self.min_impurity_decrease = min_impurity_decrease
2006
+ self.min_impurity_split = min_impurity_split
2007
+ self.ccp_alpha = ccp_alpha
2008
+ self.max_bins = max_bins
2009
+ self.min_bin_size = min_bin_size
2010
+
2011
+
2012
+ # Allow for isinstance calls without inheritance changes using ABCMeta
2013
+ sklearn_RandomForestClassifier.register(RandomForestClassifier)
2014
+ sklearn_RandomForestRegressor.register(RandomForestRegressor)
2015
+ sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
2016
+ sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)