scikit-learn-intelex 2025.1.0__py311-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-311-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-311-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-311-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,1397 @@
1
+ # ==============================================================================
2
+ # Copyright 2014 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from math import ceil
20
+
21
+ import numpy as np
22
+ from scipy import sparse as sp
23
+ from sklearn.base import clone
24
+ from sklearn.ensemble import RandomForestClassifier as RandomForestClassifier_original
25
+ from sklearn.ensemble import RandomForestRegressor as RandomForestRegressor_original
26
+ from sklearn.exceptions import DataConversionWarning
27
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
28
+ from sklearn.tree._tree import Tree
29
+ from sklearn.utils import check_array, check_random_state, deprecated
30
+ from sklearn.utils.validation import (
31
+ _num_samples,
32
+ check_consistent_length,
33
+ check_is_fitted,
34
+ )
35
+
36
+ import daal4py
37
+ from daal4py.sklearn._utils import (
38
+ PatchingConditionsChain,
39
+ check_tree_nodes,
40
+ daal_check_version,
41
+ getFPType,
42
+ sklearn_check_version,
43
+ )
44
+
45
+ from .._n_jobs_support import control_n_jobs
46
+ from ..utils.validation import _daal_num_features
47
+
48
+ if sklearn_check_version("1.2"):
49
+ from sklearn.utils._param_validation import Interval, StrOptions
50
+ if sklearn_check_version("1.4"):
51
+ from daal4py.sklearn.utils import _assert_all_finite
52
+
53
+
54
+ def _to_absolute_max_features(max_features, n_features, is_classification=False):
55
+ if max_features is None:
56
+ return n_features
57
+ if isinstance(max_features, str):
58
+ if max_features == "auto":
59
+ if not sklearn_check_version("1.3"):
60
+ if sklearn_check_version("1.1"):
61
+ warnings.warn(
62
+ "`max_features='auto'` has been deprecated in 1.1 "
63
+ "and will be removed in 1.3. To keep the past behaviour, "
64
+ "explicitly set `max_features=1.0` or remove this "
65
+ "parameter as it is also the default value for "
66
+ "RandomForestRegressors and ExtraTreesRegressors.",
67
+ FutureWarning,
68
+ )
69
+ return (
70
+ max(1, int(np.sqrt(n_features))) if is_classification else n_features
71
+ )
72
+ if max_features == "sqrt":
73
+ return max(1, int(np.sqrt(n_features)))
74
+ if max_features == "log2":
75
+ return max(1, int(np.log2(n_features)))
76
+ allowed_string_values = (
77
+ '"sqrt" or "log2"'
78
+ if sklearn_check_version("1.3")
79
+ else '"auto", "sqrt" or "log2"'
80
+ )
81
+ raise ValueError(
82
+ "Invalid value for max_features. Allowed string "
83
+ f"values are {allowed_string_values}."
84
+ )
85
+ if isinstance(max_features, (numbers.Integral, np.integer)):
86
+ return max_features
87
+ if max_features > 0.0:
88
+ return max(1, int(max_features * n_features))
89
+ return 0
90
+
91
+
92
+ def _get_n_samples_bootstrap(n_samples, max_samples):
93
+ if max_samples is None:
94
+ return 1.0
95
+
96
+ if isinstance(max_samples, numbers.Integral):
97
+ if not sklearn_check_version("1.2"):
98
+ if not (1 <= max_samples <= n_samples):
99
+ msg = "`max_samples` must be in range 1 to {} but got value {}"
100
+ raise ValueError(msg.format(n_samples, max_samples))
101
+ else:
102
+ if max_samples > n_samples:
103
+ msg = "`max_samples` must be <= n_samples={} but got value {}"
104
+ raise ValueError(msg.format(n_samples, max_samples))
105
+ return max(float(max_samples / n_samples), 1 / n_samples)
106
+
107
+ if isinstance(max_samples, numbers.Real):
108
+ if sklearn_check_version("1.2"):
109
+ pass
110
+ elif sklearn_check_version("1.0"):
111
+ if not (0 < float(max_samples) <= 1):
112
+ msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
113
+ raise ValueError(msg.format(max_samples))
114
+ else:
115
+ if not (0 < float(max_samples) < 1):
116
+ msg = "`max_samples` must be in range (0, 1) but got value {}"
117
+ raise ValueError(msg.format(max_samples))
118
+ return max(float(max_samples), 1 / n_samples)
119
+
120
+ msg = "`max_samples` should be int or float, but got type '{}'"
121
+ raise TypeError(msg.format(type(max_samples)))
122
+
123
+
124
+ def check_sample_weight(sample_weight, X, dtype=None):
125
+ n_samples = _num_samples(X)
126
+
127
+ if dtype is not None and dtype not in [np.float32, np.float64]:
128
+ dtype = np.float64
129
+
130
+ if sample_weight is None:
131
+ sample_weight = np.ones(n_samples, dtype=dtype)
132
+ elif isinstance(sample_weight, numbers.Number):
133
+ sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
134
+ else:
135
+ if dtype is None:
136
+ dtype = [np.float64, np.float32]
137
+ sample_weight = check_array(
138
+ sample_weight, accept_sparse=False, ensure_2d=False, dtype=dtype, order="C"
139
+ )
140
+ if sample_weight.ndim != 1:
141
+ raise ValueError("Sample weights must be 1D array or scalar")
142
+
143
+ if sample_weight.shape != (n_samples,):
144
+ raise ValueError(
145
+ "sample_weight.shape == {}, expected {}!".format(
146
+ sample_weight.shape, (n_samples,)
147
+ )
148
+ )
149
+ return sample_weight
150
+
151
+
152
+ class RandomForestBase:
153
+ def fit(self, X, y, sample_weight=None): ...
154
+
155
+ def predict(self, X): ...
156
+
157
+ def _check_parameters(self) -> None:
158
+ if not self.bootstrap and self.max_samples is not None:
159
+ raise ValueError(
160
+ "`max_sample` cannot be set if `bootstrap=False`. "
161
+ "Either switch to `bootstrap=True` or set "
162
+ "`max_sample=None`."
163
+ )
164
+ if isinstance(self.min_samples_leaf, numbers.Integral):
165
+ if not 1 <= self.min_samples_leaf:
166
+ raise ValueError(
167
+ "min_samples_leaf must be at least 1 "
168
+ f"or in (0, 0.5], got {self.min_samples_leaf}"
169
+ )
170
+ else: # float
171
+ if not 0.0 < self.min_samples_leaf <= 0.5:
172
+ raise ValueError(
173
+ "min_samples_leaf must be at least 1 "
174
+ f"or in (0, 0.5], got {self.min_samples_leaf}"
175
+ )
176
+ if isinstance(self.min_samples_split, numbers.Integral):
177
+ if not 2 <= self.min_samples_split:
178
+ raise ValueError(
179
+ "min_samples_split must be an integer "
180
+ "greater than 1 or a float in (0.0, 1.0]; "
181
+ f"got the integer {self.min_samples_split}"
182
+ )
183
+ else: # float
184
+ if not 0.0 < self.min_samples_split <= 1.0:
185
+ raise ValueError(
186
+ "min_samples_split must be an integer "
187
+ "greater than 1 or a float in (0.0, 1.0]; "
188
+ "got the float {self.min_samples_split}"
189
+ )
190
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
191
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
192
+ if self.min_impurity_split is not None:
193
+ warnings.warn(
194
+ "The min_impurity_split parameter is deprecated. "
195
+ "Its default value has changed from 1e-7 to 0 in "
196
+ "version 0.23, and it will be removed in 0.25. "
197
+ "Use the min_impurity_decrease parameter instead.",
198
+ FutureWarning,
199
+ )
200
+
201
+ if self.min_impurity_split < 0.0:
202
+ raise ValueError(
203
+ "min_impurity_split must be greater " "than or equal to 0"
204
+ )
205
+ if self.min_impurity_decrease < 0.0:
206
+ raise ValueError("min_impurity_decrease must be greater than or equal to 0")
207
+ if self.max_leaf_nodes is not None:
208
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
209
+ raise ValueError(
210
+ "max_leaf_nodes must be integral number but was "
211
+ f"{self.max_leaf_nodes}"
212
+ )
213
+ if self.max_leaf_nodes < 2:
214
+ raise ValueError(
215
+ f"max_leaf_nodes {self.max_leaf_nodes} must be either None "
216
+ "or larger than 1"
217
+ )
218
+ if isinstance(self.maxBins, numbers.Integral):
219
+ if not 2 <= self.maxBins:
220
+ raise ValueError(f"maxBins must be at least 2, got {self.maxBins}")
221
+ else:
222
+ raise ValueError(f"maxBins must be integral number but was {self.maxBins}")
223
+ if isinstance(self.minBinSize, numbers.Integral):
224
+ if not 1 <= self.minBinSize:
225
+ raise ValueError(f"minBinSize must be at least 1, got {self.minBinSize}")
226
+ else:
227
+ raise ValueError(
228
+ f"minBinSize must be integral number but was {self.minBinSize}"
229
+ )
230
+
231
+
232
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
233
+ class RandomForestClassifier(RandomForestClassifier_original, RandomForestBase):
234
+ __doc__ = RandomForestClassifier_original.__doc__
235
+
236
+ if sklearn_check_version("1.2"):
237
+ _parameter_constraints: dict = {
238
+ **RandomForestClassifier_original._parameter_constraints,
239
+ "maxBins": [Interval(numbers.Integral, 0, None, closed="left")],
240
+ "minBinSize": [Interval(numbers.Integral, 1, None, closed="left")],
241
+ "binningStrategy": [StrOptions({"quantiles", "averages"})],
242
+ }
243
+
244
+ if sklearn_check_version("1.4"):
245
+
246
+ def __init__(
247
+ self,
248
+ n_estimators=100,
249
+ criterion="gini",
250
+ max_depth=None,
251
+ min_samples_split=2,
252
+ min_samples_leaf=1,
253
+ min_weight_fraction_leaf=0.0,
254
+ max_features="sqrt",
255
+ max_leaf_nodes=None,
256
+ min_impurity_decrease=0.0,
257
+ bootstrap=True,
258
+ oob_score=False,
259
+ n_jobs=None,
260
+ random_state=None,
261
+ verbose=0,
262
+ warm_start=False,
263
+ class_weight=None,
264
+ ccp_alpha=0.0,
265
+ max_samples=None,
266
+ monotonic_cst=None,
267
+ maxBins=256,
268
+ minBinSize=1,
269
+ binningStrategy="quantiles",
270
+ ):
271
+ super().__init__(
272
+ n_estimators=n_estimators,
273
+ criterion=criterion,
274
+ max_depth=max_depth,
275
+ min_samples_split=min_samples_split,
276
+ min_samples_leaf=min_samples_leaf,
277
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
278
+ max_features=max_features,
279
+ max_leaf_nodes=max_leaf_nodes,
280
+ min_impurity_decrease=min_impurity_decrease,
281
+ bootstrap=bootstrap,
282
+ oob_score=oob_score,
283
+ n_jobs=n_jobs,
284
+ random_state=random_state,
285
+ verbose=verbose,
286
+ warm_start=warm_start,
287
+ class_weight=class_weight,
288
+ monotonic_cst=monotonic_cst,
289
+ )
290
+ self.ccp_alpha = ccp_alpha
291
+ self.max_samples = max_samples
292
+ self.monotonic_cst = monotonic_cst
293
+ self.maxBins = maxBins
294
+ self.minBinSize = minBinSize
295
+ self.min_impurity_split = None
296
+ self.binningStrategy = binningStrategy
297
+
298
+ elif sklearn_check_version("1.0"):
299
+
300
+ def __init__(
301
+ self,
302
+ n_estimators=100,
303
+ criterion="gini",
304
+ max_depth=None,
305
+ min_samples_split=2,
306
+ min_samples_leaf=1,
307
+ min_weight_fraction_leaf=0.0,
308
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
309
+ max_leaf_nodes=None,
310
+ min_impurity_decrease=0.0,
311
+ bootstrap=True,
312
+ oob_score=False,
313
+ n_jobs=None,
314
+ random_state=None,
315
+ verbose=0,
316
+ warm_start=False,
317
+ class_weight=None,
318
+ ccp_alpha=0.0,
319
+ max_samples=None,
320
+ maxBins=256,
321
+ minBinSize=1,
322
+ binningStrategy="quantiles",
323
+ ):
324
+ super().__init__(
325
+ n_estimators=n_estimators,
326
+ criterion=criterion,
327
+ max_depth=max_depth,
328
+ min_samples_split=min_samples_split,
329
+ min_samples_leaf=min_samples_leaf,
330
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
331
+ max_features=max_features,
332
+ max_leaf_nodes=max_leaf_nodes,
333
+ min_impurity_decrease=min_impurity_decrease,
334
+ bootstrap=bootstrap,
335
+ oob_score=oob_score,
336
+ n_jobs=n_jobs,
337
+ random_state=random_state,
338
+ verbose=verbose,
339
+ warm_start=warm_start,
340
+ class_weight=class_weight,
341
+ )
342
+ self.ccp_alpha = ccp_alpha
343
+ self.max_samples = max_samples
344
+ self.maxBins = maxBins
345
+ self.minBinSize = minBinSize
346
+ self.min_impurity_split = None
347
+ self.binningStrategy = binningStrategy
348
+
349
+ else:
350
+
351
+ def __init__(
352
+ self,
353
+ n_estimators=100,
354
+ criterion="gini",
355
+ max_depth=None,
356
+ min_samples_split=2,
357
+ min_samples_leaf=1,
358
+ min_weight_fraction_leaf=0.0,
359
+ max_features="auto",
360
+ max_leaf_nodes=None,
361
+ min_impurity_decrease=0.0,
362
+ min_impurity_split=None,
363
+ bootstrap=True,
364
+ oob_score=False,
365
+ n_jobs=None,
366
+ random_state=None,
367
+ verbose=0,
368
+ warm_start=False,
369
+ class_weight=None,
370
+ ccp_alpha=0.0,
371
+ max_samples=None,
372
+ maxBins=256,
373
+ minBinSize=1,
374
+ binningStrategy="quantiles",
375
+ ):
376
+ super().__init__(
377
+ n_estimators=n_estimators,
378
+ criterion=criterion,
379
+ max_depth=max_depth,
380
+ min_samples_split=min_samples_split,
381
+ min_samples_leaf=min_samples_leaf,
382
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
383
+ max_features=max_features,
384
+ max_leaf_nodes=max_leaf_nodes,
385
+ min_impurity_decrease=min_impurity_decrease,
386
+ min_impurity_split=min_impurity_split,
387
+ bootstrap=bootstrap,
388
+ oob_score=oob_score,
389
+ n_jobs=n_jobs,
390
+ random_state=random_state,
391
+ verbose=verbose,
392
+ warm_start=warm_start,
393
+ class_weight=class_weight,
394
+ ccp_alpha=ccp_alpha,
395
+ max_samples=max_samples,
396
+ )
397
+ self.maxBins = maxBins
398
+ self.minBinSize = minBinSize
399
+ self.binningStrategy = binningStrategy
400
+
401
+ def fit(self, X, y, sample_weight=None):
402
+ """
403
+ Build a forest of trees from the training set (X, y).
404
+
405
+ Parameters
406
+ ----------
407
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
408
+ The training input samples. Internally, its dtype will be converted
409
+ to ``dtype=np.float32``. If a sparse matrix is provided, it will be
410
+ converted into a sparse ``csc_matrix``.
411
+
412
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
413
+ The target values (class labels in classification, real numbers in
414
+ regression).
415
+
416
+ sample_weight : array-like of shape (n_samples,), default=None
417
+ Sample weights. If None, then samples are equally weighted. Splits
418
+ that would create child nodes with net zero or negative weight are
419
+ ignored while searching for a split in each node. In the case of
420
+ classification, splits are also ignored if they would result in any
421
+ single class carrying a negative weight in either child node.
422
+
423
+ Returns
424
+ -------
425
+ self : object
426
+ """
427
+ if sp.issparse(y):
428
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
429
+ if sklearn_check_version("1.2"):
430
+ self._validate_params()
431
+ else:
432
+ self._check_parameters()
433
+ if sample_weight is not None:
434
+ sample_weight = check_sample_weight(sample_weight, X)
435
+
436
+ _patching_status = PatchingConditionsChain(
437
+ "sklearn.ensemble.RandomForestClassifier.fit"
438
+ )
439
+ _dal_ready = _patching_status.and_conditions(
440
+ [
441
+ (
442
+ self.oob_score
443
+ and daal_check_version((2021, "P", 500))
444
+ or not self.oob_score,
445
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
446
+ ),
447
+ (self.warm_start is False, "Warm start is not supported."),
448
+ (
449
+ self.criterion == "gini",
450
+ f"'{self.criterion}' criterion is not supported. "
451
+ "Only 'gini' criterion is supported.",
452
+ ),
453
+ (
454
+ self.ccp_alpha == 0.0,
455
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
456
+ ),
457
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
458
+ ]
459
+ )
460
+ if _dal_ready and sklearn_check_version("1.4"):
461
+ try:
462
+ _assert_all_finite(X)
463
+ input_is_finite = True
464
+ except ValueError:
465
+ input_is_finite = False
466
+ _patching_status.and_conditions(
467
+ [
468
+ (
469
+ input_is_finite,
470
+ "Non-finite input is not supported.",
471
+ ),
472
+ (
473
+ self.monotonic_cst is None,
474
+ "Monotonicity constraints are not supported.",
475
+ ),
476
+ ]
477
+ )
478
+
479
+ if _dal_ready:
480
+ if sklearn_check_version("1.0"):
481
+ self._check_feature_names(X, reset=True)
482
+ X = check_array(
483
+ X,
484
+ dtype=[np.float32, np.float64],
485
+ force_all_finite=not sklearn_check_version("1.4"),
486
+ )
487
+ y = np.asarray(y)
488
+ y = np.atleast_1d(y)
489
+
490
+ if y.ndim == 2 and y.shape[1] == 1:
491
+ warnings.warn(
492
+ "A column-vector y was passed when a 1d array was"
493
+ " expected. Please change the shape of y to "
494
+ "(n_samples,), for example using ravel().",
495
+ DataConversionWarning,
496
+ stacklevel=2,
497
+ )
498
+
499
+ check_consistent_length(X, y)
500
+
501
+ if y.ndim == 1:
502
+ # reshape is necessary to preserve the data contiguity against vs
503
+ # [:, np.newaxis] that does not.
504
+ y = np.reshape(y, (-1, 1))
505
+
506
+ self.n_outputs_ = y.shape[1]
507
+ _dal_ready = _patching_status.and_conditions(
508
+ [
509
+ (
510
+ self.n_outputs_ == 1,
511
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
512
+ )
513
+ ]
514
+ )
515
+
516
+ _patching_status.write_log()
517
+ if _dal_ready:
518
+ self._daal_fit_classifier(X, y, sample_weight=sample_weight)
519
+
520
+ if sklearn_check_version("1.2"):
521
+ self._estimator = DecisionTreeClassifier()
522
+ self.estimators_ = self._estimators_
523
+
524
+ # Decapsulate classes_ attributes
525
+ self.n_classes_ = self.n_classes_[0]
526
+ self.classes_ = self.classes_[0]
527
+ return self
528
+ return super().fit(X, y, sample_weight=sample_weight)
529
+
530
+ def predict(self, X):
531
+ """
532
+ Predict class for X.
533
+
534
+ The predicted class of an input sample is a vote by the trees in
535
+ the forest, weighted by their probability estimates. That is,
536
+ the predicted class is the one with highest mean probability
537
+ estimate across the trees.
538
+
539
+ Parameters
540
+ ----------
541
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
542
+ The input samples. Internally, its dtype will be converted to
543
+ ``dtype=np.float32``. If a sparse matrix is provided, it will be
544
+ converted into a sparse ``csr_matrix``.
545
+
546
+ Returns
547
+ -------
548
+ y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
549
+ The predicted classes.
550
+ """
551
+ _patching_status = PatchingConditionsChain(
552
+ "sklearn.ensemble.RandomForestClassifier.predict"
553
+ )
554
+ _dal_ready = _patching_status.and_conditions(
555
+ [
556
+ (hasattr(self, "daal_model_"), "oneDAL model was not trained."),
557
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
558
+ ]
559
+ )
560
+ if hasattr(self, "n_outputs_"):
561
+ _dal_ready = _patching_status.and_conditions(
562
+ [
563
+ (
564
+ self.n_outputs_ == 1,
565
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
566
+ )
567
+ ]
568
+ )
569
+
570
+ _patching_status.write_log()
571
+ if not _dal_ready:
572
+ return super().predict(X)
573
+
574
+ if sklearn_check_version("1.0"):
575
+ self._check_feature_names(X, reset=False)
576
+ X = check_array(
577
+ X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
578
+ )
579
+ return self._daal_predict_classifier(X)
580
+
581
+ def predict_proba(self, X):
582
+ """
583
+ Predict class probabilities for X.
584
+
585
+ The predicted class probabilities of an input sample are computed as
586
+ the mean predicted class probabilities of the trees in the forest.
587
+ The class probability of a single tree is the fraction of samples of
588
+ the same class in a leaf.
589
+
590
+ Parameters
591
+ ----------
592
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
593
+ The input samples. Internally, its dtype will be converted to
594
+ ``dtype=np.float32``. If a sparse matrix is provided, it will be
595
+ converted into a sparse ``csr_matrix``.
596
+
597
+ Returns
598
+ -------
599
+ p : ndarray of shape (n_samples, n_classes), or a list of n_outputs
600
+ such arrays if n_outputs > 1.
601
+ The class probabilities of the input samples. The order of the
602
+ classes corresponds to that in the attribute :term:`classes_`.
603
+ """
604
+ if sklearn_check_version("1.0"):
605
+ self._check_feature_names(X, reset=False)
606
+ if hasattr(self, "n_features_in_"):
607
+ try:
608
+ num_features = _daal_num_features(X)
609
+ except TypeError:
610
+ num_features = _num_samples(X)
611
+ if num_features != self.n_features_in_:
612
+ raise ValueError(
613
+ (
614
+ f"X has {num_features} features, "
615
+ f"but RandomForestClassifier is expecting "
616
+ f"{self.n_features_in_} features as input"
617
+ )
618
+ )
619
+
620
+ _patching_status = PatchingConditionsChain(
621
+ "sklearn.ensemble.RandomForestClassifier.predict_proba"
622
+ )
623
+ _dal_ready = _patching_status.and_conditions(
624
+ [
625
+ (hasattr(self, "daal_model_"), "oneDAL model was not trained."),
626
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
627
+ (
628
+ daal_check_version((2021, "P", 400)),
629
+ "oneDAL version is lower than 2021.4.",
630
+ ),
631
+ ]
632
+ )
633
+ if hasattr(self, "n_outputs_"):
634
+ _dal_ready = _patching_status.and_conditions(
635
+ [
636
+ (
637
+ self.n_outputs_ == 1,
638
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
639
+ )
640
+ ]
641
+ )
642
+ _patching_status.write_log()
643
+
644
+ if not _dal_ready:
645
+ return super().predict_proba(X)
646
+ X = check_array(X, dtype=[np.float64, np.float32])
647
+ check_is_fitted(self)
648
+ self._check_n_features(X, reset=False)
649
+ return self._daal_predict_proba(X)
650
+
651
+ if sklearn_check_version("1.0"):
652
+
653
+ @deprecated(
654
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
655
+ "removed in 1.2. Use `n_features_in_` instead."
656
+ )
657
+ @property
658
+ def n_features_(self):
659
+ return self.n_features_in_
660
+
661
+ @property
662
+ def _estimators_(self):
663
+ if hasattr(self, "_cached_estimators_"):
664
+ if self._cached_estimators_:
665
+ return self._cached_estimators_
666
+
667
+ check_is_fitted(self)
668
+ classes_ = self.classes_[0]
669
+ n_classes_ = self.n_classes_[0]
670
+ # convert model to estimators
671
+ params = {
672
+ "criterion": self.criterion,
673
+ "max_depth": self.max_depth,
674
+ "min_samples_split": self.min_samples_split,
675
+ "min_samples_leaf": self.min_samples_leaf,
676
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
677
+ "max_features": self.max_features,
678
+ "max_leaf_nodes": self.max_leaf_nodes,
679
+ "min_impurity_decrease": self.min_impurity_decrease,
680
+ "random_state": None,
681
+ }
682
+ if not sklearn_check_version("1.0"):
683
+ params["min_impurity_split"] = self.min_impurity_split
684
+ est = DecisionTreeClassifier(**params)
685
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
686
+ # oneAPI Data Analytics Library solution
687
+ estimators_ = []
688
+ random_state_checked = check_random_state(self.random_state)
689
+ for i in range(self.n_estimators):
690
+ est_i = clone(est)
691
+ est_i.set_params(
692
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
693
+ )
694
+ if sklearn_check_version("1.0"):
695
+ est_i.n_features_in_ = self.n_features_in_
696
+ else:
697
+ est_i.n_features_ = self.n_features_in_
698
+ est_i.n_outputs_ = self.n_outputs_
699
+ est_i.classes_ = classes_
700
+ est_i.n_classes_ = n_classes_
701
+ # treeState members: 'class_count', 'leaf_count', 'max_depth',
702
+ # 'node_ar', 'node_count', 'value_ar'
703
+ tree_i_state_class = daal4py.getTreeState(self.daal_model_, i, n_classes_)
704
+
705
+ # node_ndarray = tree_i_state_class.node_ar
706
+ # value_ndarray = tree_i_state_class.value_ar
707
+ # value_shape = (node_ndarray.shape[0], self.n_outputs_,
708
+ # n_classes_)
709
+ # assert np.allclose(
710
+ # value_ndarray, value_ndarray.astype(np.intc, casting='unsafe')
711
+ # ), "Value array is non-integer"
712
+ tree_i_state_dict = {
713
+ "max_depth": tree_i_state_class.max_depth,
714
+ "node_count": tree_i_state_class.node_count,
715
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
716
+ "values": tree_i_state_class.value_ar,
717
+ }
718
+ est_i.tree_ = Tree(
719
+ self.n_features_in_,
720
+ np.array([n_classes_], dtype=np.intp),
721
+ self.n_outputs_,
722
+ )
723
+ est_i.tree_.__setstate__(tree_i_state_dict)
724
+ estimators_.append(est_i)
725
+
726
+ self._cached_estimators_ = estimators_
727
+ return estimators_
728
+
729
+ def _daal_predict_proba(self, X):
730
+ X_fptype = getFPType(X)
731
+ dfc_algorithm = daal4py.decision_forest_classification_prediction(
732
+ nClasses=int(self.n_classes_),
733
+ fptype=X_fptype,
734
+ resultsToEvaluate="computeClassProbabilities",
735
+ )
736
+ dfc_predictionResult = dfc_algorithm.compute(X, self.daal_model_)
737
+
738
+ pred = dfc_predictionResult.probabilities
739
+
740
+ return pred
741
+
742
+ def _daal_fit_classifier(self, X, y, sample_weight=None):
743
+ y = check_array(y, ensure_2d=False, dtype=None)
744
+ y, expanded_class_weight = self._validate_y_class_weight(y)
745
+ n_classes = self.n_classes_[0]
746
+ self.n_features_in_ = X.shape[1]
747
+ if not sklearn_check_version("1.0"):
748
+ self.n_features_ = self.n_features_in_
749
+
750
+ if expanded_class_weight is not None:
751
+ if sample_weight is not None:
752
+ sample_weight = sample_weight * expanded_class_weight
753
+ else:
754
+ sample_weight = expanded_class_weight
755
+ if sample_weight is not None:
756
+ sample_weight = [sample_weight]
757
+
758
+ rs_ = check_random_state(self.random_state)
759
+ seed_ = rs_.randint(0, np.iinfo("i").max)
760
+
761
+ if n_classes < 2:
762
+ raise ValueError("Training data only contain information about one class.")
763
+
764
+ daal_engine = daal4py.engines_mt19937(seed=seed_, fptype=getFPType(X))
765
+
766
+ features_per_node = _to_absolute_max_features(
767
+ self.max_features, X.shape[1], is_classification=True
768
+ )
769
+
770
+ n_samples_bootstrap = _get_n_samples_bootstrap(
771
+ n_samples=X.shape[0], max_samples=self.max_samples
772
+ )
773
+
774
+ if not self.bootstrap and self.max_samples is not None:
775
+ raise ValueError(
776
+ "`max_sample` cannot be set if `bootstrap=False`. "
777
+ "Either switch to `bootstrap=True` or set "
778
+ "`max_sample=None`."
779
+ )
780
+
781
+ if not self.bootstrap and self.oob_score:
782
+ raise ValueError("Out of bag estimation only available if bootstrap=True")
783
+
784
+ parameters = {
785
+ "bootstrap": bool(self.bootstrap),
786
+ "engine": daal_engine,
787
+ "featuresPerNode": features_per_node,
788
+ "fptype": getFPType(X),
789
+ "impurityThreshold": self.min_impurity_split or 0.0,
790
+ "maxBins": self.maxBins,
791
+ "maxLeafNodes": self.max_leaf_nodes or 0,
792
+ "maxTreeDepth": self.max_depth or 0,
793
+ "memorySavingMode": False,
794
+ "method": "hist",
795
+ "minBinSize": self.minBinSize,
796
+ "minImpurityDecreaseInSplitNode": self.min_impurity_decrease,
797
+ "minWeightFractionInLeafNode": self.min_weight_fraction_leaf,
798
+ "nClasses": int(n_classes),
799
+ "nTrees": self.n_estimators,
800
+ "observationsPerTreeFraction": 1.0,
801
+ "resultsToCompute": "",
802
+ "varImportance": "MDI",
803
+ }
804
+
805
+ if isinstance(self.min_samples_split, numbers.Integral):
806
+ parameters["minObservationsInSplitNode"] = self.min_samples_split
807
+ else:
808
+ parameters["minObservationsInSplitNode"] = ceil(
809
+ self.min_samples_split * X.shape[0]
810
+ )
811
+
812
+ if isinstance(self.min_samples_leaf, numbers.Integral):
813
+ parameters["minObservationsInLeafNode"] = self.min_samples_leaf
814
+ else:
815
+ parameters["minObservationsInLeafNode"] = ceil(
816
+ self.min_samples_leaf * X.shape[0]
817
+ )
818
+
819
+ if self.bootstrap:
820
+ parameters["observationsPerTreeFraction"] = n_samples_bootstrap
821
+ if self.oob_score:
822
+ parameters["resultsToCompute"] = (
823
+ "computeOutOfBagErrorAccuracy|computeOutOfBagErrorDecisionFunction"
824
+ )
825
+
826
+ if daal_check_version((2023, "P", 200)):
827
+ parameters["binningStrategy"] = self.binningStrategy
828
+
829
+ # create algorithm
830
+ dfc_algorithm = daal4py.decision_forest_classification_training(**parameters)
831
+ self._cached_estimators_ = None
832
+ # compute
833
+ dfc_trainingResult = dfc_algorithm.compute(X, y, sample_weight)
834
+
835
+ # get resulting model
836
+ model = dfc_trainingResult.model
837
+ self.daal_model_ = model
838
+
839
+ if self.oob_score:
840
+ self.oob_score_ = dfc_trainingResult.outOfBagErrorAccuracy[0][0]
841
+ self.oob_decision_function_ = dfc_trainingResult.outOfBagErrorDecisionFunction
842
+ if self.oob_decision_function_.shape[-1] == 1:
843
+ self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1)
844
+
845
+ return self
846
+
847
+ def _daal_predict_classifier(self, X):
848
+ X_fptype = getFPType(X)
849
+ dfc_algorithm = daal4py.decision_forest_classification_prediction(
850
+ nClasses=int(self.n_classes_),
851
+ fptype=X_fptype,
852
+ resultsToEvaluate="computeClassLabels",
853
+ )
854
+ if X.shape[1] != self.n_features_in_:
855
+ raise ValueError(
856
+ (
857
+ f"X has {X.shape[1]} features, "
858
+ f"but RandomForestClassifier is expecting "
859
+ f"{self.n_features_in_} features as input"
860
+ )
861
+ )
862
+ dfc_predictionResult = dfc_algorithm.compute(X, self.daal_model_)
863
+
864
+ pred = dfc_predictionResult.prediction
865
+
866
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
867
+
868
+
869
+ @control_n_jobs(decorated_methods=["fit", "predict"])
870
+ class RandomForestRegressor(RandomForestRegressor_original, RandomForestBase):
871
+ __doc__ = RandomForestRegressor_original.__doc__
872
+
873
+ if sklearn_check_version("1.2"):
874
+ _parameter_constraints: dict = {
875
+ **RandomForestRegressor_original._parameter_constraints,
876
+ "maxBins": [Interval(numbers.Integral, 0, None, closed="left")],
877
+ "minBinSize": [Interval(numbers.Integral, 1, None, closed="left")],
878
+ "binningStrategy": [StrOptions({"quantiles", "averages"})],
879
+ }
880
+
881
+ if sklearn_check_version("1.4"):
882
+
883
+ def __init__(
884
+ self,
885
+ n_estimators=100,
886
+ *,
887
+ criterion="squared_error",
888
+ max_depth=None,
889
+ min_samples_split=2,
890
+ min_samples_leaf=1,
891
+ min_weight_fraction_leaf=0.0,
892
+ max_features=1.0,
893
+ max_leaf_nodes=None,
894
+ min_impurity_decrease=0.0,
895
+ bootstrap=True,
896
+ oob_score=False,
897
+ n_jobs=None,
898
+ random_state=None,
899
+ verbose=0,
900
+ warm_start=False,
901
+ ccp_alpha=0.0,
902
+ max_samples=None,
903
+ monotonic_cst=None,
904
+ maxBins=256,
905
+ minBinSize=1,
906
+ binningStrategy="quantiles",
907
+ ):
908
+ super().__init__(
909
+ n_estimators=n_estimators,
910
+ criterion=criterion,
911
+ max_depth=max_depth,
912
+ min_samples_split=min_samples_split,
913
+ min_samples_leaf=min_samples_leaf,
914
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
915
+ max_features=max_features,
916
+ max_leaf_nodes=max_leaf_nodes,
917
+ min_impurity_decrease=min_impurity_decrease,
918
+ bootstrap=bootstrap,
919
+ oob_score=oob_score,
920
+ n_jobs=n_jobs,
921
+ random_state=random_state,
922
+ verbose=verbose,
923
+ warm_start=warm_start,
924
+ monotonic_cst=monotonic_cst,
925
+ )
926
+ self.ccp_alpha = ccp_alpha
927
+ self.max_samples = max_samples
928
+ self.monotonic_cst = monotonic_cst
929
+ self.maxBins = maxBins
930
+ self.minBinSize = minBinSize
931
+ self.min_impurity_split = None
932
+ self.binningStrategy = binningStrategy
933
+
934
+ elif sklearn_check_version("1.0"):
935
+
936
+ def __init__(
937
+ self,
938
+ n_estimators=100,
939
+ *,
940
+ criterion="squared_error",
941
+ max_depth=None,
942
+ min_samples_split=2,
943
+ min_samples_leaf=1,
944
+ min_weight_fraction_leaf=0.0,
945
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
946
+ max_leaf_nodes=None,
947
+ min_impurity_decrease=0.0,
948
+ bootstrap=True,
949
+ oob_score=False,
950
+ n_jobs=None,
951
+ random_state=None,
952
+ verbose=0,
953
+ warm_start=False,
954
+ ccp_alpha=0.0,
955
+ max_samples=None,
956
+ maxBins=256,
957
+ minBinSize=1,
958
+ binningStrategy="quantiles",
959
+ ):
960
+ super().__init__(
961
+ n_estimators=n_estimators,
962
+ criterion=criterion,
963
+ max_depth=max_depth,
964
+ min_samples_split=min_samples_split,
965
+ min_samples_leaf=min_samples_leaf,
966
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
967
+ max_features=max_features,
968
+ max_leaf_nodes=max_leaf_nodes,
969
+ min_impurity_decrease=min_impurity_decrease,
970
+ bootstrap=bootstrap,
971
+ oob_score=oob_score,
972
+ n_jobs=n_jobs,
973
+ random_state=random_state,
974
+ verbose=verbose,
975
+ warm_start=warm_start,
976
+ )
977
+ self.ccp_alpha = ccp_alpha
978
+ self.max_samples = max_samples
979
+ self.maxBins = maxBins
980
+ self.minBinSize = minBinSize
981
+ self.min_impurity_split = None
982
+ self.binningStrategy = binningStrategy
983
+
984
+ else:
985
+
986
+ def __init__(
987
+ self,
988
+ n_estimators=100,
989
+ *,
990
+ criterion="mse",
991
+ max_depth=None,
992
+ min_samples_split=2,
993
+ min_samples_leaf=1,
994
+ min_weight_fraction_leaf=0.0,
995
+ max_features="auto",
996
+ max_leaf_nodes=None,
997
+ min_impurity_decrease=0.0,
998
+ min_impurity_split=None,
999
+ bootstrap=True,
1000
+ oob_score=False,
1001
+ n_jobs=None,
1002
+ random_state=None,
1003
+ verbose=0,
1004
+ warm_start=False,
1005
+ ccp_alpha=0.0,
1006
+ max_samples=None,
1007
+ maxBins=256,
1008
+ minBinSize=1,
1009
+ binningStrategy="quantiles",
1010
+ ):
1011
+ super().__init__(
1012
+ n_estimators=n_estimators,
1013
+ criterion=criterion,
1014
+ max_depth=max_depth,
1015
+ min_samples_split=min_samples_split,
1016
+ min_samples_leaf=min_samples_leaf,
1017
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
1018
+ max_features=max_features,
1019
+ max_leaf_nodes=max_leaf_nodes,
1020
+ min_impurity_decrease=min_impurity_decrease,
1021
+ min_impurity_split=min_impurity_split,
1022
+ bootstrap=bootstrap,
1023
+ oob_score=oob_score,
1024
+ n_jobs=n_jobs,
1025
+ random_state=random_state,
1026
+ verbose=verbose,
1027
+ warm_start=warm_start,
1028
+ ccp_alpha=ccp_alpha,
1029
+ max_samples=max_samples,
1030
+ )
1031
+ self.maxBins = maxBins
1032
+ self.minBinSize = minBinSize
1033
+ self.binningStrategy = binningStrategy
1034
+
1035
+ def fit(self, X, y, sample_weight=None):
1036
+ """
1037
+ Build a forest of trees from the training set (X, y).
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
1042
+ The training input samples. Internally, its dtype will be converted
1043
+ to ``dtype=np.float32``. If a sparse matrix is provided, it will be
1044
+ converted into a sparse ``csc_matrix``.
1045
+
1046
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
1047
+ The target values (class labels in classification, real numbers in
1048
+ regression).
1049
+
1050
+ sample_weight : array-like of shape (n_samples,), default=None
1051
+ Sample weights. If None, then samples are equally weighted. Splits
1052
+ that would create child nodes with net zero or negative weight are
1053
+ ignored while searching for a split in each node. In the case of
1054
+ classification, splits are also ignored if they would result in any
1055
+ single class carrying a negative weight in either child node.
1056
+
1057
+ Returns
1058
+ -------
1059
+ self : object
1060
+ """
1061
+ if sp.issparse(y):
1062
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
1063
+ if sklearn_check_version("1.2"):
1064
+ self._validate_params()
1065
+ else:
1066
+ self._check_parameters()
1067
+ if sample_weight is not None:
1068
+ sample_weight = check_sample_weight(sample_weight, X)
1069
+
1070
+ if sklearn_check_version("1.0") and self.criterion == "mse":
1071
+ warnings.warn(
1072
+ "Criterion 'mse' was deprecated in v1.0 and will be "
1073
+ "removed in version 1.2. Use `criterion='squared_error'` "
1074
+ "which is equivalent.",
1075
+ FutureWarning,
1076
+ )
1077
+
1078
+ _patching_status = PatchingConditionsChain(
1079
+ "sklearn.ensemble.RandomForestRegressor.fit"
1080
+ )
1081
+ _dal_ready = _patching_status.and_conditions(
1082
+ [
1083
+ (
1084
+ self.oob_score
1085
+ and daal_check_version((2021, "P", 500))
1086
+ or not self.oob_score,
1087
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
1088
+ ),
1089
+ (self.warm_start is False, "Warm start is not supported."),
1090
+ (
1091
+ self.criterion in ["mse", "squared_error"],
1092
+ f"'{self.criterion}' criterion is not supported. "
1093
+ "Only 'mse' and 'squared_error' criteria are supported.",
1094
+ ),
1095
+ (
1096
+ self.ccp_alpha == 0.0,
1097
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
1098
+ ),
1099
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1100
+ ]
1101
+ )
1102
+ if _dal_ready and sklearn_check_version("1.4"):
1103
+ try:
1104
+ _assert_all_finite(X)
1105
+ input_is_finite = True
1106
+ except ValueError:
1107
+ input_is_finite = False
1108
+ _patching_status.and_conditions(
1109
+ [
1110
+ (
1111
+ input_is_finite,
1112
+ "Non-finite input is not supported.",
1113
+ ),
1114
+ (
1115
+ self.monotonic_cst is None,
1116
+ "Monotonicity constraints are not supported.",
1117
+ ),
1118
+ ]
1119
+ )
1120
+
1121
+ if _dal_ready:
1122
+ if sklearn_check_version("1.0"):
1123
+ self._check_feature_names(X, reset=True)
1124
+ X = check_array(
1125
+ X,
1126
+ dtype=[np.float64, np.float32],
1127
+ force_all_finite=not sklearn_check_version("1.4"),
1128
+ )
1129
+ y = np.asarray(y)
1130
+ y = np.atleast_1d(y)
1131
+
1132
+ if y.ndim == 2 and y.shape[1] == 1:
1133
+ warnings.warn(
1134
+ "A column-vector y was passed when a 1d array was"
1135
+ " expected. Please change the shape of y to "
1136
+ "(n_samples,), for example using ravel().",
1137
+ DataConversionWarning,
1138
+ stacklevel=2,
1139
+ )
1140
+
1141
+ y = check_array(y, ensure_2d=False, dtype=X.dtype)
1142
+ check_consistent_length(X, y)
1143
+
1144
+ if y.ndim == 1:
1145
+ # reshape is necessary to preserve the data contiguity against vs
1146
+ # [:, np.newaxis] that does not.
1147
+ y = np.reshape(y, (-1, 1))
1148
+
1149
+ self.n_outputs_ = y.shape[1]
1150
+ _dal_ready = _patching_status.and_conditions(
1151
+ [
1152
+ (
1153
+ self.n_outputs_ == 1,
1154
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1155
+ )
1156
+ ]
1157
+ )
1158
+
1159
+ _patching_status.write_log()
1160
+ if _dal_ready:
1161
+ self._daal_fit_regressor(X, y, sample_weight=sample_weight)
1162
+
1163
+ if sklearn_check_version("1.2"):
1164
+ self._estimator = DecisionTreeRegressor()
1165
+ self.estimators_ = self._estimators_
1166
+ return self
1167
+ return super().fit(X, y, sample_weight=sample_weight)
1168
+
1169
+ def predict(self, X):
1170
+ """
1171
+ Predict class for X.
1172
+
1173
+ The predicted class of an input sample is a vote by the trees in
1174
+ the forest, weighted by their probability estimates. That is,
1175
+ the predicted class is the one with highest mean probability
1176
+ estimate across the trees.
1177
+
1178
+ Parameters
1179
+ ----------
1180
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
1181
+ The input samples. Internally, its dtype will be converted to
1182
+ ``dtype=np.float32``. If a sparse matrix is provided, it will be
1183
+ converted into a sparse ``csr_matrix``.
1184
+
1185
+ Returns
1186
+ -------
1187
+ y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
1188
+ The predicted classes.
1189
+ """
1190
+ _patching_status = PatchingConditionsChain(
1191
+ "sklearn.ensemble.RandomForestRegressor.predict"
1192
+ )
1193
+ _dal_ready = _patching_status.and_conditions(
1194
+ [
1195
+ (hasattr(self, "daal_model_"), "oneDAL model was not trained."),
1196
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1197
+ ]
1198
+ )
1199
+ if hasattr(self, "n_outputs_"):
1200
+ _dal_ready = _patching_status.and_conditions(
1201
+ [
1202
+ (
1203
+ self.n_outputs_ == 1,
1204
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1205
+ )
1206
+ ]
1207
+ )
1208
+
1209
+ _patching_status.write_log()
1210
+ if not _dal_ready:
1211
+ return super().predict(X)
1212
+
1213
+ if sklearn_check_version("1.0"):
1214
+ self._check_feature_names(X, reset=False)
1215
+ X = check_array(
1216
+ X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
1217
+ )
1218
+ return self._daal_predict_regressor(X)
1219
+
1220
+ if sklearn_check_version("1.0"):
1221
+
1222
+ @deprecated(
1223
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
1224
+ "removed in 1.2. Use `n_features_in_` instead."
1225
+ )
1226
+ @property
1227
+ def n_features_(self):
1228
+ return self.n_features_in_
1229
+
1230
+ @property
1231
+ def _estimators_(self):
1232
+ if hasattr(self, "_cached_estimators_"):
1233
+ if self._cached_estimators_:
1234
+ return self._cached_estimators_
1235
+ check_is_fitted(self)
1236
+ # convert model to estimators
1237
+ params = {
1238
+ "criterion": self.criterion,
1239
+ "max_depth": self.max_depth,
1240
+ "min_samples_split": self.min_samples_split,
1241
+ "min_samples_leaf": self.min_samples_leaf,
1242
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
1243
+ "max_features": self.max_features,
1244
+ "max_leaf_nodes": self.max_leaf_nodes,
1245
+ "min_impurity_decrease": self.min_impurity_decrease,
1246
+ "random_state": None,
1247
+ }
1248
+ if not sklearn_check_version("1.0"):
1249
+ params["min_impurity_split"] = self.min_impurity_split
1250
+ est = DecisionTreeRegressor(**params)
1251
+
1252
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
1253
+ # oneAPI Data Analytics Library solution
1254
+ estimators_ = []
1255
+ random_state_checked = check_random_state(self.random_state)
1256
+ for i in range(self.n_estimators):
1257
+ est_i = clone(est)
1258
+ est_i.set_params(
1259
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
1260
+ )
1261
+ if sklearn_check_version("1.0"):
1262
+ est_i.n_features_in_ = self.n_features_in_
1263
+ else:
1264
+ est_i.n_features_ = self.n_features_in_
1265
+ est_i.n_outputs_ = self.n_outputs_
1266
+
1267
+ tree_i_state_class = daal4py.getTreeState(self.daal_model_, i)
1268
+ tree_i_state_dict = {
1269
+ "max_depth": tree_i_state_class.max_depth,
1270
+ "node_count": tree_i_state_class.node_count,
1271
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
1272
+ "values": tree_i_state_class.value_ar,
1273
+ }
1274
+
1275
+ est_i.tree_ = Tree(
1276
+ self.n_features_in_, np.array([1], dtype=np.intp), self.n_outputs_
1277
+ )
1278
+ est_i.tree_.__setstate__(tree_i_state_dict)
1279
+ estimators_.append(est_i)
1280
+
1281
+ return estimators_
1282
+
1283
+ def _daal_fit_regressor(self, X, y, sample_weight=None):
1284
+ self.n_features_in_ = X.shape[1]
1285
+ if not sklearn_check_version("1.0"):
1286
+ self.n_features_ = self.n_features_in_
1287
+
1288
+ rs_ = check_random_state(self.random_state)
1289
+
1290
+ if not self.bootstrap and self.max_samples is not None:
1291
+ raise ValueError(
1292
+ "`max_sample` cannot be set if `bootstrap=False`. "
1293
+ "Either switch to `bootstrap=True` or set "
1294
+ "`max_sample=None`."
1295
+ )
1296
+
1297
+ if not self.bootstrap and self.oob_score:
1298
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
1299
+
1300
+ seed_ = rs_.randint(0, np.iinfo("i").max)
1301
+
1302
+ daal_engine = daal4py.engines_mt19937(seed=seed_, fptype=getFPType(X))
1303
+
1304
+ features_per_node = _to_absolute_max_features(
1305
+ self.max_features, X.shape[1], is_classification=False
1306
+ )
1307
+
1308
+ n_samples_bootstrap = _get_n_samples_bootstrap(
1309
+ n_samples=X.shape[0], max_samples=self.max_samples
1310
+ )
1311
+
1312
+ if sample_weight is not None:
1313
+ if hasattr(sample_weight, "__array__"):
1314
+ sample_weight[sample_weight == 0.0] = 1.0
1315
+ sample_weight = [sample_weight]
1316
+
1317
+ parameters = {
1318
+ "bootstrap": bool(self.bootstrap),
1319
+ "engine": daal_engine,
1320
+ "featuresPerNode": features_per_node,
1321
+ "fptype": getFPType(X),
1322
+ "impurityThreshold": float(self.min_impurity_split or 0.0),
1323
+ "maxBins": self.maxBins,
1324
+ "maxLeafNodes": self.max_leaf_nodes or 0,
1325
+ "maxTreeDepth": self.max_depth or 0,
1326
+ "memorySavingMode": False,
1327
+ "method": "hist",
1328
+ "minBinSize": self.minBinSize,
1329
+ "minImpurityDecreaseInSplitNode": self.min_impurity_decrease,
1330
+ "minWeightFractionInLeafNode": self.min_weight_fraction_leaf,
1331
+ "nTrees": int(self.n_estimators),
1332
+ "observationsPerTreeFraction": 1.0,
1333
+ "resultsToCompute": "",
1334
+ "varImportance": "MDI",
1335
+ }
1336
+
1337
+ if isinstance(self.min_samples_split, numbers.Integral):
1338
+ parameters["minObservationsInSplitNode"] = self.min_samples_split
1339
+ else:
1340
+ parameters["minObservationsInSplitNode"] = ceil(
1341
+ self.min_samples_split * X.shape[0]
1342
+ )
1343
+
1344
+ if isinstance(self.min_samples_leaf, numbers.Integral):
1345
+ parameters["minObservationsInLeafNode"] = self.min_samples_leaf
1346
+ else:
1347
+ parameters["minObservationsInLeafNode"] = ceil(
1348
+ self.min_samples_leaf * X.shape[0]
1349
+ )
1350
+
1351
+ if self.bootstrap:
1352
+ parameters["observationsPerTreeFraction"] = n_samples_bootstrap
1353
+ if self.oob_score:
1354
+ parameters["resultsToCompute"] = (
1355
+ "computeOutOfBagErrorR2|computeOutOfBagErrorPrediction"
1356
+ )
1357
+
1358
+ if daal_check_version((2023, "P", 200)):
1359
+ parameters["binningStrategy"] = self.binningStrategy
1360
+
1361
+ # create algorithm
1362
+ dfr_algorithm = daal4py.decision_forest_regression_training(**parameters)
1363
+
1364
+ self._cached_estimators_ = None
1365
+
1366
+ dfr_trainingResult = dfr_algorithm.compute(X, y, sample_weight)
1367
+
1368
+ # get resulting model
1369
+ model = dfr_trainingResult.model
1370
+ self.daal_model_ = model
1371
+
1372
+ if self.oob_score:
1373
+ self.oob_score_ = dfr_trainingResult.outOfBagErrorR2[0][0]
1374
+ self.oob_prediction_ = dfr_trainingResult.outOfBagErrorPrediction.squeeze(
1375
+ axis=1
1376
+ )
1377
+ if self.oob_prediction_.shape[-1] == 1:
1378
+ self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1)
1379
+
1380
+ return self
1381
+
1382
+ def _daal_predict_regressor(self, X):
1383
+ if X.shape[1] != self.n_features_in_:
1384
+ raise ValueError(
1385
+ (
1386
+ f"X has {X.shape[1]} features, "
1387
+ f"but RandomForestRegressor is expecting "
1388
+ f"{self.n_features_in_} features as input"
1389
+ )
1390
+ )
1391
+ X_fptype = getFPType(X)
1392
+ dfr_alg = daal4py.decision_forest_regression_prediction(fptype=X_fptype)
1393
+ dfr_predictionResult = dfr_alg.compute(X, self.daal_model_)
1394
+
1395
+ pred = dfr_predictionResult.prediction
1396
+
1397
+ return pred.ravel()