scikit-learn-intelex 2025.1.0__py39-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-39-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-39-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-39-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,2029 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABC
20
+
21
+ import numpy as np
22
+ from scipy import sparse as sp
23
+ from sklearn.base import BaseEstimator, clone
24
+ from sklearn.ensemble import ExtraTreesClassifier as _sklearn_ExtraTreesClassifier
25
+ from sklearn.ensemble import ExtraTreesRegressor as _sklearn_ExtraTreesRegressor
26
+ from sklearn.ensemble import RandomForestClassifier as _sklearn_RandomForestClassifier
27
+ from sklearn.ensemble import RandomForestRegressor as _sklearn_RandomForestRegressor
28
+ from sklearn.ensemble._forest import ForestClassifier as _sklearn_ForestClassifier
29
+ from sklearn.ensemble._forest import ForestRegressor as _sklearn_ForestRegressor
30
+ from sklearn.ensemble._forest import _get_n_samples_bootstrap
31
+ from sklearn.exceptions import DataConversionWarning
32
+ from sklearn.metrics import accuracy_score, r2_score
33
+ from sklearn.tree import (
34
+ DecisionTreeClassifier,
35
+ DecisionTreeRegressor,
36
+ ExtraTreeClassifier,
37
+ ExtraTreeRegressor,
38
+ )
39
+ from sklearn.tree._tree import Tree
40
+ from sklearn.utils import check_random_state, deprecated
41
+ from sklearn.utils.validation import (
42
+ _check_sample_weight,
43
+ check_array,
44
+ check_is_fitted,
45
+ check_X_y,
46
+ )
47
+
48
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
49
+ from daal4py.sklearn._utils import (
50
+ check_tree_nodes,
51
+ daal_check_version,
52
+ sklearn_check_version,
53
+ )
54
+ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
55
+ from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
56
+ from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
57
+ from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
58
+ from onedal.primitives import get_tree_state_cls, get_tree_state_reg
59
+ from onedal.utils import _num_features, _num_samples
60
+ from sklearnex import get_hyperparameters
61
+ from sklearnex._utils import register_hyperparameters
62
+
63
+ from .._device_offload import dispatch, wrap_output_data
64
+ from .._utils import PatchingConditionsChain
65
+ from ..utils._array_api import get_namespace
66
+
67
+ if sklearn_check_version("1.2"):
68
+ from sklearn.utils._param_validation import Interval
69
+ if sklearn_check_version("1.4"):
70
+ from daal4py.sklearn.utils import _assert_all_finite
71
+
72
+ if sklearn_check_version("1.6"):
73
+ from sklearn.utils.validation import validate_data
74
+ else:
75
+ validate_data = BaseEstimator._validate_data
76
+
77
+
78
+ class BaseForest(ABC):
79
+ _onedal_factory = None
80
+
81
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
82
+ X, y = validate_data(
83
+ self,
84
+ X,
85
+ y,
86
+ multi_output=True,
87
+ accept_sparse=False,
88
+ dtype=[np.float64, np.float32],
89
+ force_all_finite=False,
90
+ ensure_2d=True,
91
+ )
92
+
93
+ if sample_weight is not None:
94
+ sample_weight = _check_sample_weight(sample_weight, X)
95
+
96
+ if y.ndim == 2 and y.shape[1] == 1:
97
+ warnings.warn(
98
+ "A column-vector y was passed when a 1d array was"
99
+ " expected. Please change the shape of y to "
100
+ "(n_samples,), for example using ravel().",
101
+ DataConversionWarning,
102
+ stacklevel=2,
103
+ )
104
+
105
+ if y.ndim == 1:
106
+ # reshape is necessary to preserve the data contiguity against vs
107
+ # [:, np.newaxis] that does not.
108
+ y = np.reshape(y, (-1, 1))
109
+
110
+ self._n_samples, self.n_outputs_ = y.shape
111
+
112
+ y, expanded_class_weight = self._validate_y_class_weight(y)
113
+
114
+ if expanded_class_weight is not None:
115
+ if sample_weight is not None:
116
+ sample_weight = sample_weight * expanded_class_weight
117
+ else:
118
+ sample_weight = expanded_class_weight
119
+ if sample_weight is not None:
120
+ sample_weight = [sample_weight]
121
+
122
+ onedal_params = {
123
+ "n_estimators": self.n_estimators,
124
+ "criterion": self.criterion,
125
+ "max_depth": self.max_depth,
126
+ "min_samples_split": self.min_samples_split,
127
+ "min_samples_leaf": self.min_samples_leaf,
128
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
129
+ "max_features": self._to_absolute_max_features(
130
+ self.max_features, self.n_features_in_
131
+ ),
132
+ "max_leaf_nodes": self.max_leaf_nodes,
133
+ "min_impurity_decrease": self.min_impurity_decrease,
134
+ "bootstrap": self.bootstrap,
135
+ "oob_score": self.oob_score,
136
+ "n_jobs": self.n_jobs,
137
+ "random_state": self.random_state,
138
+ "verbose": self.verbose,
139
+ "warm_start": self.warm_start,
140
+ "error_metric_mode": self._err if self.oob_score else "none",
141
+ "variable_importance_mode": "mdi",
142
+ "class_weight": self.class_weight,
143
+ "max_bins": self.max_bins,
144
+ "min_bin_size": self.min_bin_size,
145
+ "max_samples": self.max_samples,
146
+ }
147
+
148
+ if not sklearn_check_version("1.0"):
149
+ onedal_params["min_impurity_split"] = self.min_impurity_split
150
+ else:
151
+ onedal_params["min_impurity_split"] = None
152
+
153
+ # Lazy evaluation of estimators_
154
+ self._cached_estimators_ = None
155
+
156
+ # Compute
157
+ self._onedal_estimator = self._onedal_factory(**onedal_params)
158
+ self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
159
+
160
+ self._save_attributes()
161
+
162
+ # Decapsulate classes_ attributes
163
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
164
+ self.n_classes_ = self.n_classes_[0]
165
+ self.classes_ = self.classes_[0]
166
+
167
+ return self
168
+
169
+ def _save_attributes(self):
170
+ if self.oob_score:
171
+ self.oob_score_ = self._onedal_estimator.oob_score_
172
+ if hasattr(self._onedal_estimator, "oob_prediction_"):
173
+ self.oob_prediction_ = self._onedal_estimator.oob_prediction_
174
+ if hasattr(self._onedal_estimator, "oob_decision_function_"):
175
+ self.oob_decision_function_ = (
176
+ self._onedal_estimator.oob_decision_function_
177
+ )
178
+ if self.bootstrap:
179
+ self._n_samples_bootstrap = max(
180
+ round(
181
+ self._onedal_estimator.observations_per_tree_fraction
182
+ * self._n_samples
183
+ ),
184
+ 1,
185
+ )
186
+ else:
187
+ self._n_samples_bootstrap = None
188
+ self._validate_estimator()
189
+ return self
190
+
191
+ def _to_absolute_max_features(self, max_features, n_features):
192
+ if max_features is None:
193
+ return n_features
194
+ if isinstance(max_features, str):
195
+ if max_features == "auto":
196
+ if not sklearn_check_version("1.3"):
197
+ if sklearn_check_version("1.1"):
198
+ warnings.warn(
199
+ "`max_features='auto'` has been deprecated in 1.1 "
200
+ "and will be removed in 1.3. To keep the past behaviour, "
201
+ "explicitly set `max_features=1.0` or remove this "
202
+ "parameter as it is also the default value for "
203
+ "RandomForestRegressors and ExtraTreesRegressors.",
204
+ FutureWarning,
205
+ )
206
+ return (
207
+ max(1, int(np.sqrt(n_features)))
208
+ if isinstance(self, ForestClassifier)
209
+ else n_features
210
+ )
211
+ if max_features == "sqrt":
212
+ return max(1, int(np.sqrt(n_features)))
213
+ if max_features == "log2":
214
+ return max(1, int(np.log2(n_features)))
215
+ allowed_string_values = (
216
+ '"sqrt" or "log2"'
217
+ if sklearn_check_version("1.3")
218
+ else '"auto", "sqrt" or "log2"'
219
+ )
220
+ raise ValueError(
221
+ "Invalid value for max_features. Allowed string "
222
+ f"values are {allowed_string_values}."
223
+ )
224
+ if isinstance(max_features, (numbers.Integral, np.integer)):
225
+ return max_features
226
+ if max_features > 0.0:
227
+ return max(1, int(max_features * n_features))
228
+ return 0
229
+
230
+ def _check_parameters(self):
231
+ if isinstance(self.min_samples_leaf, numbers.Integral):
232
+ if not 1 <= self.min_samples_leaf:
233
+ raise ValueError(
234
+ "min_samples_leaf must be at least 1 "
235
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
236
+ )
237
+ else: # float
238
+ if not 0.0 < self.min_samples_leaf <= 0.5:
239
+ raise ValueError(
240
+ "min_samples_leaf must be at least 1 "
241
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
242
+ )
243
+ if isinstance(self.min_samples_split, numbers.Integral):
244
+ if not 2 <= self.min_samples_split:
245
+ raise ValueError(
246
+ "min_samples_split must be an integer "
247
+ "greater than 1 or a float in (0.0, 1.0]; "
248
+ "got the integer %s" % self.min_samples_split
249
+ )
250
+ else: # float
251
+ if not 0.0 < self.min_samples_split <= 1.0:
252
+ raise ValueError(
253
+ "min_samples_split must be an integer "
254
+ "greater than 1 or a float in (0.0, 1.0]; "
255
+ "got the float %s" % self.min_samples_split
256
+ )
257
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
258
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
259
+ if hasattr(self, "min_impurity_split"):
260
+ warnings.warn(
261
+ "The min_impurity_split parameter is deprecated. "
262
+ "Its default value has changed from 1e-7 to 0 in "
263
+ "version 0.23, and it will be removed in 0.25. "
264
+ "Use the min_impurity_decrease parameter instead.",
265
+ FutureWarning,
266
+ )
267
+
268
+ if getattr(self, "min_impurity_split") < 0.0:
269
+ raise ValueError(
270
+ "min_impurity_split must be greater than " "or equal to 0"
271
+ )
272
+ if self.min_impurity_decrease < 0.0:
273
+ raise ValueError(
274
+ "min_impurity_decrease must be greater than " "or equal to 0"
275
+ )
276
+ if self.max_leaf_nodes is not None:
277
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
278
+ raise ValueError(
279
+ "max_leaf_nodes must be integral number but was "
280
+ "%r" % self.max_leaf_nodes
281
+ )
282
+ if self.max_leaf_nodes < 2:
283
+ raise ValueError(
284
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
285
+ self.max_leaf_nodes
286
+ )
287
+ )
288
+ if isinstance(self.max_bins, numbers.Integral):
289
+ if not 2 <= self.max_bins:
290
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
291
+ else:
292
+ raise ValueError(
293
+ "max_bins must be integral number but was " "%r" % self.max_bins
294
+ )
295
+ if isinstance(self.min_bin_size, numbers.Integral):
296
+ if not 1 <= self.min_bin_size:
297
+ raise ValueError(
298
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
299
+ )
300
+ else:
301
+ raise ValueError(
302
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
303
+ )
304
+
305
+ @property
306
+ def estimators_(self):
307
+ if hasattr(self, "_cached_estimators_"):
308
+ if self._cached_estimators_ is None:
309
+ self._estimators_()
310
+ return self._cached_estimators_
311
+ else:
312
+ raise AttributeError(
313
+ f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
314
+ )
315
+
316
+ @estimators_.setter
317
+ def estimators_(self, estimators):
318
+ # Needed to allow for proper sklearn operation in fallback mode
319
+ self._cached_estimators_ = estimators
320
+
321
+ def _estimators_(self):
322
+ # _estimators_ should only be called if _onedal_estimator exists
323
+ check_is_fitted(self, "_onedal_estimator")
324
+ if hasattr(self, "n_classes_"):
325
+ n_classes_ = (
326
+ self.n_classes_
327
+ if isinstance(self.n_classes_, int)
328
+ else self.n_classes_[0]
329
+ )
330
+ else:
331
+ n_classes_ = 1
332
+
333
+ # convert model to estimators
334
+ params = {
335
+ "criterion": self._onedal_estimator.criterion,
336
+ "max_depth": self._onedal_estimator.max_depth,
337
+ "min_samples_split": self._onedal_estimator.min_samples_split,
338
+ "min_samples_leaf": self._onedal_estimator.min_samples_leaf,
339
+ "min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
340
+ "max_features": self._onedal_estimator.max_features,
341
+ "max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
342
+ "min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
343
+ "random_state": None,
344
+ }
345
+ if not sklearn_check_version("1.0"):
346
+ params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
347
+ est = self.estimator.__class__(**params)
348
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
349
+ # oneAPI Data Analytics Library solution
350
+ estimators_ = []
351
+
352
+ random_state_checked = check_random_state(self.random_state)
353
+
354
+ for i in range(self._onedal_estimator.n_estimators):
355
+ est_i = clone(est)
356
+ est_i.set_params(
357
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
358
+ )
359
+ if sklearn_check_version("1.0"):
360
+ est_i.n_features_in_ = self.n_features_in_
361
+ else:
362
+ est_i.n_features_ = self.n_features_in_
363
+ est_i.n_outputs_ = self.n_outputs_
364
+ est_i.n_classes_ = n_classes_
365
+ tree_i_state_class = self._get_tree_state(
366
+ self._onedal_estimator._onedal_model, i, n_classes_
367
+ )
368
+ tree_i_state_dict = {
369
+ "max_depth": tree_i_state_class.max_depth,
370
+ "node_count": tree_i_state_class.node_count,
371
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
372
+ "values": tree_i_state_class.value_ar,
373
+ }
374
+ est_i.tree_ = Tree(
375
+ self.n_features_in_,
376
+ np.array([n_classes_], dtype=np.intp),
377
+ self.n_outputs_,
378
+ )
379
+ est_i.tree_.__setstate__(tree_i_state_dict)
380
+ estimators_.append(est_i)
381
+
382
+ self._cached_estimators_ = estimators_
383
+
384
+ if sklearn_check_version("1.0"):
385
+
386
+ @deprecated(
387
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
388
+ "removed in 1.2. Use `n_features_in_` instead."
389
+ )
390
+ @property
391
+ def n_features_(self):
392
+ return self.n_features_in_
393
+
394
+ if not sklearn_check_version("1.2"):
395
+
396
+ @property
397
+ def base_estimator(self):
398
+ return self.estimator
399
+
400
+ @base_estimator.setter
401
+ def base_estimator(self, estimator):
402
+ self.estimator = estimator
403
+
404
+
405
+ class ForestClassifier(_sklearn_ForestClassifier, BaseForest):
406
+ # Surprisingly, even though scikit-learn warns against using
407
+ # their ForestClassifier directly, it actually has a more stable
408
+ # API than the user-facing objects (over time). If they change it
409
+ # significantly at some point then this may need to be versioned.
410
+
411
+ _err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
412
+ _get_tree_state = staticmethod(get_tree_state_cls)
413
+
414
+ def __init__(
415
+ self,
416
+ estimator,
417
+ n_estimators=100,
418
+ *,
419
+ estimator_params=tuple(),
420
+ bootstrap=False,
421
+ oob_score=False,
422
+ n_jobs=None,
423
+ random_state=None,
424
+ verbose=0,
425
+ warm_start=False,
426
+ class_weight=None,
427
+ max_samples=None,
428
+ ):
429
+ super().__init__(
430
+ estimator,
431
+ n_estimators=n_estimators,
432
+ estimator_params=estimator_params,
433
+ bootstrap=bootstrap,
434
+ oob_score=oob_score,
435
+ n_jobs=n_jobs,
436
+ random_state=random_state,
437
+ verbose=verbose,
438
+ warm_start=warm_start,
439
+ class_weight=class_weight,
440
+ max_samples=max_samples,
441
+ )
442
+
443
+ # The estimator is checked against the class attribute for conformance.
444
+ # This should only trigger if the user uses this class directly.
445
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
446
+ self._onedal_factory, onedal_RandomForestClassifier
447
+ ):
448
+ self._onedal_factory = onedal_RandomForestClassifier
449
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
450
+ self._onedal_factory, onedal_ExtraTreesClassifier
451
+ ):
452
+ self._onedal_factory = onedal_ExtraTreesClassifier
453
+
454
+ if self._onedal_factory is None:
455
+ raise TypeError(f" oneDAL estimator has not been set.")
456
+
457
+ def _estimators_(self):
458
+ super()._estimators_()
459
+ classes_ = self.classes_[0]
460
+ for est in self._cached_estimators_:
461
+ est.classes_ = classes_
462
+
463
+ def fit(self, X, y, sample_weight=None):
464
+ dispatch(
465
+ self,
466
+ "fit",
467
+ {
468
+ "onedal": self.__class__._onedal_fit,
469
+ "sklearn": _sklearn_ForestClassifier.fit,
470
+ },
471
+ X,
472
+ y,
473
+ sample_weight,
474
+ )
475
+ return self
476
+
477
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
478
+ if sp.issparse(y):
479
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
480
+
481
+ if sklearn_check_version("1.2"):
482
+ self._validate_params()
483
+ else:
484
+ self._check_parameters()
485
+
486
+ if not self.bootstrap and self.oob_score:
487
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
488
+
489
+ patching_status.and_conditions(
490
+ [
491
+ (
492
+ self.oob_score
493
+ and daal_check_version((2021, "P", 500))
494
+ or not self.oob_score,
495
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
496
+ ),
497
+ (self.warm_start is False, "Warm start is not supported."),
498
+ (
499
+ self.criterion == "gini",
500
+ f"'{self.criterion}' criterion is not supported. "
501
+ "Only 'gini' criterion is supported.",
502
+ ),
503
+ (
504
+ self.ccp_alpha == 0.0,
505
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
506
+ ),
507
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
508
+ (
509
+ self.n_estimators <= 6024,
510
+ "More than 6024 estimators is not supported.",
511
+ ),
512
+ ]
513
+ )
514
+
515
+ if self.bootstrap:
516
+ patching_status.and_conditions(
517
+ [
518
+ (
519
+ self.class_weight != "balanced_subsample",
520
+ "'balanced_subsample' for class_weight is not supported",
521
+ )
522
+ ]
523
+ )
524
+
525
+ if patching_status.get_status() and sklearn_check_version("1.4"):
526
+ try:
527
+ _assert_all_finite(X)
528
+ input_is_finite = True
529
+ except ValueError:
530
+ input_is_finite = False
531
+ patching_status.and_conditions(
532
+ [
533
+ (input_is_finite, "Non-finite input is not supported."),
534
+ (
535
+ self.monotonic_cst is None,
536
+ "Monotonicity constraints are not supported.",
537
+ ),
538
+ ]
539
+ )
540
+
541
+ if patching_status.get_status():
542
+ X, y = check_X_y(
543
+ X,
544
+ y,
545
+ multi_output=True,
546
+ accept_sparse=True,
547
+ dtype=[np.float64, np.float32],
548
+ force_all_finite=False,
549
+ )
550
+
551
+ if y.ndim == 2 and y.shape[1] == 1:
552
+ warnings.warn(
553
+ "A column-vector y was passed when a 1d array was"
554
+ " expected. Please change the shape of y to "
555
+ "(n_samples,), for example using ravel().",
556
+ DataConversionWarning,
557
+ stacklevel=2,
558
+ )
559
+
560
+ if y.ndim == 1:
561
+ y = np.reshape(y, (-1, 1))
562
+
563
+ self.n_outputs_ = y.shape[1]
564
+
565
+ patching_status.and_conditions(
566
+ [
567
+ (
568
+ self.n_outputs_ == 1,
569
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
570
+ ),
571
+ (
572
+ y.dtype in [np.float32, np.float64, np.int32, np.int64],
573
+ f"Datatype ({y.dtype}) for y is not supported.",
574
+ ),
575
+ ]
576
+ )
577
+ # TODO: Fix to support integers as input
578
+
579
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
580
+
581
+ if not self.bootstrap and self.max_samples is not None:
582
+ raise ValueError(
583
+ "`max_sample` cannot be set if `bootstrap=False`. "
584
+ "Either switch to `bootstrap=True` or set "
585
+ "`max_sample=None`."
586
+ )
587
+
588
+ if (
589
+ patching_status.get_status()
590
+ and (self.random_state is not None)
591
+ and (not daal_check_version((2024, "P", 0)))
592
+ ):
593
+ warnings.warn(
594
+ "Setting 'random_state' value is not supported. "
595
+ "State set by oneDAL to default value (777).",
596
+ RuntimeWarning,
597
+ )
598
+
599
+ return patching_status, X, y, sample_weight
600
+
601
+ @wrap_output_data
602
+ def predict(self, X):
603
+ check_is_fitted(self)
604
+ return dispatch(
605
+ self,
606
+ "predict",
607
+ {
608
+ "onedal": self.__class__._onedal_predict,
609
+ "sklearn": _sklearn_ForestClassifier.predict,
610
+ },
611
+ X,
612
+ )
613
+
614
+ @wrap_output_data
615
+ def predict_proba(self, X):
616
+ # TODO:
617
+ # _check_proba()
618
+ # self._check_proba()
619
+ check_is_fitted(self)
620
+ return dispatch(
621
+ self,
622
+ "predict_proba",
623
+ {
624
+ "onedal": self.__class__._onedal_predict_proba,
625
+ "sklearn": _sklearn_ForestClassifier.predict_proba,
626
+ },
627
+ X,
628
+ )
629
+
630
+ def predict_log_proba(self, X):
631
+ xp, _ = get_namespace(X)
632
+ proba = self.predict_proba(X)
633
+
634
+ if self.n_outputs_ == 1:
635
+ return xp.log(proba)
636
+
637
+ else:
638
+ for k in range(self.n_outputs_):
639
+ proba[k] = xp.log(proba[k])
640
+
641
+ return proba
642
+
643
+ @wrap_output_data
644
+ def score(self, X, y, sample_weight=None):
645
+ check_is_fitted(self)
646
+ return dispatch(
647
+ self,
648
+ "score",
649
+ {
650
+ "onedal": self.__class__._onedal_score,
651
+ "sklearn": _sklearn_ForestClassifier.score,
652
+ },
653
+ X,
654
+ y,
655
+ sample_weight=sample_weight,
656
+ )
657
+
658
+ fit.__doc__ = _sklearn_ForestClassifier.fit.__doc__
659
+ predict.__doc__ = _sklearn_ForestClassifier.predict.__doc__
660
+ predict_proba.__doc__ = _sklearn_ForestClassifier.predict_proba.__doc__
661
+ predict_log_proba.__doc__ = _sklearn_ForestClassifier.predict_log_proba.__doc__
662
+ score.__doc__ = _sklearn_ForestClassifier.score.__doc__
663
+
664
+ def _onedal_cpu_supported(self, method_name, *data):
665
+ class_name = self.__class__.__name__
666
+ patching_status = PatchingConditionsChain(
667
+ f"sklearn.ensemble.{class_name}.{method_name}"
668
+ )
669
+
670
+ if method_name == "fit":
671
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
672
+ patching_status, *data
673
+ )
674
+
675
+ patching_status.and_conditions(
676
+ [
677
+ (
678
+ daal_check_version((2023, "P", 200))
679
+ or self.estimator.__class__ == DecisionTreeClassifier,
680
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
681
+ ),
682
+ (
683
+ not sp.issparse(sample_weight),
684
+ "sample_weight is sparse. " "Sparse input is not supported.",
685
+ ),
686
+ ]
687
+ )
688
+
689
+ elif method_name in ["predict", "predict_proba", "score"]:
690
+ X = data[0]
691
+
692
+ patching_status.and_conditions(
693
+ [
694
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
695
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
696
+ (self.warm_start is False, "Warm start is not supported."),
697
+ (
698
+ daal_check_version((2023, "P", 100))
699
+ or self.estimator.__class__ == DecisionTreeClassifier,
700
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
701
+ ),
702
+ ]
703
+ )
704
+
705
+ if method_name == "predict_proba":
706
+ patching_status.and_conditions(
707
+ [
708
+ (
709
+ daal_check_version((2021, "P", 400)),
710
+ "oneDAL version is lower than 2021.4.",
711
+ )
712
+ ]
713
+ )
714
+
715
+ if hasattr(self, "n_outputs_"):
716
+ patching_status.and_conditions(
717
+ [
718
+ (
719
+ self.n_outputs_ == 1,
720
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
721
+ ),
722
+ ]
723
+ )
724
+
725
+ else:
726
+ raise RuntimeError(
727
+ f"Unknown method {method_name} in {self.__class__.__name__}"
728
+ )
729
+
730
+ return patching_status
731
+
732
+ def _onedal_gpu_supported(self, method_name, *data):
733
+ class_name = self.__class__.__name__
734
+ patching_status = PatchingConditionsChain(
735
+ f"sklearn.ensemble.{class_name}.{method_name}"
736
+ )
737
+
738
+ if method_name == "fit":
739
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
740
+ patching_status, *data
741
+ )
742
+
743
+ patching_status.and_conditions(
744
+ [
745
+ (
746
+ daal_check_version((2023, "P", 100))
747
+ or self.estimator.__class__ == DecisionTreeClassifier,
748
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
749
+ ),
750
+ (
751
+ not self.oob_score,
752
+ "oob_scores using r2 or accuracy not implemented.",
753
+ ),
754
+ (sample_weight is None, "sample_weight is not supported."),
755
+ ]
756
+ )
757
+
758
+ elif method_name in ["predict", "predict_proba", "score"]:
759
+ X = data[0]
760
+
761
+ patching_status.and_conditions(
762
+ [
763
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
764
+ (
765
+ not sp.issparse(X),
766
+ "X is sparse. Sparse input is not supported.",
767
+ ),
768
+ (self.warm_start is False, "Warm start is not supported."),
769
+ (
770
+ daal_check_version((2023, "P", 100)),
771
+ "ExtraTrees supported starting from oneDAL version 2023.1",
772
+ ),
773
+ ]
774
+ )
775
+ if hasattr(self, "n_outputs_"):
776
+ patching_status.and_conditions(
777
+ [
778
+ (
779
+ self.n_outputs_ == 1,
780
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
781
+ ),
782
+ ]
783
+ )
784
+
785
+ else:
786
+ raise RuntimeError(
787
+ f"Unknown method {method_name} in {self.__class__.__name__}"
788
+ )
789
+
790
+ return patching_status
791
+
792
+ def _onedal_predict(self, X, queue=None):
793
+
794
+ if sklearn_check_version("1.0"):
795
+ X = validate_data(
796
+ self,
797
+ X,
798
+ dtype=[np.float64, np.float32],
799
+ force_all_finite=False,
800
+ reset=False,
801
+ ensure_2d=True,
802
+ )
803
+ else:
804
+ X = check_array(
805
+ X,
806
+ dtype=[np.float64, np.float32],
807
+ force_all_finite=False,
808
+ ) # Warning, order of dtype matters
809
+ if hasattr(self, "n_features_in_"):
810
+ try:
811
+ num_features = _num_features(X)
812
+ except TypeError:
813
+ num_features = _num_samples(X)
814
+ if num_features != self.n_features_in_:
815
+ raise ValueError(
816
+ (
817
+ f"X has {num_features} features, "
818
+ f"but {self.__class__.__name__} is expecting "
819
+ f"{self.n_features_in_} features as input"
820
+ )
821
+ )
822
+ self._check_n_features(X, reset=False)
823
+
824
+ res = self._onedal_estimator.predict(X, queue=queue)
825
+ return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
826
+
827
+ def _onedal_predict_proba(self, X, queue=None):
828
+
829
+ if sklearn_check_version("1.0"):
830
+ X = validate_data(
831
+ self,
832
+ X,
833
+ dtype=[np.float64, np.float32],
834
+ force_all_finite=False,
835
+ reset=False,
836
+ ensure_2d=True,
837
+ )
838
+ else:
839
+ X = check_array(
840
+ X,
841
+ dtype=[np.float64, np.float32],
842
+ force_all_finite=False,
843
+ ) # Warning, order of dtype matters
844
+ self._check_n_features(X, reset=False)
845
+
846
+ return self._onedal_estimator.predict_proba(X, queue=queue)
847
+
848
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
849
+ return accuracy_score(
850
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
851
+ )
852
+
853
+
854
+ class ForestRegressor(_sklearn_ForestRegressor, BaseForest):
855
+ _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
856
+ _get_tree_state = staticmethod(get_tree_state_reg)
857
+
858
+ def __init__(
859
+ self,
860
+ estimator,
861
+ n_estimators=100,
862
+ *,
863
+ estimator_params=tuple(),
864
+ bootstrap=False,
865
+ oob_score=False,
866
+ n_jobs=None,
867
+ random_state=None,
868
+ verbose=0,
869
+ warm_start=False,
870
+ max_samples=None,
871
+ ):
872
+ super().__init__(
873
+ estimator,
874
+ n_estimators=n_estimators,
875
+ estimator_params=estimator_params,
876
+ bootstrap=bootstrap,
877
+ oob_score=oob_score,
878
+ n_jobs=n_jobs,
879
+ random_state=random_state,
880
+ verbose=verbose,
881
+ warm_start=warm_start,
882
+ max_samples=max_samples,
883
+ )
884
+
885
+ # The splitter is checked against the class attribute for conformance
886
+ # This should only trigger if the user uses this class directly.
887
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
888
+ self._onedal_factory, onedal_RandomForestRegressor
889
+ ):
890
+ self._onedal_factory = onedal_RandomForestRegressor
891
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
892
+ self._onedal_factory, onedal_ExtraTreesRegressor
893
+ ):
894
+ self._onedal_factory = onedal_ExtraTreesRegressor
895
+
896
+ if self._onedal_factory is None:
897
+ raise TypeError(f" oneDAL estimator has not been set.")
898
+
899
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
900
+ if sp.issparse(y):
901
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
902
+
903
+ if sklearn_check_version("1.2"):
904
+ self._validate_params()
905
+ else:
906
+ self._check_parameters()
907
+
908
+ if not self.bootstrap and self.oob_score:
909
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
910
+
911
+ if sklearn_check_version("1.0") and self.criterion == "mse":
912
+ warnings.warn(
913
+ "Criterion 'mse' was deprecated in v1.0 and will be "
914
+ "removed in version 1.2. Use `criterion='squared_error'` "
915
+ "which is equivalent.",
916
+ FutureWarning,
917
+ )
918
+
919
+ patching_status.and_conditions(
920
+ [
921
+ (
922
+ self.oob_score
923
+ and daal_check_version((2021, "P", 500))
924
+ or not self.oob_score,
925
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
926
+ ),
927
+ (self.warm_start is False, "Warm start is not supported."),
928
+ (
929
+ self.criterion in ["mse", "squared_error"],
930
+ f"'{self.criterion}' criterion is not supported. "
931
+ "Only 'mse' and 'squared_error' criteria are supported.",
932
+ ),
933
+ (
934
+ self.ccp_alpha == 0.0,
935
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
936
+ ),
937
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
938
+ (
939
+ self.n_estimators <= 6024,
940
+ "More than 6024 estimators is not supported.",
941
+ ),
942
+ ]
943
+ )
944
+
945
+ if patching_status.get_status() and sklearn_check_version("1.4"):
946
+ try:
947
+ _assert_all_finite(X)
948
+ input_is_finite = True
949
+ except ValueError:
950
+ input_is_finite = False
951
+ patching_status.and_conditions(
952
+ [
953
+ (input_is_finite, "Non-finite input is not supported."),
954
+ (
955
+ self.monotonic_cst is None,
956
+ "Monotonicity constraints are not supported.",
957
+ ),
958
+ ]
959
+ )
960
+
961
+ if patching_status.get_status():
962
+ X, y = check_X_y(
963
+ X,
964
+ y,
965
+ multi_output=True,
966
+ accept_sparse=True,
967
+ dtype=[np.float64, np.float32],
968
+ force_all_finite=False,
969
+ )
970
+
971
+ if y.ndim == 2 and y.shape[1] == 1:
972
+ warnings.warn(
973
+ "A column-vector y was passed when a 1d array was"
974
+ " expected. Please change the shape of y to "
975
+ "(n_samples,), for example using ravel().",
976
+ DataConversionWarning,
977
+ stacklevel=2,
978
+ )
979
+
980
+ if y.ndim == 1:
981
+ # reshape is necessary to preserve the data contiguity against vs
982
+ # [:, np.newaxis] that does not.
983
+ y = np.reshape(y, (-1, 1))
984
+
985
+ self.n_outputs_ = y.shape[1]
986
+
987
+ patching_status.and_conditions(
988
+ [
989
+ (
990
+ self.n_outputs_ == 1,
991
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
992
+ )
993
+ ]
994
+ )
995
+
996
+ # Sklearn function used for doing checks on max_samples attribute
997
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
998
+
999
+ if not self.bootstrap and self.max_samples is not None:
1000
+ raise ValueError(
1001
+ "`max_sample` cannot be set if `bootstrap=False`. "
1002
+ "Either switch to `bootstrap=True` or set "
1003
+ "`max_sample=None`."
1004
+ )
1005
+
1006
+ if (
1007
+ patching_status.get_status()
1008
+ and (self.random_state is not None)
1009
+ and (not daal_check_version((2024, "P", 0)))
1010
+ ):
1011
+ warnings.warn(
1012
+ "Setting 'random_state' value is not supported. "
1013
+ "State set by oneDAL to default value (777).",
1014
+ RuntimeWarning,
1015
+ )
1016
+
1017
+ return patching_status, X, y, sample_weight
1018
+
1019
+ def _onedal_cpu_supported(self, method_name, *data):
1020
+ class_name = self.__class__.__name__
1021
+ patching_status = PatchingConditionsChain(
1022
+ f"sklearn.ensemble.{class_name}.{method_name}"
1023
+ )
1024
+
1025
+ if method_name == "fit":
1026
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1027
+ patching_status, *data
1028
+ )
1029
+
1030
+ patching_status.and_conditions(
1031
+ [
1032
+ (
1033
+ daal_check_version((2023, "P", 200))
1034
+ or self.estimator.__class__ == DecisionTreeClassifier,
1035
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1036
+ ),
1037
+ (
1038
+ not sp.issparse(sample_weight),
1039
+ "sample_weight is sparse. " "Sparse input is not supported.",
1040
+ ),
1041
+ ]
1042
+ )
1043
+
1044
+ elif method_name in ["predict", "score"]:
1045
+ X = data[0]
1046
+
1047
+ patching_status.and_conditions(
1048
+ [
1049
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1050
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1051
+ (self.warm_start is False, "Warm start is not supported."),
1052
+ (
1053
+ daal_check_version((2023, "P", 200))
1054
+ or self.estimator.__class__ == DecisionTreeClassifier,
1055
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1056
+ ),
1057
+ ]
1058
+ )
1059
+ if hasattr(self, "n_outputs_"):
1060
+ patching_status.and_conditions(
1061
+ [
1062
+ (
1063
+ self.n_outputs_ == 1,
1064
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1065
+ ),
1066
+ ]
1067
+ )
1068
+
1069
+ else:
1070
+ raise RuntimeError(
1071
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1072
+ )
1073
+
1074
+ return patching_status
1075
+
1076
+ def _onedal_gpu_supported(self, method_name, *data):
1077
+ class_name = self.__class__.__name__
1078
+ patching_status = PatchingConditionsChain(
1079
+ f"sklearn.ensemble.{class_name}.{method_name}"
1080
+ )
1081
+
1082
+ if method_name == "fit":
1083
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1084
+ patching_status, *data
1085
+ )
1086
+
1087
+ patching_status.and_conditions(
1088
+ [
1089
+ (
1090
+ daal_check_version((2023, "P", 100))
1091
+ or self.estimator.__class__ == DecisionTreeClassifier,
1092
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1093
+ ),
1094
+ (not self.oob_score, "oob_score value is not sklearn conformant."),
1095
+ (sample_weight is None, "sample_weight is not supported."),
1096
+ ]
1097
+ )
1098
+
1099
+ elif method_name in ["predict", "score"]:
1100
+ X = data[0]
1101
+
1102
+ patching_status.and_conditions(
1103
+ [
1104
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1105
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1106
+ (self.warm_start is False, "Warm start is not supported."),
1107
+ (
1108
+ daal_check_version((2023, "P", 100))
1109
+ or self.estimator.__class__ == DecisionTreeClassifier,
1110
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1111
+ ),
1112
+ ]
1113
+ )
1114
+ if hasattr(self, "n_outputs_"):
1115
+ patching_status.and_conditions(
1116
+ [
1117
+ (
1118
+ self.n_outputs_ == 1,
1119
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1120
+ ),
1121
+ ]
1122
+ )
1123
+
1124
+ else:
1125
+ raise RuntimeError(
1126
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1127
+ )
1128
+
1129
+ return patching_status
1130
+
1131
+ def _onedal_predict(self, X, queue=None):
1132
+ check_is_fitted(self, "_onedal_estimator")
1133
+
1134
+ if sklearn_check_version("1.0"):
1135
+ X = validate_data(
1136
+ self,
1137
+ X,
1138
+ dtype=[np.float64, np.float32],
1139
+ force_all_finite=False,
1140
+ reset=False,
1141
+ ensure_2d=True,
1142
+ ) # Warning, order of dtype matters
1143
+ else:
1144
+ X = check_array(
1145
+ X, dtype=[np.float64, np.float32], force_all_finite=False
1146
+ ) # Warning, order of dtype matters
1147
+
1148
+ return self._onedal_estimator.predict(X, queue=queue)
1149
+
1150
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
1151
+ return r2_score(
1152
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
1153
+ )
1154
+
1155
+ def fit(self, X, y, sample_weight=None):
1156
+ dispatch(
1157
+ self,
1158
+ "fit",
1159
+ {
1160
+ "onedal": self.__class__._onedal_fit,
1161
+ "sklearn": _sklearn_ForestRegressor.fit,
1162
+ },
1163
+ X,
1164
+ y,
1165
+ sample_weight,
1166
+ )
1167
+ return self
1168
+
1169
+ @wrap_output_data
1170
+ def predict(self, X):
1171
+ check_is_fitted(self)
1172
+ return dispatch(
1173
+ self,
1174
+ "predict",
1175
+ {
1176
+ "onedal": self.__class__._onedal_predict,
1177
+ "sklearn": _sklearn_ForestRegressor.predict,
1178
+ },
1179
+ X,
1180
+ )
1181
+
1182
+ @wrap_output_data
1183
+ def score(self, X, y, sample_weight=None):
1184
+ check_is_fitted(self)
1185
+ return dispatch(
1186
+ self,
1187
+ "score",
1188
+ {
1189
+ "onedal": self.__class__._onedal_score,
1190
+ "sklearn": _sklearn_ForestRegressor.score,
1191
+ },
1192
+ X,
1193
+ y,
1194
+ sample_weight=sample_weight,
1195
+ )
1196
+
1197
+ fit.__doc__ = _sklearn_ForestRegressor.fit.__doc__
1198
+ predict.__doc__ = _sklearn_ForestRegressor.predict.__doc__
1199
+ score.__doc__ = _sklearn_ForestRegressor.score.__doc__
1200
+
1201
+
1202
+ @register_hyperparameters({"infer": get_hyperparameters("decision_forest", "infer")})
1203
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1204
+ class RandomForestClassifier(ForestClassifier):
1205
+ __doc__ = _sklearn_RandomForestClassifier.__doc__
1206
+ _onedal_factory = onedal_RandomForestClassifier
1207
+
1208
+ if sklearn_check_version("1.2"):
1209
+ _parameter_constraints: dict = {
1210
+ **_sklearn_RandomForestClassifier._parameter_constraints,
1211
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1212
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1213
+ }
1214
+
1215
+ if sklearn_check_version("1.4"):
1216
+
1217
+ def __init__(
1218
+ self,
1219
+ n_estimators=100,
1220
+ *,
1221
+ criterion="gini",
1222
+ max_depth=None,
1223
+ min_samples_split=2,
1224
+ min_samples_leaf=1,
1225
+ min_weight_fraction_leaf=0.0,
1226
+ max_features="sqrt",
1227
+ max_leaf_nodes=None,
1228
+ min_impurity_decrease=0.0,
1229
+ bootstrap=True,
1230
+ oob_score=False,
1231
+ n_jobs=None,
1232
+ random_state=None,
1233
+ verbose=0,
1234
+ warm_start=False,
1235
+ class_weight=None,
1236
+ ccp_alpha=0.0,
1237
+ max_samples=None,
1238
+ monotonic_cst=None,
1239
+ max_bins=256,
1240
+ min_bin_size=1,
1241
+ ):
1242
+ super().__init__(
1243
+ DecisionTreeClassifier(),
1244
+ n_estimators,
1245
+ estimator_params=(
1246
+ "criterion",
1247
+ "max_depth",
1248
+ "min_samples_split",
1249
+ "min_samples_leaf",
1250
+ "min_weight_fraction_leaf",
1251
+ "max_features",
1252
+ "max_leaf_nodes",
1253
+ "min_impurity_decrease",
1254
+ "random_state",
1255
+ "ccp_alpha",
1256
+ "monotonic_cst",
1257
+ ),
1258
+ bootstrap=bootstrap,
1259
+ oob_score=oob_score,
1260
+ n_jobs=n_jobs,
1261
+ random_state=random_state,
1262
+ verbose=verbose,
1263
+ warm_start=warm_start,
1264
+ class_weight=class_weight,
1265
+ max_samples=max_samples,
1266
+ )
1267
+
1268
+ self.criterion = criterion
1269
+ self.max_depth = max_depth
1270
+ self.min_samples_split = min_samples_split
1271
+ self.min_samples_leaf = min_samples_leaf
1272
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1273
+ self.max_features = max_features
1274
+ self.max_leaf_nodes = max_leaf_nodes
1275
+ self.min_impurity_decrease = min_impurity_decrease
1276
+ self.ccp_alpha = ccp_alpha
1277
+ self.max_bins = max_bins
1278
+ self.min_bin_size = min_bin_size
1279
+ self.monotonic_cst = monotonic_cst
1280
+
1281
+ elif sklearn_check_version("1.0"):
1282
+
1283
+ def __init__(
1284
+ self,
1285
+ n_estimators=100,
1286
+ *,
1287
+ criterion="gini",
1288
+ max_depth=None,
1289
+ min_samples_split=2,
1290
+ min_samples_leaf=1,
1291
+ min_weight_fraction_leaf=0.0,
1292
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1293
+ max_leaf_nodes=None,
1294
+ min_impurity_decrease=0.0,
1295
+ bootstrap=True,
1296
+ oob_score=False,
1297
+ n_jobs=None,
1298
+ random_state=None,
1299
+ verbose=0,
1300
+ warm_start=False,
1301
+ class_weight=None,
1302
+ ccp_alpha=0.0,
1303
+ max_samples=None,
1304
+ max_bins=256,
1305
+ min_bin_size=1,
1306
+ ):
1307
+ super().__init__(
1308
+ DecisionTreeClassifier(),
1309
+ n_estimators,
1310
+ estimator_params=(
1311
+ "criterion",
1312
+ "max_depth",
1313
+ "min_samples_split",
1314
+ "min_samples_leaf",
1315
+ "min_weight_fraction_leaf",
1316
+ "max_features",
1317
+ "max_leaf_nodes",
1318
+ "min_impurity_decrease",
1319
+ "random_state",
1320
+ "ccp_alpha",
1321
+ ),
1322
+ bootstrap=bootstrap,
1323
+ oob_score=oob_score,
1324
+ n_jobs=n_jobs,
1325
+ random_state=random_state,
1326
+ verbose=verbose,
1327
+ warm_start=warm_start,
1328
+ class_weight=class_weight,
1329
+ max_samples=max_samples,
1330
+ )
1331
+
1332
+ self.criterion = criterion
1333
+ self.max_depth = max_depth
1334
+ self.min_samples_split = min_samples_split
1335
+ self.min_samples_leaf = min_samples_leaf
1336
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1337
+ self.max_features = max_features
1338
+ self.max_leaf_nodes = max_leaf_nodes
1339
+ self.min_impurity_decrease = min_impurity_decrease
1340
+ self.ccp_alpha = ccp_alpha
1341
+ self.max_bins = max_bins
1342
+ self.min_bin_size = min_bin_size
1343
+
1344
+ else:
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ n_estimators=100,
1349
+ *,
1350
+ criterion="gini",
1351
+ max_depth=None,
1352
+ min_samples_split=2,
1353
+ min_samples_leaf=1,
1354
+ min_weight_fraction_leaf=0.0,
1355
+ max_features="auto",
1356
+ max_leaf_nodes=None,
1357
+ min_impurity_decrease=0.0,
1358
+ min_impurity_split=None,
1359
+ bootstrap=True,
1360
+ oob_score=False,
1361
+ n_jobs=None,
1362
+ random_state=None,
1363
+ verbose=0,
1364
+ warm_start=False,
1365
+ class_weight=None,
1366
+ ccp_alpha=0.0,
1367
+ max_samples=None,
1368
+ max_bins=256,
1369
+ min_bin_size=1,
1370
+ ):
1371
+ super().__init__(
1372
+ DecisionTreeClassifier(),
1373
+ n_estimators,
1374
+ estimator_params=(
1375
+ "criterion",
1376
+ "max_depth",
1377
+ "min_samples_split",
1378
+ "min_samples_leaf",
1379
+ "min_weight_fraction_leaf",
1380
+ "max_features",
1381
+ "max_leaf_nodes",
1382
+ "min_impurity_decrease",
1383
+ "min_impurity_split",
1384
+ "random_state",
1385
+ "ccp_alpha",
1386
+ ),
1387
+ bootstrap=bootstrap,
1388
+ oob_score=oob_score,
1389
+ n_jobs=n_jobs,
1390
+ random_state=random_state,
1391
+ verbose=verbose,
1392
+ warm_start=warm_start,
1393
+ class_weight=class_weight,
1394
+ max_samples=max_samples,
1395
+ )
1396
+
1397
+ self.criterion = criterion
1398
+ self.max_depth = max_depth
1399
+ self.min_samples_split = min_samples_split
1400
+ self.min_samples_leaf = min_samples_leaf
1401
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1402
+ self.max_features = max_features
1403
+ self.max_leaf_nodes = max_leaf_nodes
1404
+ self.min_impurity_decrease = min_impurity_decrease
1405
+ self.min_impurity_split = min_impurity_split
1406
+ self.ccp_alpha = ccp_alpha
1407
+ self.max_bins = max_bins
1408
+ self.min_bin_size = min_bin_size
1409
+ self.max_bins = max_bins
1410
+ self.min_bin_size = min_bin_size
1411
+
1412
+
1413
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
1414
+ class RandomForestRegressor(ForestRegressor):
1415
+ __doc__ = _sklearn_RandomForestRegressor.__doc__
1416
+ _onedal_factory = onedal_RandomForestRegressor
1417
+
1418
+ if sklearn_check_version("1.2"):
1419
+ _parameter_constraints: dict = {
1420
+ **_sklearn_RandomForestRegressor._parameter_constraints,
1421
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1422
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1423
+ }
1424
+
1425
+ if sklearn_check_version("1.4"):
1426
+
1427
+ def __init__(
1428
+ self,
1429
+ n_estimators=100,
1430
+ *,
1431
+ criterion="squared_error",
1432
+ max_depth=None,
1433
+ min_samples_split=2,
1434
+ min_samples_leaf=1,
1435
+ min_weight_fraction_leaf=0.0,
1436
+ max_features=1.0,
1437
+ max_leaf_nodes=None,
1438
+ min_impurity_decrease=0.0,
1439
+ bootstrap=True,
1440
+ oob_score=False,
1441
+ n_jobs=None,
1442
+ random_state=None,
1443
+ verbose=0,
1444
+ warm_start=False,
1445
+ ccp_alpha=0.0,
1446
+ max_samples=None,
1447
+ monotonic_cst=None,
1448
+ max_bins=256,
1449
+ min_bin_size=1,
1450
+ ):
1451
+ super().__init__(
1452
+ DecisionTreeRegressor(),
1453
+ n_estimators=n_estimators,
1454
+ estimator_params=(
1455
+ "criterion",
1456
+ "max_depth",
1457
+ "min_samples_split",
1458
+ "min_samples_leaf",
1459
+ "min_weight_fraction_leaf",
1460
+ "max_features",
1461
+ "max_leaf_nodes",
1462
+ "min_impurity_decrease",
1463
+ "random_state",
1464
+ "ccp_alpha",
1465
+ "monotonic_cst",
1466
+ ),
1467
+ bootstrap=bootstrap,
1468
+ oob_score=oob_score,
1469
+ n_jobs=n_jobs,
1470
+ random_state=random_state,
1471
+ verbose=verbose,
1472
+ warm_start=warm_start,
1473
+ max_samples=max_samples,
1474
+ )
1475
+
1476
+ self.criterion = criterion
1477
+ self.max_depth = max_depth
1478
+ self.min_samples_split = min_samples_split
1479
+ self.min_samples_leaf = min_samples_leaf
1480
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1481
+ self.max_features = max_features
1482
+ self.max_leaf_nodes = max_leaf_nodes
1483
+ self.min_impurity_decrease = min_impurity_decrease
1484
+ self.ccp_alpha = ccp_alpha
1485
+ self.max_bins = max_bins
1486
+ self.min_bin_size = min_bin_size
1487
+ self.monotonic_cst = monotonic_cst
1488
+
1489
+ elif sklearn_check_version("1.0"):
1490
+
1491
+ def __init__(
1492
+ self,
1493
+ n_estimators=100,
1494
+ *,
1495
+ criterion="squared_error",
1496
+ max_depth=None,
1497
+ min_samples_split=2,
1498
+ min_samples_leaf=1,
1499
+ min_weight_fraction_leaf=0.0,
1500
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1501
+ max_leaf_nodes=None,
1502
+ min_impurity_decrease=0.0,
1503
+ bootstrap=True,
1504
+ oob_score=False,
1505
+ n_jobs=None,
1506
+ random_state=None,
1507
+ verbose=0,
1508
+ warm_start=False,
1509
+ ccp_alpha=0.0,
1510
+ max_samples=None,
1511
+ max_bins=256,
1512
+ min_bin_size=1,
1513
+ ):
1514
+ super().__init__(
1515
+ DecisionTreeRegressor(),
1516
+ n_estimators=n_estimators,
1517
+ estimator_params=(
1518
+ "criterion",
1519
+ "max_depth",
1520
+ "min_samples_split",
1521
+ "min_samples_leaf",
1522
+ "min_weight_fraction_leaf",
1523
+ "max_features",
1524
+ "max_leaf_nodes",
1525
+ "min_impurity_decrease",
1526
+ "random_state",
1527
+ "ccp_alpha",
1528
+ ),
1529
+ bootstrap=bootstrap,
1530
+ oob_score=oob_score,
1531
+ n_jobs=n_jobs,
1532
+ random_state=random_state,
1533
+ verbose=verbose,
1534
+ warm_start=warm_start,
1535
+ max_samples=max_samples,
1536
+ )
1537
+
1538
+ self.criterion = criterion
1539
+ self.max_depth = max_depth
1540
+ self.min_samples_split = min_samples_split
1541
+ self.min_samples_leaf = min_samples_leaf
1542
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1543
+ self.max_features = max_features
1544
+ self.max_leaf_nodes = max_leaf_nodes
1545
+ self.min_impurity_decrease = min_impurity_decrease
1546
+ self.ccp_alpha = ccp_alpha
1547
+ self.max_bins = max_bins
1548
+ self.min_bin_size = min_bin_size
1549
+
1550
+ else:
1551
+
1552
+ def __init__(
1553
+ self,
1554
+ n_estimators=100,
1555
+ *,
1556
+ criterion="mse",
1557
+ max_depth=None,
1558
+ min_samples_split=2,
1559
+ min_samples_leaf=1,
1560
+ min_weight_fraction_leaf=0.0,
1561
+ max_features="auto",
1562
+ max_leaf_nodes=None,
1563
+ min_impurity_decrease=0.0,
1564
+ min_impurity_split=None,
1565
+ bootstrap=True,
1566
+ oob_score=False,
1567
+ n_jobs=None,
1568
+ random_state=None,
1569
+ verbose=0,
1570
+ warm_start=False,
1571
+ ccp_alpha=0.0,
1572
+ max_samples=None,
1573
+ max_bins=256,
1574
+ min_bin_size=1,
1575
+ ):
1576
+ super().__init__(
1577
+ DecisionTreeRegressor(),
1578
+ n_estimators=n_estimators,
1579
+ estimator_params=(
1580
+ "criterion",
1581
+ "max_depth",
1582
+ "min_samples_split",
1583
+ "min_samples_leaf",
1584
+ "min_weight_fraction_leaf",
1585
+ "max_features",
1586
+ "max_leaf_nodes",
1587
+ "min_impurity_decrease",
1588
+ "min_impurity_split" "random_state",
1589
+ "ccp_alpha",
1590
+ ),
1591
+ bootstrap=bootstrap,
1592
+ oob_score=oob_score,
1593
+ n_jobs=n_jobs,
1594
+ random_state=random_state,
1595
+ verbose=verbose,
1596
+ warm_start=warm_start,
1597
+ max_samples=max_samples,
1598
+ )
1599
+
1600
+ self.criterion = criterion
1601
+ self.max_depth = max_depth
1602
+ self.min_samples_split = min_samples_split
1603
+ self.min_samples_leaf = min_samples_leaf
1604
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1605
+ self.max_features = max_features
1606
+ self.max_leaf_nodes = max_leaf_nodes
1607
+ self.min_impurity_decrease = min_impurity_decrease
1608
+ self.min_impurity_split = min_impurity_split
1609
+ self.ccp_alpha = ccp_alpha
1610
+ self.max_bins = max_bins
1611
+ self.min_bin_size = min_bin_size
1612
+
1613
+
1614
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1615
+ class ExtraTreesClassifier(ForestClassifier):
1616
+ __doc__ = _sklearn_ExtraTreesClassifier.__doc__
1617
+ _onedal_factory = onedal_ExtraTreesClassifier
1618
+
1619
+ if sklearn_check_version("1.2"):
1620
+ _parameter_constraints: dict = {
1621
+ **_sklearn_ExtraTreesClassifier._parameter_constraints,
1622
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1623
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1624
+ }
1625
+
1626
+ if sklearn_check_version("1.4"):
1627
+
1628
+ def __init__(
1629
+ self,
1630
+ n_estimators=100,
1631
+ *,
1632
+ criterion="gini",
1633
+ max_depth=None,
1634
+ min_samples_split=2,
1635
+ min_samples_leaf=1,
1636
+ min_weight_fraction_leaf=0.0,
1637
+ max_features="sqrt",
1638
+ max_leaf_nodes=None,
1639
+ min_impurity_decrease=0.0,
1640
+ bootstrap=False,
1641
+ oob_score=False,
1642
+ n_jobs=None,
1643
+ random_state=None,
1644
+ verbose=0,
1645
+ warm_start=False,
1646
+ class_weight=None,
1647
+ ccp_alpha=0.0,
1648
+ max_samples=None,
1649
+ monotonic_cst=None,
1650
+ max_bins=256,
1651
+ min_bin_size=1,
1652
+ ):
1653
+ super().__init__(
1654
+ ExtraTreeClassifier(),
1655
+ n_estimators,
1656
+ estimator_params=(
1657
+ "criterion",
1658
+ "max_depth",
1659
+ "min_samples_split",
1660
+ "min_samples_leaf",
1661
+ "min_weight_fraction_leaf",
1662
+ "max_features",
1663
+ "max_leaf_nodes",
1664
+ "min_impurity_decrease",
1665
+ "random_state",
1666
+ "ccp_alpha",
1667
+ "monotonic_cst",
1668
+ ),
1669
+ bootstrap=bootstrap,
1670
+ oob_score=oob_score,
1671
+ n_jobs=n_jobs,
1672
+ random_state=random_state,
1673
+ verbose=verbose,
1674
+ warm_start=warm_start,
1675
+ class_weight=class_weight,
1676
+ max_samples=max_samples,
1677
+ )
1678
+
1679
+ self.criterion = criterion
1680
+ self.max_depth = max_depth
1681
+ self.min_samples_split = min_samples_split
1682
+ self.min_samples_leaf = min_samples_leaf
1683
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1684
+ self.max_features = max_features
1685
+ self.max_leaf_nodes = max_leaf_nodes
1686
+ self.min_impurity_decrease = min_impurity_decrease
1687
+ self.ccp_alpha = ccp_alpha
1688
+ self.max_bins = max_bins
1689
+ self.min_bin_size = min_bin_size
1690
+ self.monotonic_cst = monotonic_cst
1691
+
1692
+ elif sklearn_check_version("1.0"):
1693
+
1694
+ def __init__(
1695
+ self,
1696
+ n_estimators=100,
1697
+ *,
1698
+ criterion="gini",
1699
+ max_depth=None,
1700
+ min_samples_split=2,
1701
+ min_samples_leaf=1,
1702
+ min_weight_fraction_leaf=0.0,
1703
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1704
+ max_leaf_nodes=None,
1705
+ min_impurity_decrease=0.0,
1706
+ bootstrap=False,
1707
+ oob_score=False,
1708
+ n_jobs=None,
1709
+ random_state=None,
1710
+ verbose=0,
1711
+ warm_start=False,
1712
+ class_weight=None,
1713
+ ccp_alpha=0.0,
1714
+ max_samples=None,
1715
+ max_bins=256,
1716
+ min_bin_size=1,
1717
+ ):
1718
+ super().__init__(
1719
+ ExtraTreeClassifier(),
1720
+ n_estimators,
1721
+ estimator_params=(
1722
+ "criterion",
1723
+ "max_depth",
1724
+ "min_samples_split",
1725
+ "min_samples_leaf",
1726
+ "min_weight_fraction_leaf",
1727
+ "max_features",
1728
+ "max_leaf_nodes",
1729
+ "min_impurity_decrease",
1730
+ "random_state",
1731
+ "ccp_alpha",
1732
+ ),
1733
+ bootstrap=bootstrap,
1734
+ oob_score=oob_score,
1735
+ n_jobs=n_jobs,
1736
+ random_state=random_state,
1737
+ verbose=verbose,
1738
+ warm_start=warm_start,
1739
+ class_weight=class_weight,
1740
+ max_samples=max_samples,
1741
+ )
1742
+
1743
+ self.criterion = criterion
1744
+ self.max_depth = max_depth
1745
+ self.min_samples_split = min_samples_split
1746
+ self.min_samples_leaf = min_samples_leaf
1747
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1748
+ self.max_features = max_features
1749
+ self.max_leaf_nodes = max_leaf_nodes
1750
+ self.min_impurity_decrease = min_impurity_decrease
1751
+ self.ccp_alpha = ccp_alpha
1752
+ self.max_bins = max_bins
1753
+ self.min_bin_size = min_bin_size
1754
+
1755
+ else:
1756
+
1757
+ def __init__(
1758
+ self,
1759
+ n_estimators=100,
1760
+ *,
1761
+ criterion="gini",
1762
+ max_depth=None,
1763
+ min_samples_split=2,
1764
+ min_samples_leaf=1,
1765
+ min_weight_fraction_leaf=0.0,
1766
+ max_features="auto",
1767
+ max_leaf_nodes=None,
1768
+ min_impurity_decrease=0.0,
1769
+ min_impurity_split=None,
1770
+ bootstrap=False,
1771
+ oob_score=False,
1772
+ n_jobs=None,
1773
+ random_state=None,
1774
+ verbose=0,
1775
+ warm_start=False,
1776
+ class_weight=None,
1777
+ ccp_alpha=0.0,
1778
+ max_samples=None,
1779
+ max_bins=256,
1780
+ min_bin_size=1,
1781
+ ):
1782
+ super().__init__(
1783
+ ExtraTreeClassifier(),
1784
+ n_estimators,
1785
+ estimator_params=(
1786
+ "criterion",
1787
+ "max_depth",
1788
+ "min_samples_split",
1789
+ "min_samples_leaf",
1790
+ "min_weight_fraction_leaf",
1791
+ "max_features",
1792
+ "max_leaf_nodes",
1793
+ "min_impurity_decrease",
1794
+ "min_impurity_split",
1795
+ "random_state",
1796
+ "ccp_alpha",
1797
+ ),
1798
+ bootstrap=bootstrap,
1799
+ oob_score=oob_score,
1800
+ n_jobs=n_jobs,
1801
+ random_state=random_state,
1802
+ verbose=verbose,
1803
+ warm_start=warm_start,
1804
+ class_weight=class_weight,
1805
+ max_samples=max_samples,
1806
+ )
1807
+
1808
+ self.criterion = criterion
1809
+ self.max_depth = max_depth
1810
+ self.min_samples_split = min_samples_split
1811
+ self.min_samples_leaf = min_samples_leaf
1812
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1813
+ self.max_features = max_features
1814
+ self.max_leaf_nodes = max_leaf_nodes
1815
+ self.min_impurity_decrease = min_impurity_decrease
1816
+ self.min_impurity_split = min_impurity_split
1817
+ self.ccp_alpha = ccp_alpha
1818
+ self.max_bins = max_bins
1819
+ self.min_bin_size = min_bin_size
1820
+ self.max_bins = max_bins
1821
+ self.min_bin_size = min_bin_size
1822
+
1823
+
1824
+ @control_n_jobs(decorated_methods=["fit", "predict", "score"])
1825
+ class ExtraTreesRegressor(ForestRegressor):
1826
+ __doc__ = _sklearn_ExtraTreesRegressor.__doc__
1827
+ _onedal_factory = onedal_ExtraTreesRegressor
1828
+
1829
+ if sklearn_check_version("1.2"):
1830
+ _parameter_constraints: dict = {
1831
+ **_sklearn_ExtraTreesRegressor._parameter_constraints,
1832
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1833
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1834
+ }
1835
+
1836
+ if sklearn_check_version("1.4"):
1837
+
1838
+ def __init__(
1839
+ self,
1840
+ n_estimators=100,
1841
+ *,
1842
+ criterion="squared_error",
1843
+ max_depth=None,
1844
+ min_samples_split=2,
1845
+ min_samples_leaf=1,
1846
+ min_weight_fraction_leaf=0.0,
1847
+ max_features=1.0,
1848
+ max_leaf_nodes=None,
1849
+ min_impurity_decrease=0.0,
1850
+ bootstrap=False,
1851
+ oob_score=False,
1852
+ n_jobs=None,
1853
+ random_state=None,
1854
+ verbose=0,
1855
+ warm_start=False,
1856
+ ccp_alpha=0.0,
1857
+ max_samples=None,
1858
+ monotonic_cst=None,
1859
+ max_bins=256,
1860
+ min_bin_size=1,
1861
+ ):
1862
+ super().__init__(
1863
+ ExtraTreeRegressor(),
1864
+ n_estimators=n_estimators,
1865
+ estimator_params=(
1866
+ "criterion",
1867
+ "max_depth",
1868
+ "min_samples_split",
1869
+ "min_samples_leaf",
1870
+ "min_weight_fraction_leaf",
1871
+ "max_features",
1872
+ "max_leaf_nodes",
1873
+ "min_impurity_decrease",
1874
+ "random_state",
1875
+ "ccp_alpha",
1876
+ "monotonic_cst",
1877
+ ),
1878
+ bootstrap=bootstrap,
1879
+ oob_score=oob_score,
1880
+ n_jobs=n_jobs,
1881
+ random_state=random_state,
1882
+ verbose=verbose,
1883
+ warm_start=warm_start,
1884
+ max_samples=max_samples,
1885
+ )
1886
+
1887
+ self.criterion = criterion
1888
+ self.max_depth = max_depth
1889
+ self.min_samples_split = min_samples_split
1890
+ self.min_samples_leaf = min_samples_leaf
1891
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1892
+ self.max_features = max_features
1893
+ self.max_leaf_nodes = max_leaf_nodes
1894
+ self.min_impurity_decrease = min_impurity_decrease
1895
+ self.ccp_alpha = ccp_alpha
1896
+ self.max_bins = max_bins
1897
+ self.min_bin_size = min_bin_size
1898
+ self.monotonic_cst = monotonic_cst
1899
+
1900
+ elif sklearn_check_version("1.0"):
1901
+
1902
+ def __init__(
1903
+ self,
1904
+ n_estimators=100,
1905
+ *,
1906
+ criterion="squared_error",
1907
+ max_depth=None,
1908
+ min_samples_split=2,
1909
+ min_samples_leaf=1,
1910
+ min_weight_fraction_leaf=0.0,
1911
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1912
+ max_leaf_nodes=None,
1913
+ min_impurity_decrease=0.0,
1914
+ bootstrap=False,
1915
+ oob_score=False,
1916
+ n_jobs=None,
1917
+ random_state=None,
1918
+ verbose=0,
1919
+ warm_start=False,
1920
+ ccp_alpha=0.0,
1921
+ max_samples=None,
1922
+ max_bins=256,
1923
+ min_bin_size=1,
1924
+ ):
1925
+ super().__init__(
1926
+ ExtraTreeRegressor(),
1927
+ n_estimators=n_estimators,
1928
+ estimator_params=(
1929
+ "criterion",
1930
+ "max_depth",
1931
+ "min_samples_split",
1932
+ "min_samples_leaf",
1933
+ "min_weight_fraction_leaf",
1934
+ "max_features",
1935
+ "max_leaf_nodes",
1936
+ "min_impurity_decrease",
1937
+ "random_state",
1938
+ "ccp_alpha",
1939
+ ),
1940
+ bootstrap=bootstrap,
1941
+ oob_score=oob_score,
1942
+ n_jobs=n_jobs,
1943
+ random_state=random_state,
1944
+ verbose=verbose,
1945
+ warm_start=warm_start,
1946
+ max_samples=max_samples,
1947
+ )
1948
+
1949
+ self.criterion = criterion
1950
+ self.max_depth = max_depth
1951
+ self.min_samples_split = min_samples_split
1952
+ self.min_samples_leaf = min_samples_leaf
1953
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1954
+ self.max_features = max_features
1955
+ self.max_leaf_nodes = max_leaf_nodes
1956
+ self.min_impurity_decrease = min_impurity_decrease
1957
+ self.ccp_alpha = ccp_alpha
1958
+ self.max_bins = max_bins
1959
+ self.min_bin_size = min_bin_size
1960
+
1961
+ else:
1962
+
1963
+ def __init__(
1964
+ self,
1965
+ n_estimators=100,
1966
+ *,
1967
+ criterion="mse",
1968
+ max_depth=None,
1969
+ min_samples_split=2,
1970
+ min_samples_leaf=1,
1971
+ min_weight_fraction_leaf=0.0,
1972
+ max_features="auto",
1973
+ max_leaf_nodes=None,
1974
+ min_impurity_decrease=0.0,
1975
+ min_impurity_split=None,
1976
+ bootstrap=False,
1977
+ oob_score=False,
1978
+ n_jobs=None,
1979
+ random_state=None,
1980
+ verbose=0,
1981
+ warm_start=False,
1982
+ ccp_alpha=0.0,
1983
+ max_samples=None,
1984
+ max_bins=256,
1985
+ min_bin_size=1,
1986
+ ):
1987
+ super().__init__(
1988
+ ExtraTreeRegressor(),
1989
+ n_estimators=n_estimators,
1990
+ estimator_params=(
1991
+ "criterion",
1992
+ "max_depth",
1993
+ "min_samples_split",
1994
+ "min_samples_leaf",
1995
+ "min_weight_fraction_leaf",
1996
+ "max_features",
1997
+ "max_leaf_nodes",
1998
+ "min_impurity_decrease",
1999
+ "min_impurity_split" "random_state",
2000
+ "ccp_alpha",
2001
+ ),
2002
+ bootstrap=bootstrap,
2003
+ oob_score=oob_score,
2004
+ n_jobs=n_jobs,
2005
+ random_state=random_state,
2006
+ verbose=verbose,
2007
+ warm_start=warm_start,
2008
+ max_samples=max_samples,
2009
+ )
2010
+
2011
+ self.criterion = criterion
2012
+ self.max_depth = max_depth
2013
+ self.min_samples_split = min_samples_split
2014
+ self.min_samples_leaf = min_samples_leaf
2015
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
2016
+ self.max_features = max_features
2017
+ self.max_leaf_nodes = max_leaf_nodes
2018
+ self.min_impurity_decrease = min_impurity_decrease
2019
+ self.min_impurity_split = min_impurity_split
2020
+ self.ccp_alpha = ccp_alpha
2021
+ self.max_bins = max_bins
2022
+ self.min_bin_size = min_bin_size
2023
+
2024
+
2025
+ # Allow for isinstance calls without inheritance changes using ABCMeta
2026
+ _sklearn_RandomForestClassifier.register(RandomForestClassifier)
2027
+ _sklearn_RandomForestRegressor.register(RandomForestRegressor)
2028
+ _sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
2029
+ _sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)