scikit-learn-intelex 2025.0.0__py312-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (278) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-312-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-312-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +242 -0
  10. daal4py/sklearn/_utils.py +241 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +155 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +53 -0
  61. onedal/_device_offload.py +229 -0
  62. onedal/_onedal_py_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-312-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +560 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +116 -0
  83. onedal/common/tests/test_policy.py +75 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +95 -0
  91. onedal/datatypes/tests/test_data.py +235 -0
  92. onedal/decomposition/__init__.py +20 -0
  93. onedal/decomposition/incremental_pca.py +204 -0
  94. onedal/decomposition/pca.py +186 -0
  95. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  96. onedal/ensemble/__init__.py +29 -0
  97. onedal/ensemble/forest.py +720 -0
  98. onedal/ensemble/tests/test_random_forest.py +97 -0
  99. onedal/linear_model/__init__.py +27 -0
  100. onedal/linear_model/incremental_linear_model.py +258 -0
  101. onedal/linear_model/linear_model.py +329 -0
  102. onedal/linear_model/logistic_regression.py +249 -0
  103. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  104. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  105. onedal/linear_model/tests/test_linear_regression.py +149 -0
  106. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  107. onedal/linear_model/tests/test_ridge.py +95 -0
  108. onedal/neighbors/__init__.py +19 -0
  109. onedal/neighbors/neighbors.py +778 -0
  110. onedal/neighbors/tests/test_knn_classification.py +49 -0
  111. onedal/primitives/__init__.py +27 -0
  112. onedal/primitives/get_tree.py +25 -0
  113. onedal/primitives/kernel_functions.py +153 -0
  114. onedal/primitives/tests/test_kernel_functions.py +159 -0
  115. onedal/spmd/__init__.py +25 -0
  116. onedal/spmd/_base.py +30 -0
  117. onedal/spmd/basic_statistics/__init__.py +20 -0
  118. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  119. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  120. onedal/spmd/cluster/__init__.py +28 -0
  121. onedal/spmd/cluster/dbscan.py +23 -0
  122. onedal/spmd/cluster/kmeans.py +56 -0
  123. onedal/spmd/covariance/__init__.py +20 -0
  124. onedal/spmd/covariance/covariance.py +26 -0
  125. onedal/spmd/covariance/incremental_covariance.py +82 -0
  126. onedal/spmd/decomposition/__init__.py +20 -0
  127. onedal/spmd/decomposition/incremental_pca.py +117 -0
  128. onedal/spmd/decomposition/pca.py +26 -0
  129. onedal/spmd/ensemble/__init__.py +19 -0
  130. onedal/spmd/ensemble/forest.py +28 -0
  131. onedal/spmd/linear_model/__init__.py +21 -0
  132. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  133. onedal/spmd/linear_model/linear_model.py +30 -0
  134. onedal/spmd/linear_model/logistic_regression.py +38 -0
  135. onedal/spmd/neighbors/__init__.py +19 -0
  136. onedal/spmd/neighbors/neighbors.py +75 -0
  137. onedal/svm/__init__.py +19 -0
  138. onedal/svm/svm.py +556 -0
  139. onedal/svm/tests/test_csr_svm.py +351 -0
  140. onedal/svm/tests/test_nusvc.py +204 -0
  141. onedal/svm/tests/test_nusvr.py +210 -0
  142. onedal/svm/tests/test_svc.py +168 -0
  143. onedal/svm/tests/test_svr.py +243 -0
  144. onedal/tests/test_common.py +41 -0
  145. onedal/tests/utils/_dataframes_support.py +168 -0
  146. onedal/tests/utils/_device_selection.py +107 -0
  147. onedal/utils/__init__.py +49 -0
  148. onedal/utils/_array_api.py +91 -0
  149. onedal/utils/validation.py +432 -0
  150. scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
  151. scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
  152. scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
  153. scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
  154. scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
  155. sklearnex/__init__.py +65 -0
  156. sklearnex/__main__.py +58 -0
  157. sklearnex/_config.py +98 -0
  158. sklearnex/_device_offload.py +121 -0
  159. sklearnex/_utils.py +109 -0
  160. sklearnex/basic_statistics/__init__.py +20 -0
  161. sklearnex/basic_statistics/basic_statistics.py +140 -0
  162. sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  163. sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
  164. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
  165. sklearnex/cluster/__init__.py +20 -0
  166. sklearnex/cluster/dbscan.py +192 -0
  167. sklearnex/cluster/k_means.py +383 -0
  168. sklearnex/cluster/tests/test_dbscan.py +38 -0
  169. sklearnex/cluster/tests/test_kmeans.py +153 -0
  170. sklearnex/conftest.py +73 -0
  171. sklearnex/covariance/__init__.py +19 -0
  172. sklearnex/covariance/incremental_covariance.py +368 -0
  173. sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
  174. sklearnex/decomposition/__init__.py +19 -0
  175. sklearnex/decomposition/pca.py +414 -0
  176. sklearnex/decomposition/tests/test_pca.py +58 -0
  177. sklearnex/dispatcher.py +543 -0
  178. sklearnex/doc/third-party-programs.txt +424 -0
  179. sklearnex/ensemble/__init__.py +29 -0
  180. sklearnex/ensemble/_forest.py +2016 -0
  181. sklearnex/ensemble/tests/test_forest.py +120 -0
  182. sklearnex/glob/__main__.py +72 -0
  183. sklearnex/glob/dispatcher.py +101 -0
  184. sklearnex/linear_model/__init__.py +32 -0
  185. sklearnex/linear_model/coordinate_descent.py +30 -0
  186. sklearnex/linear_model/incremental_linear.py +463 -0
  187. sklearnex/linear_model/incremental_ridge.py +418 -0
  188. sklearnex/linear_model/linear.py +302 -0
  189. sklearnex/linear_model/logistic_path.py +17 -0
  190. sklearnex/linear_model/logistic_regression.py +403 -0
  191. sklearnex/linear_model/ridge.py +24 -0
  192. sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
  193. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  194. sklearnex/linear_model/tests/test_linear.py +142 -0
  195. sklearnex/linear_model/tests/test_logreg.py +134 -0
  196. sklearnex/manifold/__init__.py +19 -0
  197. sklearnex/manifold/t_sne.py +21 -0
  198. sklearnex/manifold/tests/test_tsne.py +26 -0
  199. sklearnex/metrics/__init__.py +23 -0
  200. sklearnex/metrics/pairwise.py +22 -0
  201. sklearnex/metrics/ranking.py +20 -0
  202. sklearnex/metrics/tests/test_metrics.py +39 -0
  203. sklearnex/model_selection/__init__.py +21 -0
  204. sklearnex/model_selection/split.py +22 -0
  205. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  206. sklearnex/neighbors/__init__.py +27 -0
  207. sklearnex/neighbors/_lof.py +231 -0
  208. sklearnex/neighbors/common.py +310 -0
  209. sklearnex/neighbors/knn_classification.py +226 -0
  210. sklearnex/neighbors/knn_regression.py +203 -0
  211. sklearnex/neighbors/knn_unsupervised.py +170 -0
  212. sklearnex/neighbors/tests/test_neighbors.py +80 -0
  213. sklearnex/preview/__init__.py +17 -0
  214. sklearnex/preview/covariance/__init__.py +19 -0
  215. sklearnex/preview/covariance/covariance.py +133 -0
  216. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  217. sklearnex/preview/decomposition/__init__.py +19 -0
  218. sklearnex/preview/decomposition/incremental_pca.py +228 -0
  219. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  220. sklearnex/preview/linear_model/__init__.py +19 -0
  221. sklearnex/preview/linear_model/ridge.py +419 -0
  222. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  223. sklearnex/spmd/__init__.py +25 -0
  224. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  225. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  226. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  227. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  228. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  229. sklearnex/spmd/cluster/__init__.py +30 -0
  230. sklearnex/spmd/cluster/dbscan.py +50 -0
  231. sklearnex/spmd/cluster/kmeans.py +21 -0
  232. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  233. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  234. sklearnex/spmd/covariance/__init__.py +20 -0
  235. sklearnex/spmd/covariance/covariance.py +21 -0
  236. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  237. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  238. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  239. sklearnex/spmd/decomposition/__init__.py +20 -0
  240. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  241. sklearnex/spmd/decomposition/pca.py +21 -0
  242. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  243. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  244. sklearnex/spmd/ensemble/__init__.py +19 -0
  245. sklearnex/spmd/ensemble/forest.py +71 -0
  246. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  247. sklearnex/spmd/linear_model/__init__.py +21 -0
  248. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  249. sklearnex/spmd/linear_model/linear_model.py +21 -0
  250. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  251. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  252. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  253. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
  254. sklearnex/spmd/neighbors/__init__.py +19 -0
  255. sklearnex/spmd/neighbors/neighbors.py +25 -0
  256. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  257. sklearnex/svm/__init__.py +29 -0
  258. sklearnex/svm/_common.py +328 -0
  259. sklearnex/svm/nusvc.py +332 -0
  260. sklearnex/svm/nusvr.py +148 -0
  261. sklearnex/svm/svc.py +360 -0
  262. sklearnex/svm/svr.py +149 -0
  263. sklearnex/svm/tests/test_svm.py +93 -0
  264. sklearnex/tests/_utils.py +328 -0
  265. sklearnex/tests/_utils_spmd.py +198 -0
  266. sklearnex/tests/test_common.py +54 -0
  267. sklearnex/tests/test_config.py +43 -0
  268. sklearnex/tests/test_memory_usage.py +291 -0
  269. sklearnex/tests/test_monkeypatch.py +276 -0
  270. sklearnex/tests/test_n_jobs_support.py +103 -0
  271. sklearnex/tests/test_parallel.py +48 -0
  272. sklearnex/tests/test_patching.py +385 -0
  273. sklearnex/tests/test_run_to_run_stability.py +296 -0
  274. sklearnex/utils/__init__.py +19 -0
  275. sklearnex/utils/_array_api.py +82 -0
  276. sklearnex/utils/parallel.py +59 -0
  277. sklearnex/utils/tests/test_finite.py +89 -0
  278. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,720 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABCMeta, abstractmethod
