scikit-learn-intelex 2025.1.0__py312-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (280) hide show
  1. daal4py/__init__.py +73 -0
  2. daal4py/__main__.py +58 -0
  3. daal4py/_daal4py.cpython-312-x86_64-linux-gnu.so +0 -0
  4. daal4py/doc/third-party-programs.txt +424 -0
  5. daal4py/mb/__init__.py +19 -0
  6. daal4py/mb/model_builders.py +377 -0
  7. daal4py/mpi_transceiver.cpython-312-x86_64-linux-gnu.so +0 -0
  8. daal4py/sklearn/__init__.py +40 -0
  9. daal4py/sklearn/_n_jobs_support.py +248 -0
  10. daal4py/sklearn/_utils.py +245 -0
  11. daal4py/sklearn/cluster/__init__.py +20 -0
  12. daal4py/sklearn/cluster/dbscan.py +165 -0
  13. daal4py/sklearn/cluster/k_means.py +597 -0
  14. daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. daal4py/sklearn/decomposition/__init__.py +19 -0
  16. daal4py/sklearn/decomposition/_pca.py +524 -0
  17. daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. daal4py/sklearn/ensemble/__init__.py +27 -0
  20. daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. daal4py/sklearn/linear_model/__init__.py +29 -0
  23. daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. daal4py/sklearn/linear_model/_linear.py +272 -0
  25. daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  27. daal4py/sklearn/linear_model/linear.py +17 -0
  28. daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. daal4py/sklearn/linear_model/ridge.py +17 -0
  31. daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. daal4py/sklearn/manifold/__init__.py +19 -0
  34. daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. daal4py/sklearn/metrics/__init__.py +20 -0
  36. daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. daal4py/sklearn/metrics/_ranking.py +210 -0
  38. daal4py/sklearn/model_selection/__init__.py +19 -0
  39. daal4py/sklearn/model_selection/_split.py +309 -0
  40. daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
  44. daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  46. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  47. daal4py/sklearn/neighbors/__init__.py +21 -0
  48. daal4py/sklearn/neighbors/_base.py +503 -0
  49. daal4py/sklearn/neighbors/_classification.py +139 -0
  50. daal4py/sklearn/neighbors/_regression.py +74 -0
  51. daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. daal4py/sklearn/svm/__init__.py +19 -0
  54. daal4py/sklearn/svm/svm.py +734 -0
  55. daal4py/sklearn/utils/__init__.py +21 -0
  56. daal4py/sklearn/utils/base.py +75 -0
  57. daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. daal4py/sklearn/utils/validation.py +693 -0
  59. onedal/__init__.py +83 -0
  60. onedal/_config.py +54 -0
  61. onedal/_device_offload.py +222 -0
  62. onedal/_onedal_py_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  63. onedal/_onedal_py_host.cpython-312-x86_64-linux-gnu.so +0 -0
  64. onedal/_onedal_py_spmd_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
  65. onedal/basic_statistics/__init__.py +20 -0
  66. onedal/basic_statistics/basic_statistics.py +107 -0
  67. onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  68. onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  69. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  70. onedal/cluster/__init__.py +27 -0
  71. onedal/cluster/dbscan.py +110 -0
  72. onedal/cluster/kmeans.py +564 -0
  73. onedal/cluster/kmeans_init.py +115 -0
  74. onedal/cluster/tests/test_dbscan.py +125 -0
  75. onedal/cluster/tests/test_kmeans.py +88 -0
  76. onedal/cluster/tests/test_kmeans_init.py +93 -0
  77. onedal/common/_base.py +38 -0
  78. onedal/common/_estimator_checks.py +47 -0
  79. onedal/common/_mixin.py +62 -0
  80. onedal/common/_policy.py +59 -0
  81. onedal/common/_spmd_policy.py +30 -0
  82. onedal/common/hyperparameters.py +125 -0
  83. onedal/common/tests/test_policy.py +76 -0
  84. onedal/covariance/__init__.py +20 -0
  85. onedal/covariance/covariance.py +125 -0
  86. onedal/covariance/incremental_covariance.py +146 -0
  87. onedal/covariance/tests/test_covariance.py +50 -0
  88. onedal/covariance/tests/test_incremental_covariance.py +122 -0
  89. onedal/datatypes/__init__.py +19 -0
  90. onedal/datatypes/_data_conversion.py +154 -0
  91. onedal/datatypes/tests/common.py +126 -0
  92. onedal/datatypes/tests/test_data.py +414 -0
  93. onedal/decomposition/__init__.py +20 -0
  94. onedal/decomposition/incremental_pca.py +204 -0
  95. onedal/decomposition/pca.py +186 -0
  96. onedal/decomposition/tests/test_incremental_pca.py +198 -0
  97. onedal/ensemble/__init__.py +29 -0
  98. onedal/ensemble/forest.py +727 -0
  99. onedal/ensemble/tests/test_random_forest.py +97 -0
  100. onedal/linear_model/__init__.py +27 -0
  101. onedal/linear_model/incremental_linear_model.py +258 -0
  102. onedal/linear_model/linear_model.py +329 -0
  103. onedal/linear_model/logistic_regression.py +249 -0
  104. onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  105. onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  106. onedal/linear_model/tests/test_linear_regression.py +250 -0
  107. onedal/linear_model/tests/test_logistic_regression.py +95 -0
  108. onedal/linear_model/tests/test_ridge.py +95 -0
  109. onedal/neighbors/__init__.py +19 -0
  110. onedal/neighbors/neighbors.py +767 -0
  111. onedal/neighbors/tests/test_knn_classification.py +49 -0
  112. onedal/primitives/__init__.py +27 -0
  113. onedal/primitives/get_tree.py +25 -0
  114. onedal/primitives/kernel_functions.py +153 -0
  115. onedal/primitives/tests/test_kernel_functions.py +159 -0
  116. onedal/spmd/__init__.py +25 -0
  117. onedal/spmd/_base.py +30 -0
  118. onedal/spmd/basic_statistics/__init__.py +20 -0
  119. onedal/spmd/basic_statistics/basic_statistics.py +30 -0
  120. onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
  121. onedal/spmd/cluster/__init__.py +28 -0
  122. onedal/spmd/cluster/dbscan.py +23 -0
  123. onedal/spmd/cluster/kmeans.py +56 -0
  124. onedal/spmd/covariance/__init__.py +20 -0
  125. onedal/spmd/covariance/covariance.py +26 -0
  126. onedal/spmd/covariance/incremental_covariance.py +82 -0
  127. onedal/spmd/decomposition/__init__.py +20 -0
  128. onedal/spmd/decomposition/incremental_pca.py +117 -0
  129. onedal/spmd/decomposition/pca.py +26 -0
  130. onedal/spmd/ensemble/__init__.py +19 -0
  131. onedal/spmd/ensemble/forest.py +28 -0
  132. onedal/spmd/linear_model/__init__.py +21 -0
  133. onedal/spmd/linear_model/incremental_linear_model.py +97 -0
  134. onedal/spmd/linear_model/linear_model.py +30 -0
  135. onedal/spmd/linear_model/logistic_regression.py +38 -0
  136. onedal/spmd/neighbors/__init__.py +19 -0
  137. onedal/spmd/neighbors/neighbors.py +75 -0
  138. onedal/svm/__init__.py +19 -0
  139. onedal/svm/svm.py +556 -0
  140. onedal/svm/tests/test_csr_svm.py +351 -0
  141. onedal/svm/tests/test_nusvc.py +204 -0
  142. onedal/svm/tests/test_nusvr.py +210 -0
  143. onedal/svm/tests/test_svc.py +176 -0
  144. onedal/svm/tests/test_svr.py +243 -0
  145. onedal/tests/test_common.py +57 -0
  146. onedal/tests/utils/_dataframes_support.py +162 -0
  147. onedal/tests/utils/_device_selection.py +102 -0
  148. onedal/utils/__init__.py +49 -0
  149. onedal/utils/_array_api.py +81 -0
  150. onedal/utils/_dpep_helpers.py +56 -0
  151. onedal/utils/validation.py +440 -0
  152. scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
  153. scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
  154. scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
  155. scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
  156. scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
  157. sklearnex/__init__.py +66 -0
  158. sklearnex/__main__.py +58 -0
  159. sklearnex/_config.py +116 -0
  160. sklearnex/_device_offload.py +126 -0
  161. sklearnex/_utils.py +132 -0
  162. sklearnex/basic_statistics/__init__.py +20 -0
  163. sklearnex/basic_statistics/basic_statistics.py +230 -0
  164. sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  165. sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  166. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  167. sklearnex/cluster/__init__.py +20 -0
  168. sklearnex/cluster/dbscan.py +197 -0
  169. sklearnex/cluster/k_means.py +395 -0
  170. sklearnex/cluster/tests/test_dbscan.py +38 -0
  171. sklearnex/cluster/tests/test_kmeans.py +159 -0
  172. sklearnex/conftest.py +82 -0
  173. sklearnex/covariance/__init__.py +19 -0
  174. sklearnex/covariance/incremental_covariance.py +398 -0
  175. sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  176. sklearnex/decomposition/__init__.py +19 -0
  177. sklearnex/decomposition/pca.py +425 -0
  178. sklearnex/decomposition/tests/test_pca.py +58 -0
  179. sklearnex/dispatcher.py +543 -0
  180. sklearnex/doc/third-party-programs.txt +424 -0
  181. sklearnex/ensemble/__init__.py +29 -0
  182. sklearnex/ensemble/_forest.py +2029 -0
  183. sklearnex/ensemble/tests/test_forest.py +135 -0
  184. sklearnex/glob/__main__.py +72 -0
  185. sklearnex/glob/dispatcher.py +101 -0
  186. sklearnex/linear_model/__init__.py +32 -0
  187. sklearnex/linear_model/coordinate_descent.py +30 -0
  188. sklearnex/linear_model/incremental_linear.py +482 -0
  189. sklearnex/linear_model/incremental_ridge.py +425 -0
  190. sklearnex/linear_model/linear.py +341 -0
  191. sklearnex/linear_model/logistic_regression.py +413 -0
  192. sklearnex/linear_model/ridge.py +24 -0
  193. sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  194. sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  195. sklearnex/linear_model/tests/test_linear.py +167 -0
  196. sklearnex/linear_model/tests/test_logreg.py +134 -0
  197. sklearnex/manifold/__init__.py +19 -0
  198. sklearnex/manifold/t_sne.py +21 -0
  199. sklearnex/manifold/tests/test_tsne.py +26 -0
  200. sklearnex/metrics/__init__.py +23 -0
  201. sklearnex/metrics/pairwise.py +22 -0
  202. sklearnex/metrics/ranking.py +20 -0
  203. sklearnex/metrics/tests/test_metrics.py +39 -0
  204. sklearnex/model_selection/__init__.py +21 -0
  205. sklearnex/model_selection/split.py +22 -0
  206. sklearnex/model_selection/tests/test_model_selection.py +34 -0
  207. sklearnex/neighbors/__init__.py +27 -0
  208. sklearnex/neighbors/_lof.py +236 -0
  209. sklearnex/neighbors/common.py +310 -0
  210. sklearnex/neighbors/knn_classification.py +231 -0
  211. sklearnex/neighbors/knn_regression.py +207 -0
  212. sklearnex/neighbors/knn_unsupervised.py +178 -0
  213. sklearnex/neighbors/tests/test_neighbors.py +82 -0
  214. sklearnex/preview/__init__.py +17 -0
  215. sklearnex/preview/covariance/__init__.py +19 -0
  216. sklearnex/preview/covariance/covariance.py +138 -0
  217. sklearnex/preview/covariance/tests/test_covariance.py +66 -0
  218. sklearnex/preview/decomposition/__init__.py +19 -0
  219. sklearnex/preview/decomposition/incremental_pca.py +233 -0
  220. sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  221. sklearnex/preview/linear_model/__init__.py +19 -0
  222. sklearnex/preview/linear_model/ridge.py +424 -0
  223. sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  224. sklearnex/spmd/__init__.py +25 -0
  225. sklearnex/spmd/basic_statistics/__init__.py +20 -0
  226. sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  227. sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  228. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  229. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  230. sklearnex/spmd/cluster/__init__.py +30 -0
  231. sklearnex/spmd/cluster/dbscan.py +50 -0
  232. sklearnex/spmd/cluster/kmeans.py +21 -0
  233. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  234. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  235. sklearnex/spmd/covariance/__init__.py +20 -0
  236. sklearnex/spmd/covariance/covariance.py +21 -0
  237. sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  238. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  239. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  240. sklearnex/spmd/decomposition/__init__.py +20 -0
  241. sklearnex/spmd/decomposition/incremental_pca.py +30 -0
  242. sklearnex/spmd/decomposition/pca.py +21 -0
  243. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  244. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  245. sklearnex/spmd/ensemble/__init__.py +19 -0
  246. sklearnex/spmd/ensemble/forest.py +71 -0
  247. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  248. sklearnex/spmd/linear_model/__init__.py +21 -0
  249. sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
  250. sklearnex/spmd/linear_model/linear_model.py +21 -0
  251. sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  252. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  253. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  254. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  255. sklearnex/spmd/neighbors/__init__.py +19 -0
  256. sklearnex/spmd/neighbors/neighbors.py +25 -0
  257. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  258. sklearnex/svm/__init__.py +29 -0
  259. sklearnex/svm/_common.py +339 -0
  260. sklearnex/svm/nusvc.py +371 -0
  261. sklearnex/svm/nusvr.py +170 -0
  262. sklearnex/svm/svc.py +399 -0
  263. sklearnex/svm/svr.py +167 -0
  264. sklearnex/svm/tests/test_svm.py +93 -0
  265. sklearnex/tests/test_common.py +390 -0
  266. sklearnex/tests/test_config.py +123 -0
  267. sklearnex/tests/test_memory_usage.py +379 -0
  268. sklearnex/tests/test_monkeypatch.py +276 -0
  269. sklearnex/tests/test_n_jobs_support.py +108 -0
  270. sklearnex/tests/test_parallel.py +48 -0
  271. sklearnex/tests/test_patching.py +385 -0
  272. sklearnex/tests/test_run_to_run_stability.py +321 -0
  273. sklearnex/tests/utils/__init__.py +44 -0
  274. sklearnex/tests/utils/base.py +371 -0
  275. sklearnex/tests/utils/spmd.py +198 -0
  276. sklearnex/utils/__init__.py +19 -0
  277. sklearnex/utils/_array_api.py +82 -0
  278. sklearnex/utils/parallel.py +59 -0
  279. sklearnex/utils/tests/test_finite.py +89 -0
  280. sklearnex/utils/validation.py +17 -0
@@ -0,0 +1,727 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numbers
18
+ import warnings
19
+ from abc import ABCMeta, abstractmethod
20
+ from math import ceil
21
+
22
+ import numpy as np
23
+ from sklearn.ensemble import BaseEnsemble
24
+ from sklearn.utils import check_random_state
25
+
26
+ from daal4py.sklearn._utils import daal_check_version
27
+ from sklearnex import get_hyperparameters
28
+
29
+ from ..common._base import BaseEstimator
30
+ from ..common._estimator_checks import _check_is_fitted
31
+ from ..common._mixin import ClassifierMixin, RegressorMixin
32
+ from ..datatypes import _convert_to_supported, from_table, to_table
33
+ from ..utils import (
34
+ _check_array,
35
+ _check_n_features,
36
+ _check_X_y,
37
+ _column_or_1d,
38
+ _validate_targets,
39
+ )
40
+
41
+
42
+ class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta):
43
+ @abstractmethod
44
+ def __init__(
45
+ self,
46
+ n_estimators,
47
+ criterion,
48
+ max_depth,
49
+ min_samples_split,
50
+ min_samples_leaf,
51
+ min_weight_fraction_leaf,
52
+ max_features,
53
+ max_leaf_nodes,
54
+ min_impurity_decrease,
55
+ min_impurity_split,
56
+ bootstrap,
57
+ oob_score,
58
+ random_state,
59
+ warm_start,
60
+ class_weight,
61
+ ccp_alpha,
62
+ max_samples,
63
+ max_bins,
64
+ min_bin_size,
65
+ infer_mode,
66
+ splitter_mode,
67
+ voting_mode,
68
+ error_metric_mode,
69
+ variable_importance_mode,
70
+ algorithm,
71
+ **kwargs,
72
+ ):
73
+ self.n_estimators = n_estimators
74
+ self.bootstrap = bootstrap
75
+ self.oob_score = oob_score
76
+ self.random_state = random_state
77
+ self.warm_start = warm_start
78
+ self.class_weight = class_weight
79
+ self.max_samples = max_samples
80
+ self.criterion = criterion
81
+ self.max_depth = max_depth
82
+ self.min_samples_split = min_samples_split
83
+ self.min_samples_leaf = min_samples_leaf
84
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
85
+ self.max_features = max_features
86
+ self.max_leaf_nodes = max_leaf_nodes
87
+ self.min_impurity_decrease = min_impurity_decrease
88
+ self.min_impurity_split = min_impurity_split
89
+ self.ccp_alpha = ccp_alpha
90
+ self.max_bins = max_bins
91
+ self.min_bin_size = min_bin_size
92
+ self.infer_mode = infer_mode
93
+ self.splitter_mode = splitter_mode
94
+ self.voting_mode = voting_mode
95
+ self.error_metric_mode = error_metric_mode
96
+ self.variable_importance_mode = variable_importance_mode
97
+ self.algorithm = algorithm
98
+
99
+ def _to_absolute_max_features(self, n_features):
100
+ if self.max_features is None:
101
+ return n_features
102
+ elif isinstance(self.max_features, str):
103
+ return max(1, int(getattr(np, self.max_features)(n_features)))
104
+ elif isinstance(self.max_features, (numbers.Integral, np.integer)):
105
+ return self.max_features
106
+ elif self.max_features > 0.0:
107
+ return max(1, int(self.max_features * n_features))
108
+ return 0
109
+
110
+ def _get_observations_per_tree_fraction(self, n_samples, max_samples):
111
+ if max_samples is None:
112
+ return 1.0
113
+
114
+ if isinstance(max_samples, numbers.Integral):
115
+ if not (1 <= max_samples <= n_samples):
116
+ msg = "`max_samples` must be in range 1 to {} but got value {}"
117
+ raise ValueError(msg.format(n_samples, max_samples))
118
+ return max(float(max_samples / n_samples), 1 / n_samples)
119
+
120
+ if isinstance(max_samples, numbers.Real):
121
+ return max(float(max_samples), 1 / n_samples)
122
+
123
+ msg = "`max_samples` should be int or float, but got type '{}'"
124
+ raise TypeError(msg.format(type(max_samples)))
125
+
126
+ def _get_onedal_params(self, data):
127
+ n_samples, n_features = data.shape
128
+
129
+ self.observations_per_tree_fraction = self._get_observations_per_tree_fraction(
130
+ n_samples=n_samples, max_samples=self.max_samples
131
+ )
132
+ self.observations_per_tree_fraction = (
133
+ self.observations_per_tree_fraction if bool(self.bootstrap) else 1.0
134
+ )
135
+
136
+ if not self.bootstrap and self.max_samples is not None:
137
+ raise ValueError(
138
+ "`max_sample` cannot be set if `bootstrap=False`. "
139
+ "Either switch to `bootstrap=True` or set "
140
+ "`max_sample=None`."
141
+ )
142
+ if not self.bootstrap and self.oob_score:
143
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
144
+
145
+ min_observations_in_leaf_node = (
146
+ self.min_samples_leaf
147
+ if isinstance(self.min_samples_leaf, numbers.Integral)
148
+ else int(ceil(self.min_samples_leaf * n_samples))
149
+ )
150
+
151
+ min_observations_in_split_node = (
152
+ self.min_samples_split
153
+ if isinstance(self.min_samples_split, numbers.Integral)
154
+ else int(ceil(self.min_samples_split * n_samples))
155
+ )
156
+
157
+ rs = check_random_state(self.random_state)
158
+ seed = rs.randint(0, np.iinfo("i").max)
159
+
160
+ onedal_params = {
161
+ "fptype": "float" if data.dtype == np.float32 else "double",
162
+ "method": self.algorithm,
163
+ "infer_mode": self.infer_mode,
164
+ "voting_mode": self.voting_mode,
165
+ "observations_per_tree_fraction": self.observations_per_tree_fraction,
166
+ "impurity_threshold": float(
167
+ 0.0 if self.min_impurity_split is None else self.min_impurity_split
168
+ ),
169
+ "min_weight_fraction_in_leaf_node": self.min_weight_fraction_leaf,
170
+ "min_impurity_decrease_in_split_node": self.min_impurity_decrease,
171
+ "tree_count": int(self.n_estimators),
172
+ "features_per_node": self._to_absolute_max_features(n_features),
173
+ "max_tree_depth": int(0 if self.max_depth is None else self.max_depth),
174
+ "min_observations_in_leaf_node": min_observations_in_leaf_node,
175
+ "min_observations_in_split_node": min_observations_in_split_node,
176
+ "max_leaf_nodes": (0 if self.max_leaf_nodes is None else self.max_leaf_nodes),
177
+ "max_bins": self.max_bins,
178
+ "min_bin_size": self.min_bin_size,
179
+ "seed": seed,
180
+ "memory_saving_mode": False,
181
+ "bootstrap": bool(self.bootstrap),
182
+ "error_metric_mode": self.error_metric_mode,
183
+ "variable_importance_mode": self.variable_importance_mode,
184
+ }
185
+ if isinstance(self, ClassifierMixin):
186
+ onedal_params["class_count"] = (
187
+ 0 if self.classes_ is None else len(self.classes_)
188
+ )
189
+ if daal_check_version((2023, "P", 101)):
190
+ onedal_params["splitter_mode"] = self.splitter_mode
191
+ return onedal_params
192
+
193
+ def _check_parameters(self):
194
+ if isinstance(self.min_samples_leaf, numbers.Integral):
195
+ if not 1 <= self.min_samples_leaf:
196
+ raise ValueError(
197
+ "min_samples_leaf must be at least 1 "
198
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
199
+ )
200
+ else: # float
201
+ if not 0.0 < self.min_samples_leaf <= 0.5:
202
+ raise ValueError(
203
+ "min_samples_leaf must be at least 1 "
204
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
205
+ )
206
+ if isinstance(self.min_samples_split, numbers.Integral):
207
+ if not 2 <= self.min_samples_split:
208
+ raise ValueError(
209
+ "min_samples_split must be an integer "
210
+ "greater than 1 or a float in (0.0, 1.0]; "
211
+ "got the integer %s" % self.min_samples_split
212
+ )
213
+ else: # float
214
+ if not 0.0 < self.min_samples_split <= 1.0:
215
+ raise ValueError(
216
+ "min_samples_split must be an integer "
217
+ "greater than 1 or a float in (0.0, 1.0]; "
218
+ "got the float %s" % self.min_samples_split
219
+ )
220
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
221
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
222
+ if self.min_impurity_split is not None:
223
+ warnings.warn(
224
+ "The min_impurity_split parameter is deprecated. "
225
+ "Its default value has changed from 1e-7 to 0 in "
226
+ "version 0.23, and it will be removed in 0.25. "
227
+ "Use the min_impurity_decrease parameter instead.",
228
+ FutureWarning,
229
+ )
230
+
231
+ if self.min_impurity_split < 0.0:
232
+ raise ValueError(
233
+ "min_impurity_split must be greater than " "or equal to 0"
234
+ )
235
+ if self.min_impurity_decrease < 0.0:
236
+ raise ValueError(
237
+ "min_impurity_decrease must be greater than " "or equal to 0"
238
+ )
239
+ if self.max_leaf_nodes is not None:
240
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
241
+ raise ValueError(
242
+ "max_leaf_nodes must be integral number but was "
243
+ "%r" % self.max_leaf_nodes
244
+ )
245
+ if self.max_leaf_nodes < 2:
246
+ raise ValueError(
247
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
248
+ self.max_leaf_nodes
249
+ )
250
+ )
251
+ if isinstance(self.max_bins, numbers.Integral):
252
+ if not 2 <= self.max_bins:
253
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
254
+ else:
255
+ raise ValueError(
256
+ "max_bins must be integral number but was " "%r" % self.max_bins
257
+ )
258
+ if isinstance(self.min_bin_size, numbers.Integral):
259
+ if not 1 <= self.min_bin_size:
260
+ raise ValueError(
261
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
262
+ )
263
+ else:
264
+ raise ValueError(
265
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
266
+ )
267
+
268
+ def _validate_targets(self, y, dtype):
269
+ self.class_weight_ = None
270
+ self.classes_ = None
271
+ return _column_or_1d(y, warn=True).astype(dtype, copy=False)
272
+
273
+ def _get_sample_weight(self, sample_weight, X):
274
+ sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel()
275
+
276
+ sample_weight = _check_array(
277
+ sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C"
278
+ )
279
+
280
+ if sample_weight.size != X.shape[0]:
281
+ raise ValueError(
282
+ "sample_weight and X have incompatible shapes: "
283
+ "%r vs %r\n"
284
+ "Note: Sparse matrices cannot be indexed w/"
285
+ "boolean masks (use `indices=True` in CV)."
286
+ % (sample_weight.shape, X.shape)
287
+ )
288
+
289
+ return sample_weight
290
+
291
+ def _fit(self, X, y, sample_weight, module, queue):
292
+ X, y = _check_X_y(
293
+ X,
294
+ y,
295
+ dtype=[np.float64, np.float32],
296
+ force_all_finite=True,
297
+ accept_sparse="csr",
298
+ )
299
+ y = self._validate_targets(y, X.dtype)
300
+
301
+ self.n_features_in_ = X.shape[1]
302
+
303
+ if sample_weight is not None and len(sample_weight) > 0:
304
+ sample_weight = self._get_sample_weight(sample_weight, X)
305
+ data = (X, y, sample_weight)
306
+ else:
307
+ data = (X, y)
308
+ policy = self._get_policy(queue, *data)
309
+ data = _convert_to_supported(policy, *data)
310
+ params = self._get_onedal_params(data[0])
311
+ train_result = module.train(policy, params, *to_table(*data))
312
+
313
+ self._onedal_model = train_result.model
314
+
315
+ if self.oob_score:
316
+ if isinstance(self, ClassifierMixin):
317
+ self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
318
+ self.oob_decision_function_ = from_table(
319
+ train_result.oob_err_decision_function
320
+ )
321
+ if np.any(self.oob_decision_function_ == 0):
322
+ warnings.warn(
323
+ "Some inputs do not have OOB scores. This probably means "
324
+ "too few trees were used to compute any reliable OOB "
325
+ "estimates.",
326
+ UserWarning,
327
+ )
328
+ else:
329
+ self.oob_score_ = from_table(train_result.oob_err_r2).item()
330
+ self.oob_prediction_ = from_table(
331
+ train_result.oob_err_prediction
332
+ ).reshape(-1)
333
+ if np.any(self.oob_prediction_ == 0):
334
+ warnings.warn(
335
+ "Some inputs do not have OOB scores. This probably means "
336
+ "too few trees were used to compute any reliable OOB "
337
+ "estimates.",
338
+ UserWarning,
339
+ )
340
+
341
+ return self
342
+
343
+ def _create_model(self, module):
344
+ # TODO:
345
+ # upate error msg.
346
+ raise NotImplementedError("Creating model is not supported.")
347
+
348
+ def _predict(self, X, module, queue, hparams=None):
349
+ _check_is_fitted(self)
350
+ X = _check_array(
351
+ X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
352
+ )
353
+ _check_n_features(self, X, False)
354
+ policy = self._get_policy(queue, X)
355
+
356
+ model = self._onedal_model
357
+ X = _convert_to_supported(policy, X)
358
+ params = self._get_onedal_params(X)
359
+ if hparams is not None and not hparams.is_default:
360
+ result = module.infer(policy, params, hparams.backend, model, to_table(X))
361
+ else:
362
+ result = module.infer(policy, params, model, to_table(X))
363
+
364
+ y = from_table(result.responses)
365
+ return y
366
+
367
+ def _predict_proba(self, X, module, queue):
368
+ _check_is_fitted(self)
369
+ X = _check_array(
370
+ X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
371
+ )
372
+ _check_n_features(self, X, False)
373
+ policy = self._get_policy(queue, X)
374
+ X = _convert_to_supported(policy, X)
375
+ params = self._get_onedal_params(X)
376
+ params["infer_mode"] = "class_probabilities"
377
+
378
+ model = self._onedal_model
379
+ result = module.infer(policy, params, model, to_table(X))
380
+ y = from_table(result.probabilities)
381
+ return y
382
+
383
+
384
+ class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
385
+ def __init__(
386
+ self,
387
+ n_estimators=100,
388
+ criterion="gini",
389
+ max_depth=None,
390
+ min_samples_split=2,
391
+ min_samples_leaf=1,
392
+ min_weight_fraction_leaf=0.0,
393
+ max_features="sqrt",
394
+ max_leaf_nodes=None,
395
+ min_impurity_decrease=0.0,
396
+ min_impurity_split=None,
397
+ bootstrap=True,
398
+ oob_score=False,
399
+ random_state=None,
400
+ warm_start=False,
401
+ class_weight=None,
402
+ ccp_alpha=0.0,
403
+ max_samples=None,
404
+ max_bins=256,
405
+ min_bin_size=1,
406
+ infer_mode="class_responses",
407
+ splitter_mode="best",
408
+ voting_mode="weighted",
409
+ error_metric_mode="none",
410
+ variable_importance_mode="none",
411
+ algorithm="hist",
412
+ **kwargs,
413
+ ):
414
+ super().__init__(
415
+ n_estimators=n_estimators,
416
+ criterion=criterion,
417
+ max_depth=max_depth,
418
+ min_samples_split=min_samples_split,
419
+ min_samples_leaf=min_samples_leaf,
420
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
421
+ max_features=max_features,
422
+ max_leaf_nodes=max_leaf_nodes,
423
+ min_impurity_decrease=min_impurity_decrease,
424
+ min_impurity_split=min_impurity_split,
425
+ bootstrap=bootstrap,
426
+ oob_score=oob_score,
427
+ random_state=random_state,
428
+ warm_start=warm_start,
429
+ class_weight=class_weight,
430
+ ccp_alpha=ccp_alpha,
431
+ max_samples=max_samples,
432
+ max_bins=max_bins,
433
+ min_bin_size=min_bin_size,
434
+ infer_mode=infer_mode,
435
+ splitter_mode=splitter_mode,
436
+ voting_mode=voting_mode,
437
+ error_metric_mode=error_metric_mode,
438
+ variable_importance_mode=variable_importance_mode,
439
+ algorithm=algorithm,
440
+ )
441
+
442
+ def _validate_targets(self, y, dtype):
443
+ y, self.class_weight_, self.classes_ = _validate_targets(
444
+ y, self.class_weight, dtype
445
+ )
446
+
447
+ # Decapsulate classes_ attributes
448
+ # TODO:
449
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
450
+ # if hasattr(self, "classes_"):
451
+ # self.n_classes_ = self.classes_
452
+ return y
453
+
454
+ def fit(self, X, y, sample_weight=None, queue=None):
455
+ return self._fit(
456
+ X,
457
+ y,
458
+ sample_weight,
459
+ self._get_backend("decision_forest", "classification", None),
460
+ queue,
461
+ )
462
+
463
+ def predict(self, X, queue=None):
464
+ hparams = get_hyperparameters("decision_forest", "infer")
465
+ pred = super()._predict(
466
+ X,
467
+ self._get_backend("decision_forest", "classification", None),
468
+ queue,
469
+ hparams,
470
+ )
471
+
472
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
473
+
474
+ def predict_proba(self, X, queue=None):
475
+ return super()._predict_proba(
476
+ X, self._get_backend("decision_forest", "classification", None), queue
477
+ )
478
+
479
+
480
+ class RandomForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
481
+ def __init__(
482
+ self,
483
+ n_estimators=100,
484
+ criterion="squared_error",
485
+ max_depth=None,
486
+ min_samples_split=2,
487
+ min_samples_leaf=1,
488
+ min_weight_fraction_leaf=0.0,
489
+ max_features=1.0,
490
+ max_leaf_nodes=None,
491
+ min_impurity_decrease=0.0,
492
+ min_impurity_split=None,
493
+ bootstrap=True,
494
+ oob_score=False,
495
+ random_state=None,
496
+ warm_start=False,
497
+ class_weight=None,
498
+ ccp_alpha=0.0,
499
+ max_samples=None,
500
+ max_bins=256,
501
+ min_bin_size=1,
502
+ infer_mode="class_responses",
503
+ splitter_mode="best",
504
+ voting_mode="weighted",
505
+ error_metric_mode="none",
506
+ variable_importance_mode="none",
507
+ algorithm="hist",
508
+ **kwargs,
509
+ ):
510
+ super().__init__(
511
+ n_estimators=n_estimators,
512
+ criterion=criterion,
513
+ max_depth=max_depth,
514
+ min_samples_split=min_samples_split,
515
+ min_samples_leaf=min_samples_leaf,
516
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
517
+ max_features=max_features,
518
+ max_leaf_nodes=max_leaf_nodes,
519
+ min_impurity_decrease=min_impurity_decrease,
520
+ min_impurity_split=min_impurity_split,
521
+ bootstrap=bootstrap,
522
+ oob_score=oob_score,
523
+ random_state=random_state,
524
+ warm_start=warm_start,
525
+ class_weight=class_weight,
526
+ ccp_alpha=ccp_alpha,
527
+ max_samples=max_samples,
528
+ max_bins=max_bins,
529
+ min_bin_size=min_bin_size,
530
+ infer_mode=infer_mode,
531
+ splitter_mode=splitter_mode,
532
+ voting_mode=voting_mode,
533
+ error_metric_mode=error_metric_mode,
534
+ variable_importance_mode=variable_importance_mode,
535
+ algorithm=algorithm,
536
+ )
537
+
538
+ def fit(self, X, y, sample_weight=None, queue=None):
539
+ if sample_weight is not None:
540
+ if hasattr(sample_weight, "__array__"):
541
+ sample_weight[sample_weight == 0.0] = 1.0
542
+ sample_weight = [sample_weight]
543
+ return super()._fit(
544
+ X,
545
+ y,
546
+ sample_weight,
547
+ self._get_backend("decision_forest", "regression", None),
548
+ queue,
549
+ )
550
+
551
+ def predict(self, X, queue=None):
552
+ return (
553
+ super()
554
+ ._predict(X, self._get_backend("decision_forest", "regression", None), queue)
555
+ .ravel()
556
+ )
557
+
558
+
559
+ class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
560
+ def __init__(
561
+ self,
562
+ n_estimators=100,
563
+ criterion="gini",
564
+ max_depth=None,
565
+ min_samples_split=2,
566
+ min_samples_leaf=1,
567
+ min_weight_fraction_leaf=0.0,
568
+ max_features="sqrt",
569
+ max_leaf_nodes=None,
570
+ min_impurity_decrease=0.0,
571
+ min_impurity_split=None,
572
+ bootstrap=False,
573
+ oob_score=False,
574
+ random_state=None,
575
+ warm_start=False,
576
+ class_weight=None,
577
+ ccp_alpha=0.0,
578
+ max_samples=None,
579
+ max_bins=256,
580
+ min_bin_size=1,
581
+ infer_mode="class_responses",
582
+ splitter_mode="random",
583
+ voting_mode="weighted",
584
+ error_metric_mode="none",
585
+ variable_importance_mode="none",
586
+ algorithm="hist",
587
+ **kwargs,
588
+ ):
589
+ super().__init__(
590
+ n_estimators=n_estimators,
591
+ criterion=criterion,
592
+ max_depth=max_depth,
593
+ min_samples_split=min_samples_split,
594
+ min_samples_leaf=min_samples_leaf,
595
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
596
+ max_features=max_features,
597
+ max_leaf_nodes=max_leaf_nodes,
598
+ min_impurity_decrease=min_impurity_decrease,
599
+ min_impurity_split=min_impurity_split,
600
+ bootstrap=bootstrap,
601
+ oob_score=oob_score,
602
+ random_state=random_state,
603
+ warm_start=warm_start,
604
+ class_weight=class_weight,
605
+ ccp_alpha=ccp_alpha,
606
+ max_samples=max_samples,
607
+ max_bins=max_bins,
608
+ min_bin_size=min_bin_size,
609
+ infer_mode=infer_mode,
610
+ splitter_mode=splitter_mode,
611
+ voting_mode=voting_mode,
612
+ error_metric_mode=error_metric_mode,
613
+ variable_importance_mode=variable_importance_mode,
614
+ algorithm=algorithm,
615
+ )
616
+
617
+ def _validate_targets(self, y, dtype):
618
+ y, self.class_weight_, self.classes_ = _validate_targets(
619
+ y, self.class_weight, dtype
620
+ )
621
+
622
+ # Decapsulate classes_ attributes
623
+ # TODO:
624
+ # align with `n_classes_` and `classes_` attr with daal4py implementations.
625
+ # if hasattr(self, "classes_"):
626
+ # self.n_classes_ = self.classes_
627
+ return y
628
+
629
+ def fit(self, X, y, sample_weight=None, queue=None):
630
+ return self._fit(
631
+ X,
632
+ y,
633
+ sample_weight,
634
+ self._get_backend("decision_forest", "classification", None),
635
+ queue,
636
+ )
637
+
638
+ def predict(self, X, queue=None):
639
+ pred = super()._predict(
640
+ X, self._get_backend("decision_forest", "classification", None), queue
641
+ )
642
+
643
+ return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
644
+
645
+ def predict_proba(self, X, queue=None):
646
+ return super()._predict_proba(
647
+ X, self._get_backend("decision_forest", "classification", None), queue
648
+ )
649
+
650
+
651
+ class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
652
+ def __init__(
653
+ self,
654
+ n_estimators=100,
655
+ criterion="squared_error",
656
+ max_depth=None,
657
+ min_samples_split=2,
658
+ min_samples_leaf=1,
659
+ min_weight_fraction_leaf=0.0,
660
+ max_features=1.0,
661
+ max_leaf_nodes=None,
662
+ min_impurity_decrease=0.0,
663
+ min_impurity_split=None,
664
+ bootstrap=False,
665
+ oob_score=False,
666
+ random_state=None,
667
+ warm_start=False,
668
+ class_weight=None,
669
+ ccp_alpha=0.0,
670
+ max_samples=None,
671
+ max_bins=256,
672
+ min_bin_size=1,
673
+ infer_mode="class_responses",
674
+ splitter_mode="random",
675
+ voting_mode="weighted",
676
+ error_metric_mode="none",
677
+ variable_importance_mode="none",
678
+ algorithm="hist",
679
+ **kwargs,
680
+ ):
681
+ super().__init__(
682
+ n_estimators=n_estimators,
683
+ criterion=criterion,
684
+ max_depth=max_depth,
685
+ min_samples_split=min_samples_split,
686
+ min_samples_leaf=min_samples_leaf,
687
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
688
+ max_features=max_features,
689
+ max_leaf_nodes=max_leaf_nodes,
690
+ min_impurity_decrease=min_impurity_decrease,
691
+ min_impurity_split=min_impurity_split,
692
+ bootstrap=bootstrap,
693
+ oob_score=oob_score,
694
+ random_state=random_state,
695
+ warm_start=warm_start,
696
+ class_weight=class_weight,
697
+ ccp_alpha=ccp_alpha,
698
+ max_samples=max_samples,
699
+ max_bins=max_bins,
700
+ min_bin_size=min_bin_size,
701
+ infer_mode=infer_mode,
702
+ splitter_mode=splitter_mode,
703
+ voting_mode=voting_mode,
704
+ error_metric_mode=error_metric_mode,
705
+ variable_importance_mode=variable_importance_mode,
706
+ algorithm=algorithm,
707
+ )
708
+
709
+ def fit(self, X, y, sample_weight=None, queue=None):
710
+ if sample_weight is not None:
711
+ if hasattr(sample_weight, "__array__"):
712
+ sample_weight[sample_weight == 0.0] = 1.0
713
+ sample_weight = [sample_weight]
714
+ return super()._fit(
715
+ X,
716
+ y,
717
+ sample_weight,
718
+ self._get_backend("decision_forest", "regression", None),
719
+ queue,
720
+ )
721
+
722
+ def predict(self, X, queue=None):
723
+ return (
724
+ super()
725
+ ._predict(X, self._get_backend("decision_forest", "regression", None), queue)
726
+ .ravel()
727
+ )