20
+ from math import ceil
21
+ from numbers import Number
22
+
23
+ import numpy as np
24
+ from sklearn.ensemble import BaseEnsemble
25
+ from sklearn.utils import check_random_state
26
+
27
+ from daal4py.sklearn._utils import daal_check_version
28
+ from onedal import _backend
29
+
30
+ from ..common._base import BaseEstimator
31
+ from ..common._estimator_checks import _check_is_fitted
32
+ from ..common._mixin import ClassifierMixin, RegressorMixin
33
+ from ..datatypes import _convert_to_supported, from_table, to_table
34
+ from ..utils import (
35
+ _check_array,
36
+ _check_n_features,
37
+ _check_X_y,
38
+ _column_or_1d,
39
+ _validate_targets,
40
+ )
41
+
42
+
43
+ class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta):
44
+ @abstractmethod
45
+ def __init__(
46
+ self,
47
+ n_estimators,
48
+ criterion,
49
+ max_depth,
50
+ min_samples_split,
51
+ min_samples_leaf,
52
+ min_weight_fraction_leaf,
53
+ max_features,
54
+ max_leaf_nodes,
55
+ min_impurity_decrease,
56
+ min_impurity_split,
57
+ bootstrap,
58
+ oob_score,
59
+ random_state,
60
+ warm_start,
61
+ class_weight,
62
+ ccp_alpha,
63
+ max_samples,
64
+ max_bins,
65
+ min_bin_size,
66
+ infer_mode,
67
+ splitter_mode,
68
+ voting_mode,
69
+ error_metric_mode,
70
+ variable_importance_mode,
71
+ algorithm,
72
+ **kwargs,
73
+ ):
74
+ self.n_estimators = n_estimators
75
+ self.bootstrap = bootstrap
76
+ self.oob_score = oob_score
77
+ self.random_state = random_state
78
+ self.warm_start = warm_start
79
+ self.class_weight = class_weight
80
+ self.max_samples = max_samples
81
+ self.criterion = criterion
82
+ self.max_depth = max_depth
83
+ self.min_samples_split = min_samples_split
84
+ self.min_samples_leaf = min_samples_leaf
85
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
86
+ self.max_features = max_features
87
+ self.max_leaf_nodes = max_leaf_nodes
88
+ self.min_impurity_decrease = min_impurity_decrease
89
+ self.min_impurity_split = min_impurity_split
90
+ self.ccp_alpha = ccp_alpha
91
+ self.max_bins = max_bins
92
+ self.min_bin_size = min_bin_size
93
+ self.infer_mode = infer_mode
94
+ self.splitter_mode = splitter_mode
95
+ self.voting_mode = voting_mode
96
+ self.error_metric_mode = error_metric_mode
97
+ self.variable_importance_mode = variable_importance_mode
98
+ self.algorithm = algorithm
99
+
100
+ def _to_absolute_max_features(self, n_features):
101
+ if self.max_features is None:
102
+ return n_features
103
+ elif isinstance(self.max_features, str):
104
+ return max(1, int(getattr(np, self.max_features)(n_features)))
105
+ elif isinstance(self.max_features, (numbers.Integral, np.integer)):
106
+ return self.max_features
107
+ elif self.max_features > 0.0:
108
+ return max(1, int(self.max_features * n_features))
109
+ return 0
110
+
111
+ def _get_observations_per_tree_fraction(self, n_samples, max_samples):
112
+ if max_samples is None:
113
+ return 1.0
114
+
115
+ if isinstance(max_samples, numbers.Integral):
116
+ if not (1 <= max_samples <= n_samples):
117
+ msg = "`max_samples` must be in range 1 to {} but got value {}"
118
+ raise ValueError(msg.format(n_samples, max_samples))
119
+ return max(float(max_samples / n_samples), 1 / n_samples)
120
+
121
+ if isinstance(max_samples, numbers.Real):
122
+ return max(float(max_samples), 1 / n_samples)
123
+
124
+ msg = "`max_samples` should be int or float, but got type '{}'"
125
+ raise TypeError(msg.format(type(max_samples)))
126
+
127
+ def _get_onedal_params(self, data):
128
+ n_samples, n_features = data.shape
129
+
130
+ self.observations_per_tree_fraction = self._get_observations_per_tree_fraction(
131
+ n_samples=n_samples, max_samples=self.max_samples
132
+ )
133
+ self.observations_per_tree_fraction = (
134
+ self.observations_per_tree_fraction if bool(self.bootstrap) else 1.0
135
+ )
136
+
137
+ if not self.bootstrap and self.max_samples is not None:
138
+ raise ValueError(
139
+ "`max_sample` cannot be set if `bootstrap=False`. "
140
+ "Either switch to `bootstrap=True` or set "
141
+ "`max_sample=None`."
142
+ )
143
+ if not self.bootstrap and self.oob_score:
144
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
145
+
146
+ min_observations_in_leaf_node = (
147
+ self.min_samples_leaf
148
+ if isinstance(self.min_samples_leaf, numbers.Integral)
149
+ else int(ceil(self.min_samples_leaf * n_samples))
150
+ )
151
+
152
+ min_observations_in_split_node = (
153
+ self.min_samples_split
154
+ if isinstance(self.min_samples_split, numbers.Integral)
155
+ else int(ceil(self.min_samples_split * n_samples))
156
+ )
157
+
158
+ rs = check_random_state(self.random_state)
159
+ seed = rs.randint(0, np.iinfo("i").max)
160
+
161
+ onedal_params = {
162
+ "fptype": "float" if data.dtype == np.float32 else "double",
163
+ "method": self.algorithm,
164
+ "infer_mode": self.infer_mode,
165
+ "voting_mode": self.voting_mode,
166
+ "observations_per_tree_fraction": self.observations_per_tree_fraction,
167
+ "impurity_threshold": float(
168
+ 0.0 if self.min_impurity_split is None else self.min_impurity_split
169
+ ),
170
+ "min_weight_fraction_in_leaf_node": self.min_weight_fraction_leaf,
171
+ "min_impurity_decrease_in_split_node": self.min_impurity_decrease,
172
+ "tree_count": int(self.n_estimators),
173
+ "features_per_node": self._to_absolute_max_features(n_features),
174
+ "max_tree_depth": int(0 if self.max_depth is None else self.max_depth),
175
+ "min_observations_in_leaf_node": min_observations_in_leaf_node,
176
+ "min_observations_in_split_node": min_observations_in_split_node,
177
+ "max_leaf_nodes": (0 if self.max_leaf_nodes is None else self.max_leaf_nodes),
178
+ "max_bins": self.max_bins,
179
+ "min_bin_size": self.min_bin_size,
180
+ "seed": seed,
181
+ "memory_saving_mode": False,
182
+ "bootstrap": bool(self.bootstrap),
183
+ "error_metric_mode": self.error_metric_mode,
184
+ "variable_importance_mode": self.variable_importance_mode,
185
+ }
186
+ if isinstance(self, ClassifierMixin):
187
+ onedal_params["class_count"] = (
188
+ 0 if self.classes_ is None else len(self.classes_)
189
+ )
190
+ if daal_check_version((2023, "P", 101)):
191
+ onedal_params["splitter_mode"] = self.splitter_mode
192
+ return onedal_params
193
+
194
+ def _check_parameters(self):
195
+ if isinstance(self.min_samples_leaf, numbers.Integral):
196
+ if not 1 <= self.min_samples_leaf:
197
+ raise ValueError(
198
+ "min_samples_leaf must be at least 1 "
199
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
200
+ )
201
+ else: # float
202
+ if not 0.0 < self.min_samples_leaf <= 0.5:
203
+ raise ValueError(
204
+ "min_samples_leaf must be at least 1 "
205
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
206
+ )
207
+ if isinstance(self.min_samples_split, numbers.Integral):
208
+ if not 2 <= self.min_samples_split:
209
+ raise ValueError(
210
+ "min_samples_split must be an integer "
211
+ "greater than 1 or a float in (0.0, 1.0]; "
212
+ "got the integer %s" % self.min_samples_split
213
+ )
214
+ else: # float
215
+ if not 0.0 < self.min_samples_split <= 1.0:
216
+ raise ValueError(
217
+ "min_samples_split must be an integer "
218
+ "greater than 1 or a float in (0.0, 1.0]; "
219
+ "got the float %s" % self.min_samples_split
220
+ )
221
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
222
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
223
+ if self.min_impurity_split is not None:
224
+ warnings.warn(
225
+ "The min_impurity_split parameter is deprecated. "
226
+ "Its default value has changed from 1e-7 to 0 in "
227
+ "version 0.23, and it will be removed in 0.25. "
228
+ "Use the min_impurity_decrease parameter instead.",
229
+ FutureWarning,
230
+ )
231
+
232
+ if self.min_impurity_split < 0.0:
233
+ raise ValueError(
234
+ "min_impurity_split must be greater than " "or equal to 0"
235
+ )
236
+ if self.min_impurity_decrease < 0.0:
237
+ raise ValueError(
238
+ "min_impurity_decrease must be greater than " "or equal to 0"
239
+ )
240
+ if self.max_leaf_nodes is not None:
241
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
242
+ raise ValueError(
243
+ "max_leaf_nodes must be integral number but was "
244
+ "%r" % self.max_leaf_nodes
245
+ )
246
+ if self.max_leaf_nodes < 2:
247
+ raise ValueError(
248
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
249
+ self.max_leaf_nodes
250
+ )
251
+ )
252
+ if isinstance(self.max_bins, numbers.Integral):
253
+ if not 2 <= self.max_bins:
254
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
255
+ else:
256
+ raise ValueError(
257
+ "max_bins must be integral number but was " "%r" % self.max_bins
258
+ )
259
+ if isinstance(self.min_bin_size, numbers.Integral):
260
+ if not 1 <= self.min_bin_size:
261
+ raise ValueError(
262
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
263
+ )
264
+ else:
265
+ raise ValueError(
266
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
267
+ )
268
+
269
+ def _validate_targets(self, y, dtype):
270
+ self.class_weight_ = None
271
+ self.classes_ = None
272
+ return _column_or_1d(y, warn=True).astype(dtype, copy=False)
273
+
274
+ def _get_sample_weight(self, sample_weight, X):
275
+ sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel()
276
+
277
+ sample_weight = _check_array(
278
+ sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C"
279
+ )
280
+
281
+ if sample_weight.size != X.shape[0]:
282
+ raise ValueError(
283
+ "sample_weight and X have incompatible shapes: "
284
+ "%r vs %r\n"
285
+ "Note: Sparse matrices cannot be indexed w/"
286
+ "boolean masks (use `indices=True` in CV)."
287
+ % (sample_weight.shape, X.shape)
288
+ )
289
+
290
+ return sample_weight
291
+
292
+ def _fit(self, X, y, sample_weight, module, queue):
293
+ X, y = _check_X_y(
294
+ X,
295
+ y,
296
+ dtype=[np.float64, np.float32],
297
+ force_all_finite=True,
298
+ accept_sparse="csr",
299
+ )
300
+ y = self._validate_targets(y, X.dtype)
301
+
302
+ self.n_features_in_ = X.shape[1]
303
+
304
+ if sample_weight is not None and len(sample_weight) > 0:
305
+ sample_weight = self._get_sample_weight(sample_weight, X)
306
+ data = (X, y, sample_weight)
307
+ else:
308
+ data = (X, y)
309
+ policy = self._get_policy(queue, *data)
310
+ data = _convert_to_supported(policy, *data)
311
+ params = self._get_onedal_params(data[0])
312
+ train_result = module.train(policy, params, *to_table(*data))
313
+
314
+ self._onedal_model = train_result.model
315
+
316
+ if self.oob_score:
317
+ if isinstance(self, ClassifierMixin):
318
+ self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
319
+ self.oob_decision_function_ = from_table(
320
+ train_result.oob_err_decision_function
321
+ )
322
+ if np.any(self.oob_decision_function_ == 0):
323
+ warnings.warn(
324
+ "Some inputs do not have OOB scores. This probably means "
325
+ "too few trees were used to compute any reliable OOB "
326
+ "estimates.",
327
+ UserWarning,
328
+ )
329
+ else:
330
+ self.oob_score_ = from_table(train_result.oob_err_r2).item()
331
+ self.oob_prediction_ = from_table(
332
+ train_result.oob_err_prediction
333
+ ).reshape(-1)
334
+ if np.any(self.oob_prediction_ == 0):
335
+ warnings.warn(
336
+ "Some inputs do not have OOB scores. This probably means "
337
+ "too few trees were used to compute any reliable OOB "
338
+ "estimates.",
339
+ UserWarning,
340
+ )
341
+
342
+ return self
343
+
344
+ def _create_model(self, module):
345
+ # TODO:
346
+ # upate error msg.
347
+ raise NotImplementedError("Creating model is not supported.")
348
+
349
+ def _predict(self, X, module, queue):
350
+ _check_is_fitted(self)
351
+ X = _check_array(
352
+ X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
353
+ )
354
+ _check_n_features(self, X, False)
355
+ policy = self._get_policy(queue, X)
356
+
357
+ model = self._onedal_model
358
+ X = _convert_to_supported(policy, X)
359
+ params = self._get_onedal_params(X)
360
+ result = module.infer(policy, params, model, to_table(X))
361
+ y = from_table(result.responses)
362
+ return y
363
+
364
+ def _predict_proba(self, X, module, queue):
365
+ _check_is_fitted(self)
366
+ X = _check_array(
367
+ X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
368
+ )
369
+ _check_n_features(self, X, False)
370
+ policy = self._get_policy(queue, X)
371
+ X = _convert_to_supported(policy, X)
372
+ params = self._get_onedal_params(X)
373
+ params["infer_mode"] = "class_probabilities"
374
+
375
+ model = self._onedal_model
376
+ result = module.infer(policy, params, model, to_table(X))
377
+ y = from_table(result.probabilities)
378
+ return y
379
+
380
+
381
+ class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
382
+ def __init__(
383
+ self,
384
+ n_estimators=100,
385
+ criterion="gini",
386
+ max_depth=None,
387
+ min_samples_split=2,
388
+ min_samples_leaf=1,
389
+ min_weight_fraction_leaf=0.0,
390
+ max_features="sqrt",
391
+ max_leaf_nodes=None,
392
+ min_impurity_decrease=0.0,
393
+ min_impurity_split=None,
394
+ bootstrap=True,
395
+ oob_score=False,
396
+ random_state=None,
397
+ warm_start=False,
398
+ class_weight=None,
399
+ ccp_alpha=0.0,
400
+ max_samples=None,
401
+ max_bins=256,
402
+ min_bin_size=1,
403
+ infer_mode="class_responses",
404
+ splitter_mode="best",
405
+ voting_mode="weighted",
406
+ error_metric_mode="none",
407
+ variable_importance_mode="none",
408
+ algorithm="hist",
409
+ **kwargs,
410
+ ):
411
+ super().__init__(
412
+ n_estimators=n_estimators,
413
+ criterion=criterion,
414
+ max_depth=max_depth,
415
+ min_samples_split=min_samples_split,
416
+ min_samples_leaf=min_samples_leaf,
417
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
418
+ max_features=max_features,
419
+ max_leaf_nodes=max_leaf_nodes,
420
+ min_impurity_decrease=min_impurity_decrease,
421
+ min_impurity_split=min_impurity_split,
422
+ bootstrap=bootstrap,
423
+ oob_score=oob_score,
424
+ random_state=random_state,
425
+ warm_start=warm_start,
426
+ class_weight=class_weight,
427
+ ccp_alpha=ccp_alpha,
428
+ max_samples=max_samples,
429
+ max_bins=max_bins,
430
+ min_bin_size=min_bin_size,
431
+ infer_mode=infer_mode,
432
+ splitter_mode=splitter_mode,
433
+ voting_mode=voting_mode,
434
+ error_metric_mode=error_metric_mode,
435
+ variable_importance_mode=variable_importance_mode,
436
+ algorithm=algorithm,
437
+ )
438
+
439
+ def _validate_targets(self, y, dtype):
440
+ y, self.class_weight_, self.classes_ = _validate_targets(
441
+ y, self.class_weight, dtype
442
+ )
443
+
444
+ # Decapsulate classes_ attributes
445
+ # TODO:
446
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
447
+ # if hasattr(self, "classes_"):
448
+ # self.n_classes_ = self.classes_
449
+ return y
450
+
451
+ def fit(self, X, y, sample_weight=None, queue=None):
452
+ return self._fit(
453
+ X,
454
+ y,
455
+ sample_weight,
456
+ self._get_backend("decision_forest", "classification", None),
457
+ queue,
458
+ )
459
+
460
+ def predict(self, X, queue=None):
461
+ pred = super()._predict(
462
+ X, self._get_backend("decision_forest", "classification", None), queue
463
+ )
464
+
465
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
466
+
467
+ def predict_proba(self, X, queue=None):
468
+ return super()._predict_proba(
469
+ X, self._get_backend("decision_forest", "classification", None), queue
470
+ )
471
+
472
+
473
+ class RandomForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
474
+ def __init__(
475
+ self,
476
+ n_estimators=100,
477
+ criterion="squared_error",
478
+ max_depth=None,
479
+ min_samples_split=2,
480
+ min_samples_leaf=1,
481
+ min_weight_fraction_leaf=0.0,
482
+ max_features=1.0,
483
+ max_leaf_nodes=None,
484
+ min_impurity_decrease=0.0,
485
+ min_impurity_split=None,
486
+ bootstrap=True,
487
+ oob_score=False,
488
+ random_state=None,
489
+ warm_start=False,
490
+ class_weight=None,
491
+ ccp_alpha=0.0,
492
+ max_samples=None,
493
+ max_bins=256,
494
+ min_bin_size=1,
495
+ infer_mode="class_responses",
496
+ splitter_mode="best",
497
+ voting_mode="weighted",
498
+ error_metric_mode="none",
499
+ variable_importance_mode="none",
500
+ algorithm="hist",
501
+ **kwargs,
502
+ ):
503
+ super().__init__(
504
+ n_estimators=n_estimators,
505
+ criterion=criterion,
506
+ max_depth=max_depth,
507
+ min_samples_split=min_samples_split,
508
+ min_samples_leaf=min_samples_leaf,
509
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
510
+ max_features=max_features,
511
+ max_leaf_nodes=max_leaf_nodes,
512
+ min_impurity_decrease=min_impurity_decrease,
513
+ min_impurity_split=min_impurity_split,
514
+ bootstrap=bootstrap,
515
+ oob_score=oob_score,
516
+ random_state=random_state,
517
+ warm_start=warm_start,
518
+ class_weight=class_weight,
519
+ ccp_alpha=ccp_alpha,
520
+ max_samples=max_samples,
521
+ max_bins=max_bins,
522
+ min_bin_size=min_bin_size,
523
+ infer_mode=infer_mode,
524
+ splitter_mode=splitter_mode,
525
+ voting_mode=voting_mode,
526
+ error_metric_mode=error_metric_mode,
527
+ variable_importance_mode=variable_importance_mode,
528
+ algorithm=algorithm,
529
+ )
530
+
531
+ def fit(self, X, y, sample_weight=None, queue=None):
532
+ if sample_weight is not None:
533
+ if hasattr(sample_weight, "__array__"):
534
+ sample_weight[sample_weight == 0.0] = 1.0
535
+ sample_weight = [sample_weight]
536
+ return super()._fit(
537
+ X,
538
+ y,
539
+ sample_weight,
540
+ self._get_backend("decision_forest", "regression", None),
541
+ queue,
542
+ )
543
+
544
+ def predict(self, X, queue=None):
545
+ return (
546
+ super()
547
+ ._predict(X, self._get_backend("decision_forest", "regression", None), queue)
548
+ .ravel()
549
+ )
550
+
551
+
552
+ class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
553
+ def __init__(
554
+ self,
555
+ n_estimators=100,
556
+ criterion="gini",
557
+ max_depth=None,
558
+ min_samples_split=2,
559
+ min_samples_leaf=1,
560
+ min_weight_fraction_leaf=0.0,
561
+ max_features="sqrt",
562
+ max_leaf_nodes=None,
563
+ min_impurity_decrease=0.0,
564
+ min_impurity_split=None,
565
+ bootstrap=False,
566
+ oob_score=False,
567
+ random_state=None,
568
+ warm_start=False,
569
+ class_weight=None,
570
+ ccp_alpha=0.0,
571
+ max_samples=None,
572
+ max_bins=256,
573
+ min_bin_size=1,
574
+ infer_mode="class_responses",
575
+ splitter_mode="random",
576
+ voting_mode="weighted",
577
+ error_metric_mode="none",
578
+ variable_importance_mode="none",
579
+ algorithm="hist",
580
+ **kwargs,
581
+ ):
582
+ super().__init__(
583
+ n_estimators=n_estimators,
584
+ criterion=criterion,
585
+ max_depth=max_depth,
586
+ min_samples_split=min_samples_split,
587
+ min_samples_leaf=min_samples_leaf,
588
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
589
+ max_features=max_features,
590
+ max_leaf_nodes=max_leaf_nodes,
591
+ min_impurity_decrease=min_impurity_decrease,
592
+ min_impurity_split=min_impurity_split,
593
+ bootstrap=bootstrap,
594
+ oob_score=oob_score,
595
+ random_state=random_state,
596
+ warm_start=warm_start,
597
+ class_weight=class_weight,
598
+ ccp_alpha=ccp_alpha,
599
+ max_samples=max_samples,
600
+ max_bins=max_bins,
601
+ min_bin_size=min_bin_size,
602
+ infer_mode=infer_mode,
603
+ splitter_mode=splitter_mode,
604
+ voting_mode=voting_mode,
605
+ error_metric_mode=error_metric_mode,
606
+ variable_importance_mode=variable_importance_mode,
607
+ algorithm=algorithm,
608
+ )
609
+
610
+ def _validate_targets(self, y, dtype):
611
+ y, self.class_weight_, self.classes_ = _validate_targets(
612
+ y, self.class_weight, dtype
613
+ )
614
+
615
+ # Decapsulate classes_ attributes
616
+ # TODO:
617
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
618
+ # if hasattr(self, "classes_"):
619
+ # self.n_classes_ = self.classes_
620
+ return y
621
+
622
+ def fit(self, X, y, sample_weight=None, queue=None):
623
+ return self._fit(
624
+ X,
625
+ y,
626
+ sample_weight,
627
+ self._get_backend("decision_forest", "classification", None),
628
+ queue,
629
+ )
630
+
631
+ def predict(self, X, queue=None):
632
+ pred = super()._predict(
633
+ X, self._get_backend("decision_forest", "classification", None), queue
634
+ )
635
+
636
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
637
+
638
+ def predict_proba(self, X, queue=None):
639
+ return super()._predict_proba(
640
+ X, self._get_backend("decision_forest", "classification", None), queue
641
+ )
642
+
643
+
644
+ class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
645
+ def __init__(
646
+ self,
647
+ n_estimators=100,
648
+ criterion="squared_error",
649
+ max_depth=None,
650
+ min_samples_split=2,
651
+ min_samples_leaf=1,
652
+ min_weight_fraction_leaf=0.0,
653
+ max_features=1.0,
654
+ max_leaf_nodes=None,
655
+ min_impurity_decrease=0.0,
656
+ min_impurity_split=None,
657
+ bootstrap=False,
658
+ oob_score=False,
659
+ random_state=None,
660
+ warm_start=False,
661
+ class_weight=None,
662
+ ccp_alpha=0.0,
663
+ max_samples=None,
664
+ max_bins=256,
665
+ min_bin_size=1,
666
+ infer_mode="class_responses",
667
+ splitter_mode="random",
668
+ voting_mode="weighted",
669
+ error_metric_mode="none",
670
+ variable_importance_mode="none",
671
+ algorithm="hist",
672
+ **kwargs,
673
+ ):
674
+ super().__init__(
675
+ n_estimators=n_estimators,
676
+ criterion=criterion,
677
+ max_depth=max_depth,
678
+ min_samples_split=min_samples_split,
679
+ min_samples_leaf=min_samples_leaf,
680
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
681
+ max_features=max_features,
682
+ max_leaf_nodes=max_leaf_nodes,
683
+ min_impurity_decrease=min_impurity_decrease,
684
+ min_impurity_split=min_impurity_split,
685
+ bootstrap=bootstrap,
686
+ oob_score=oob_score,
687
+ random_state=random_state,
688
+ warm_start=warm_start,
689
+ class_weight=class_weight,
690
+ ccp_alpha=ccp_alpha,
691
+ max_samples=max_samples,
692
+ max_bins=max_bins,
693
+ min_bin_size=min_bin_size,
694
+ infer_mode=infer_mode,
695
+ splitter_mode=splitter_mode,
696
+ voting_mode=voting_mode,
697
+ error_metric_mode=error_metric_mode,
698
+ variable_importance_mode=variable_importance_mode,
699
+ algorithm=algorithm,
700
+ )
701
+
702
+ def fit(self, X, y, sample_weight=None, queue=None):
703
+ if sample_weight is not None:
704
+ if hasattr(sample_weight, "__array__"):
705
+ sample_weight[sample_weight == 0.0] = 1.0
706
+ sample_weight = [sample_weight]
707
+ return super()._fit(
708
+ X,
709
+ y,
710
+ sample_weight,
711
+ self._get_backend("decision_forest", "regression", None),
712
+ queue,
713
+ )
714
+
715
+ def predict(self, X, queue=None):
716
+ return (
717
+ super()
718
+ ._predict(X, self._get_backend("decision_forest", "regression", None), queue)
719
+ .ravel()
720
+ )