scikit-learn-intelex 2024.0.0__py311-none-win_amd64.whl → 2024.0.1__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (99) hide show
  1. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_utils.py +2 -0
  2. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +70 -77
  3. {scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +6 -2
  4. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +960 -494
  5. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +18 -15
  6. {scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +59 -12
  7. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +15 -4
  8. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  9. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +3 -1
  10. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +2 -6
  11. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -14
  12. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +8 -5
  13. {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  14. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  15. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  16. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  17. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -54
  18. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -17
  19. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1557
  20. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -20
  21. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  22. scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py +0 -47
  23. scikit_learn_intelex-2024.0.0.dist-info/RECORD +0 -98
  24. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  25. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  26. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  27. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  28. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  29. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  30. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  31. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -0
  32. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  33. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  34. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  35. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  36. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -0
  37. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -0
  38. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  39. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  40. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  41. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  42. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  43. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  44. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  45. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
  46. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  47. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  48. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  49. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  51. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  52. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  53. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  54. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  55. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  56. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  57. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  58. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -0
  59. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -0
  60. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +0 -0
  61. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -0
  62. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
  63. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  65. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -0
  66. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -0
  67. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -0
  68. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  69. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  70. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  71. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  72. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  73. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  74. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  75. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  76. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  77. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  78. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  79. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  81. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  83. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +0 -0
  84. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +0 -0
  85. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +0 -0
  86. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
  87. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  88. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -0
  89. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  90. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  91. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -0
  92. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  93. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -0
  94. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  96. {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  97. {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  98. {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  99. {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,16 @@ from scipy import sparse as sp
24
24
  from sklearn.base import clone
25
25
  from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
26
26
  from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
27
+ from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
28
+ from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
29
+ from sklearn.ensemble._forest import _get_n_samples_bootstrap
27
30
  from sklearn.exceptions import DataConversionWarning
28
- from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
29
- from sklearn.tree import ExtraTreeClassifier, ExtraTreeRegressor
31
+ from sklearn.tree import (
32
+ DecisionTreeClassifier,
33
+ DecisionTreeRegressor,
34
+ ExtraTreeClassifier,
35
+ ExtraTreeRegressor,
36
+ )
30
37
  from sklearn.tree._tree import Tree
31
38
  from sklearn.utils import check_random_state, deprecated
32
39
  from sklearn.utils.validation import (
@@ -53,18 +60,117 @@ try:
53
60
  except ModuleNotFoundError:
54
61
  from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
55
62
  from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
63
+
56
64
  from onedal.primitives import get_tree_state_cls, get_tree_state_reg
57
65
  from onedal.utils import _num_features, _num_samples
58
66
 
59
- from ..._config import get_config
60
- from ..._device_offload import dispatch, wrap_output_data
61
- from ..._utils import PatchingConditionsChain
67
+ from .._config import get_config
68
+ from .._device_offload import dispatch, wrap_output_data
69
+ from .._utils import PatchingConditionsChain
62
70
 
63
71
  if sklearn_check_version("1.2"):
64
72
  from sklearn.utils._param_validation import Interval
73
+ if sklearn_check_version("1.4"):
74
+ from daal4py.sklearn.utils import _assert_all_finite
75
+
76
+
77
+ class BaseForest(ABC):
78
+ _onedal_factory = None
79
+
80
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
81
+ if sklearn_check_version("0.24"):
82
+ X, y = self._validate_data(
83
+ X,
84
+ y,
85
+ multi_output=False,
86
+ accept_sparse=False,
87
+ dtype=[np.float64, np.float32],
88
+ force_all_finite=False,
89
+ )
90
+ else:
91
+ X, y = check_X_y(
92
+ X,
93
+ y,
94
+ accept_sparse=False,
95
+ dtype=[np.float64, np.float32],
96
+ multi_output=False,
97
+ force_all_finite=False,
98
+ )
99
+
100
+ if sample_weight is not None:
101
+ sample_weight = self.check_sample_weight(sample_weight, X)
102
+
103
+ if y.ndim == 2 and y.shape[1] == 1:
104
+ warnings.warn(
105
+ "A column-vector y was passed when a 1d array was"
106
+ " expected. Please change the shape of y to "
107
+ "(n_samples,), for example using ravel().",
108
+ DataConversionWarning,
109
+ stacklevel=2,
110
+ )
111
+
112
+ if y.ndim == 1:
113
+ # reshape is necessary to preserve the data contiguity against vs
114
+ # [:, np.newaxis] that does not.
115
+ y = np.reshape(y, (-1, 1))
116
+
117
+ y, expanded_class_weight = self._validate_y_class_weight(y)
118
+
119
+ self.n_features_in_ = X.shape[1]
120
+
121
+ if expanded_class_weight is not None:
122
+ if sample_weight is not None:
123
+ sample_weight = sample_weight * expanded_class_weight
124
+ else:
125
+ sample_weight = expanded_class_weight
126
+ if sample_weight is not None:
127
+ sample_weight = [sample_weight]
128
+
129
+ onedal_params = {
130
+ "n_estimators": self.n_estimators,
131
+ "criterion": self.criterion,
132
+ "max_depth": self.max_depth,
133
+ "min_samples_split": self.min_samples_split,
134
+ "min_samples_leaf": self.min_samples_leaf,
135
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
136
+ "max_features": self.max_features,
137
+ "max_leaf_nodes": self.max_leaf_nodes,
138
+ "min_impurity_decrease": self.min_impurity_decrease,
139
+ "bootstrap": self.bootstrap,
140
+ "oob_score": self.oob_score,
141
+ "n_jobs": self.n_jobs,
142
+ "random_state": self.random_state,
143
+ "verbose": self.verbose,
144
+ "warm_start": self.warm_start,
145
+ "error_metric_mode": self._err if self.oob_score else "none",
146
+ "variable_importance_mode": "mdi",
147
+ "class_weight": self.class_weight,
148
+ "max_bins": self.max_bins,
149
+ "min_bin_size": self.min_bin_size,
150
+ "max_samples": self.max_samples,
151
+ }
152
+
153
+ if not sklearn_check_version("1.0"):
154
+ onedal_params["min_impurity_split"] = self.min_impurity_split
155
+ else:
156
+ onedal_params["min_impurity_split"] = None
157
+
158
+ # Lazy evaluation of estimators_
159
+ self._cached_estimators_ = None
160
+
161
+ # Compute
162
+ self._onedal_estimator = self._onedal_factory(**onedal_params)
163
+ self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
164
+
165
+ self._save_attributes()
166
+
167
+ # Decapsulate classes_ attributes
168
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
169
+ self.n_classes_ = self.n_classes_[0]
170
+ self.classes_ = self.classes_[0]
65
171
 
172
+ return self
66
173
 
67
- class BaseTree(ABC):
68
174
  def _fit_proba(self, X, y, sample_weight=None, queue=None):
69
175
  params = self.get_params()
70
176
  self.__class__(**params)
@@ -75,8 +181,6 @@ class BaseTree(ABC):
75
181
  cfg["target_offload"] = queue
76
182
 
77
183
  def _save_attributes(self):
78
- self._onedal_model = self._onedal_estimator._onedal_model
79
-
80
184
  if self.oob_score:
81
185
  self.oob_score_ = self._onedal_estimator.oob_score_
82
186
  if hasattr(self._onedal_estimator, "oob_prediction_"):
@@ -85,6 +189,8 @@ class BaseTree(ABC):
85
189
  self.oob_decision_function_ = (
86
190
  self._onedal_estimator.oob_decision_function_
87
191
  )
192
+
193
+ self._validate_estimator()
88
194
  return self
89
195
 
90
196
  # TODO:
@@ -183,6 +289,7 @@ class BaseTree(ABC):
183
289
  ensure_2d=False,
184
290
  dtype=dtype,
185
291
  order="C",
292
+ force_all_finite=False,
186
293
  )
187
294
  if sample_weight.ndim != 1:
188
295
  raise ValueError("Sample weights must be 1D array or scalar")
@@ -198,7 +305,7 @@ class BaseTree(ABC):
198
305
  @property
199
306
  def estimators_(self):
200
307
  if hasattr(self, "_cached_estimators_"):
201
- if self._cached_estimators_ is None and self._onedal_model:
308
+ if self._cached_estimators_ is None:
202
309
  self._estimators_()
203
310
  return self._cached_estimators_
204
311
  else:
@@ -211,13 +318,99 @@ class BaseTree(ABC):
211
318
  # Needed to allow for proper sklearn operation in fallback mode
212
319
  self._cached_estimators_ = estimators
213
320
 
321
+ def _estimators_(self):
322
+ # _estimators_ should only be called if _onedal_estimator exists
323
+ check_is_fitted(self, "_onedal_estimator")
324
+ if hasattr(self, "n_classes_"):
325
+ n_classes_ = (
326
+ self.n_classes_
327
+ if isinstance(self.n_classes_, int)
328
+ else self.n_classes_[0]
329
+ )
330
+ else:
331
+ n_classes_ = 1
332
+
333
+ # convert model to estimators
334
+ params = {
335
+ "criterion": self._onedal_estimator.criterion,
336
+ "max_depth": self._onedal_estimator.max_depth,
337
+ "min_samples_split": self._onedal_estimator.min_samples_split,
338
+ "min_samples_leaf": self._onedal_estimator.min_samples_leaf,
339
+ "min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
340
+ "max_features": self._onedal_estimator.max_features,
341
+ "max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
342
+ "min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
343
+ "random_state": None,
344
+ }
345
+ if not sklearn_check_version("1.0"):
346
+ params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
347
+ est = self.estimator.__class__(**params)
348
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
349
+ # oneAPI Data Analytics Library solution
350
+ estimators_ = []
351
+
352
+ random_state_checked = check_random_state(self.random_state)
353
+
354
+ for i in range(self._onedal_estimator.n_estimators):
355
+ est_i = clone(est)
356
+ est_i.set_params(
357
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
358
+ )
359
+ if sklearn_check_version("1.0"):
360
+ est_i.n_features_in_ = self.n_features_in_
361
+ else:
362
+ est_i.n_features_ = self.n_features_in_
363
+ est_i.n_outputs_ = self.n_outputs_
364
+ est_i.n_classes_ = n_classes_
365
+ tree_i_state_class = self._get_tree_state(
366
+ self._onedal_estimator._onedal_model, i, n_classes_
367
+ )
368
+ tree_i_state_dict = {
369
+ "max_depth": tree_i_state_class.max_depth,
370
+ "node_count": tree_i_state_class.node_count,
371
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
372
+ "values": tree_i_state_class.value_ar,
373
+ }
374
+ est_i.tree_ = Tree(
375
+ self.n_features_in_,
376
+ np.array([n_classes_], dtype=np.intp),
377
+ self.n_outputs_,
378
+ )
379
+ est_i.tree_.__setstate__(tree_i_state_dict)
380
+ estimators_.append(est_i)
381
+
382
+ self._cached_estimators_ = estimators_
383
+
384
+ if sklearn_check_version("1.0"):
385
+
386
+ @deprecated(
387
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
388
+ "removed in 1.2. Use `n_features_in_` instead."
389
+ )
390
+ @property
391
+ def n_features_(self):
392
+ return self.n_features_in_
393
+
394
+ if not sklearn_check_version("1.2"):
395
+
396
+ @property
397
+ def base_estimator(self):
398
+ return self.estimator
399
+
400
+ @base_estimator.setter
401
+ def base_estimator(self, estimator):
402
+ self.estimator = estimator
403
+
214
404
 
215
- class ForestClassifier(sklearn_ForestClassifier, BaseTree):
405
+ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
216
406
  # Surprisingly, even though scikit-learn warns against using
217
407
  # their ForestClassifier directly, it actually has a more stable
218
408
  # API than the user-facing objects (over time). If they change it
219
409
  # significantly at some point then this may need to be versioned.
220
410
 
411
+ _err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
412
+ _get_tree_state = staticmethod(get_tree_state_cls)
413
+
221
414
  def __init__(
222
415
  self,
223
416
  estimator,
@@ -247,16 +440,27 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
247
440
  max_samples=max_samples,
248
441
  )
249
442
 
250
- # The splitter is recognized here for proper dispatching.
251
- self._estimator = estimator # TODO: Verify if this is done in older verions
252
- if self._estimator.__class__ == DecisionTreeClassifier:
253
- self._onedal_classifier = onedal_RandomForestClassifier
254
- elif self._estimator.__class__ == ExtraTreeClassifier:
255
- self._onedal_classifier = onedal_ExtraTreesClassifier
256
- else:
257
- raise TypeError(
258
- f"{estimator.__class__.__name__} is not a supported tree classifier"
259
- )
443
+ # The estimator is checked against the class attribute for conformance.
444
+ # This should only trigger if the user uses this class directly.
445
+ if (
446
+ self.estimator.__class__ == DecisionTreeClassifier
447
+ and self._onedal_factory != onedal_RandomForestClassifier
448
+ ):
449
+ self._onedal_factory = onedal_RandomForestClassifier
450
+ elif (
451
+ self.estimator.__class__ == ExtraTreeClassifier
452
+ and self._onedal_factory != onedal_ExtraTreesClassifier
453
+ ):
454
+ self._onedal_factory = onedal_ExtraTreesClassifier
455
+
456
+ if self._onedal_factory is None:
457
+ raise TypeError(f" oneDAL estimator has not been set.")
458
+
459
+ def _estimators_(self):
460
+ super()._estimators_()
461
+ classes_ = self.classes_[0]
462
+ for est in self._cached_estimators_:
463
+ est.classes_ = classes_
260
464
 
261
465
  def fit(self, X, y, sample_weight=None):
262
466
  dispatch(
@@ -292,17 +496,17 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
292
496
  or not self.oob_score,
293
497
  "OOB score is only supported starting from 2021.5 version of oneDAL.",
294
498
  ),
295
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
296
- (
297
- self.ccp_alpha == 0.0,
298
- f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
299
- ),
499
+ (self.warm_start is False, "Warm start is not supported."),
300
500
  (
301
501
  self.criterion == "gini",
302
502
  f"'{self.criterion}' criterion is not supported. "
303
503
  "Only 'gini' criterion is supported.",
304
504
  ),
305
- (self.warm_start is False, "Warm start is not supported."),
505
+ (
506
+ self.ccp_alpha == 0.0,
507
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
508
+ ),
509
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
306
510
  (
307
511
  self.n_estimators <= 6024,
308
512
  "More than 6024 estimators is not supported.",
@@ -310,12 +514,46 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
310
514
  ]
311
515
  )
312
516
 
517
+ if self.bootstrap:
518
+ patching_status.and_conditions(
519
+ [
520
+ (
521
+ self.class_weight != "balanced_subsample",
522
+ "'balanced_subsample' for class_weight is not supported",
523
+ )
524
+ ]
525
+ )
526
+
527
+ if patching_status.get_status() and sklearn_check_version("1.4"):
528
+ try:
529
+ _assert_all_finite(X)
530
+ input_is_finite = True
531
+ except ValueError:
532
+ input_is_finite = False
533
+ patching_status.and_conditions(
534
+ [
535
+ (input_is_finite, "Non-finite input is not supported."),
536
+ (
537
+ self.monotonic_cst is None,
538
+ "Monotonicity constraints are not supported.",
539
+ ),
540
+ ]
541
+ )
542
+
313
543
  if patching_status.get_status():
314
- if sklearn_check_version("1.0"):
315
- self._check_feature_names(X, reset=True)
316
- X = check_array(X, dtype=[np.float32, np.float64])
317
- y = np.asarray(y)
318
- y = np.atleast_1d(y)
544
+ if sklearn_check_version("0.24"):
545
+ X, y = self._validate_data(
546
+ X,
547
+ y,
548
+ multi_output=True,
549
+ accept_sparse=True,
550
+ dtype=[np.float64, np.float32],
551
+ force_all_finite=False,
552
+ )
553
+ else:
554
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
555
+ y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
556
+
319
557
  if y.ndim == 2 and y.shape[1] == 1:
320
558
  warnings.warn(
321
559
  "A column-vector y was passed when a 1d array was"
@@ -324,11 +562,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
324
562
  DataConversionWarning,
325
563
  stacklevel=2,
326
564
  )
327
- check_consistent_length(X, y)
328
565
 
329
566
  if y.ndim == 1:
330
567
  y = np.reshape(y, (-1, 1))
568
+
331
569
  self.n_outputs_ = y.shape[1]
570
+
332
571
  patching_status.and_conditions(
333
572
  [
334
573
  (
@@ -343,30 +582,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
343
582
  )
344
583
  # TODO: Fix to support integers as input
345
584
 
346
- n_samples = X.shape[0]
347
- if isinstance(self.max_samples, numbers.Integral):
348
- if not sklearn_check_version("1.2"):
349
- if not (1 <= self.max_samples <= n_samples):
350
- msg = "`max_samples` must be in range 1 to {} but got value {}"
351
- raise ValueError(msg.format(n_samples, self.max_samples))
352
- else:
353
- if self.max_samples > n_samples:
354
- msg = "`max_samples` must be <= n_samples={} but got value {}"
355
- raise ValueError(msg.format(n_samples, self.max_samples))
356
- elif isinstance(self.max_samples, numbers.Real):
357
- if sklearn_check_version("1.2"):
358
- pass
359
- elif sklearn_check_version("1.0"):
360
- if not (0 < float(self.max_samples) <= 1):
361
- msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
362
- raise ValueError(msg.format(self.max_samples))
363
- else:
364
- if not (0 < float(self.max_samples) < 1):
365
- msg = "`max_samples` must be in range (0, 1) but got value {}"
366
- raise ValueError(msg.format(self.max_samples))
367
- elif self.max_samples is not None:
368
- msg = "`max_samples` should be int or float, but got type '{}'"
369
- raise TypeError(msg.format(type(self.max_samples)))
585
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
370
586
 
371
587
  if not self.bootstrap and self.max_samples is not None:
372
588
  raise ValueError(
@@ -375,6 +591,17 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
375
591
  "`max_sample=None`."
376
592
  )
377
593
 
594
+ if (
595
+ patching_status.get_status()
596
+ and (self.random_state is not None)
597
+ and (not daal_check_version((2024, "P", 0)))
598
+ ):
599
+ warnings.warn(
600
+ "Setting 'random_state' value is not supported. "
601
+ "State set by oneDAL to default value (777).",
602
+ RuntimeWarning,
603
+ )
604
+
378
605
  return patching_status, X, y, sample_weight
379
606
 
380
607
  @wrap_output_data
@@ -423,124 +650,57 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
423
650
  predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
424
651
  predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
425
652
 
426
- if sklearn_check_version("1.0"):
427
-
428
- @deprecated(
429
- "Attribute `n_features_` was deprecated in version 1.0 and will be "
430
- "removed in 1.2. Use `n_features_in_` instead."
431
- )
432
- @property
433
- def n_features_(self):
434
- return self.n_features_in_
435
-
436
- def _estimators_(self):
437
- # _estimators_ should only be called if _onedal_model exists
438
- check_is_fitted(self, "_onedal_model")
439
- classes_ = self.classes_[0]
440
- n_classes_ = (
441
- self.n_classes_ if isinstance(self.n_classes_, int) else self.n_classes_[0]
653
+ def _onedal_cpu_supported(self, method_name, *data):
654
+ class_name = self.__class__.__name__
655
+ patching_status = PatchingConditionsChain(
656
+ f"sklearn.ensemble.{class_name}.{method_name}"
442
657
  )
443
- # convert model to estimators
444
- params = {
445
- "criterion": self.criterion,
446
- "max_depth": self.max_depth,
447
- "min_samples_split": self.min_samples_split,
448
- "min_samples_leaf": self.min_samples_leaf,
449
- "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
450
- "max_features": self.max_features,
451
- "max_leaf_nodes": self.max_leaf_nodes,
452
- "min_impurity_decrease": self.min_impurity_decrease,
453
- "random_state": None,
454
- }
455
- if not sklearn_check_version("1.0"):
456
- params["min_impurity_split"] = self.min_impurity_split
457
- est = self._estimator.__class__(**params)
458
- # we need to set est.tree_ field with Trees constructed from Intel(R)
459
- # oneAPI Data Analytics Library solution
460
- estimators_ = []
461
658
 
462
- random_state_checked = check_random_state(self.random_state)
659
+ if method_name == "fit":
660
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
661
+ patching_status, *data
662
+ )
463
663
 
464
- for i in range(self.n_estimators):
465
- est_i = clone(est)
466
- est_i.set_params(
467
- random_state=random_state_checked.randint(np.iinfo(np.int32).max)
664
+ patching_status.and_conditions(
665
+ [
666
+ (
667
+ daal_check_version((2023, "P", 200))
668
+ or self.estimator.__class__ == DecisionTreeClassifier,
669
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
670
+ ),
671
+ (
672
+ not sp.issparse(sample_weight),
673
+ "sample_weight is sparse. " "Sparse input is not supported.",
674
+ ),
675
+ ]
468
676
  )
469
- if sklearn_check_version("1.0"):
470
- est_i.n_features_in_ = self.n_features_in_
471
- else:
472
- est_i.n_features_ = self.n_features_in_
473
- est_i.n_outputs_ = self.n_outputs_
474
- est_i.classes_ = classes_
475
- est_i.n_classes_ = n_classes_
476
- tree_i_state_class = get_tree_state_cls(self._onedal_model, i, n_classes_)
477
- tree_i_state_dict = {
478
- "max_depth": tree_i_state_class.max_depth,
479
- "node_count": tree_i_state_class.node_count,
480
- "nodes": check_tree_nodes(tree_i_state_class.node_ar),
481
- "values": tree_i_state_class.value_ar,
482
- }
483
- est_i.tree_ = Tree(
484
- self.n_features_in_,
485
- np.array([n_classes_], dtype=np.intp),
486
- self.n_outputs_,
487
- )
488
- est_i.tree_.__setstate__(tree_i_state_dict)
489
- estimators_.append(est_i)
490
-
491
- self._cached_estimators_ = estimators_
492
-
493
- def _onedal_cpu_supported(self, method_name, *data):
494
- class_name = self.__class__.__name__
495
- patching_status = PatchingConditionsChain(
496
- f"sklearn.ensemble.{class_name}.{method_name}"
497
- )
498
-
499
- if method_name == "fit":
500
- patching_status, X, y, sample_weight = self._onedal_fit_ready(
501
- patching_status, *data
502
- )
503
-
504
- patching_status.and_conditions(
505
- [
506
- (
507
- daal_check_version((2023, "P", 200))
508
- or self._estimator.__class__ == DecisionTreeClassifier,
509
- "ExtraTrees only supported starting from oneDAL version 2023.2",
510
- ),
511
- (
512
- not sp.issparse(sample_weight),
513
- "sample_weight is sparse. " "Sparse input is not supported.",
514
- ),
515
- ]
516
- )
517
-
518
- if (
519
- patching_status.get_status()
520
- and (self.random_state is not None)
521
- and (not daal_check_version((2024, "P", 0)))
522
- ):
523
- warnings.warn(
524
- "Setting 'random_state' value is not supported. "
525
- "State set by oneDAL to default value (777).",
526
- RuntimeWarning,
527
- )
528
677
 
529
678
  elif method_name in ["predict", "predict_proba"]:
530
679
  X = data[0]
531
680
 
532
681
  patching_status.and_conditions(
533
682
  [
534
- (hasattr(self, "_onedal_model"), "oneDAL model was not trained."),
683
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
535
684
  (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
536
685
  (self.warm_start is False, "Warm start is not supported."),
537
686
  (
538
687
  daal_check_version((2023, "P", 100))
539
- or self._estimator.__class__ == DecisionTreeClassifier,
688
+ or self.estimator.__class__ == DecisionTreeClassifier,
540
689
  "ExtraTrees only supported starting from oneDAL version 2023.2",
541
690
  ),
542
691
  ]
543
692
  )
693
+
694
+ if method_name == "predict_proba":
695
+ patching_status.and_conditions(
696
+ [
697
+ (
698
+ daal_check_version((2021, "P", 400)),
699
+ "oneDAL version is lower than 2021.4.",
700
+ )
701
+ ]
702
+ )
703
+
544
704
  if hasattr(self, "n_outputs_"):
545
705
  patching_status.and_conditions(
546
706
  [
@@ -573,24 +733,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
573
733
  [
574
734
  (
575
735
  daal_check_version((2023, "P", 100))
576
- or self._estimator.__class__ == DecisionTreeClassifier,
736
+ or self.estimator.__class__ == DecisionTreeClassifier,
577
737
  "ExtraTrees only supported starting from oneDAL version 2023.1",
578
738
  ),
579
739
  (sample_weight is not None, "sample_weight is not supported."),
580
740
  ]
581
741
  )
582
742
 
583
- if (
584
- patching_status.get_status()
585
- and (self.random_state is not None)
586
- and (not daal_check_version((2024, "P", 0)))
587
- ):
588
- warnings.warn(
589
- "Setting 'random_state' value is not supported. "
590
- "State set by oneDAL to default value (777).",
591
- RuntimeWarning,
592
- )
593
-
594
743
  elif method_name in ["predict", "predict_proba"]:
595
744
  X = data[0]
596
745
 
@@ -625,113 +774,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
625
774
 
626
775
  return patching_status
627
776
 
628
- def _onedal_fit(self, X, y, sample_weight=None, queue=None):
629
- if sklearn_check_version("1.2"):
630
- X, y = self._validate_data(
631
- X,
632
- y,
633
- multi_output=False,
634
- accept_sparse=False,
635
- dtype=[np.float64, np.float32],
636
- )
637
- else:
638
- X, y = check_X_y(
639
- X,
640
- y,
641
- accept_sparse=False,
642
- dtype=[np.float64, np.float32],
643
- multi_output=False,
644
- )
645
-
646
- if sample_weight is not None:
647
- sample_weight = self.check_sample_weight(sample_weight, X)
648
-
649
- y = np.atleast_1d(y)
650
- if y.ndim == 2 and y.shape[1] == 1:
651
- warnings.warn(
652
- "A column-vector y was passed when a 1d array was"
653
- " expected. Please change the shape of y to "
654
- "(n_samples,), for example using ravel().",
655
- DataConversionWarning,
656
- stacklevel=2,
657
- )
658
- if y.ndim == 1:
659
- # reshape is necessary to preserve the data contiguity against vs
660
- # [:, np.newaxis] that does not.
661
- y = np.reshape(y, (-1, 1))
662
-
663
- y, expanded_class_weight = self._validate_y_class_weight(y)
664
-
665
- n_classes_ = self.n_classes_[0]
666
- self.n_features_in_ = X.shape[1]
667
- if not sklearn_check_version("1.0"):
668
- self.n_features_ = self.n_features_in_
669
-
670
- if expanded_class_weight is not None:
671
- if sample_weight is not None:
672
- sample_weight = sample_weight * expanded_class_weight
673
- else:
674
- sample_weight = expanded_class_weight
675
- if sample_weight is not None:
676
- sample_weight = [sample_weight]
677
-
678
- if n_classes_ < 2:
679
- raise ValueError("Training data only contain information about one class.")
680
-
681
- if self.oob_score:
682
- err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
683
- else:
684
- err = "none"
685
-
686
- onedal_params = {
687
- "n_estimators": self.n_estimators,
688
- "criterion": self.criterion,
689
- "max_depth": self.max_depth,
690
- "min_samples_split": self.min_samples_split,
691
- "min_samples_leaf": self.min_samples_leaf,
692
- "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
693
- "max_features": self.max_features,
694
- "max_leaf_nodes": self.max_leaf_nodes,
695
- "min_impurity_decrease": self.min_impurity_decrease,
696
- "bootstrap": self.bootstrap,
697
- "oob_score": self.oob_score,
698
- "n_jobs": self.n_jobs,
699
- "random_state": self.random_state,
700
- "verbose": self.verbose,
701
- "warm_start": self.warm_start,
702
- "error_metric_mode": err,
703
- "variable_importance_mode": "mdi",
704
- "class_weight": self.class_weight,
705
- "max_bins": self.max_bins,
706
- "min_bin_size": self.min_bin_size,
707
- "max_samples": self.max_samples,
708
- }
709
- if daal_check_version((2023, "P", 101)):
710
- onedal_params["splitter_mode"] = "random"
711
- if not sklearn_check_version("1.0"):
712
- onedal_params["min_impurity_split"] = self.min_impurity_split
713
- else:
714
- onedal_params["min_impurity_split"] = None
715
-
716
- # Lazy evaluation of estimators_
717
- self._cached_estimators_ = None
718
-
719
- # Compute
720
- self._onedal_estimator = self._onedal_classifier(**onedal_params)
721
- self._onedal_estimator.fit(X, np.squeeze(y), sample_weight, queue=queue)
722
-
723
- self._save_attributes()
724
- if sklearn_check_version("1.2"):
725
- self._estimator = ExtraTreeClassifier()
726
-
727
- # Decapsulate classes_ attributes
728
- self.n_classes_ = self.n_classes_[0]
729
- self.classes_ = self.classes_[0]
730
- return self
731
-
732
777
  def _onedal_predict(self, X, queue=None):
733
- X = check_array(X, dtype=[np.float32, np.float64])
734
- check_is_fitted(self, "_onedal_model")
778
+ X = check_array(
779
+ X,
780
+ dtype=[np.float64, np.float32],
781
+ force_all_finite=False,
782
+ ) # Warning, order of dtype matters
783
+ check_is_fitted(self, "_onedal_estimator")
735
784
 
736
785
  if sklearn_check_version("1.0"):
737
786
  self._check_feature_names(X, reset=False)
@@ -740,8 +789,8 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
740
789
  return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
741
790
 
742
791
  def _onedal_predict_proba(self, X, queue=None):
743
- X = check_array(X, dtype=[np.float64, np.float32])
744
- check_is_fitted(self, "_onedal_model")
792
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
793
+ check_is_fitted(self, "_onedal_estimator")
745
794
 
746
795
  if sklearn_check_version("0.23"):
747
796
  self._check_n_features(X, reset=False)
@@ -750,7 +799,10 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
750
799
  return self._onedal_estimator.predict_proba(X, queue=queue)
751
800
 
752
801
 
753
- class ForestRegressor(sklearn_ForestRegressor, BaseTree):
802
+ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
803
+ _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
804
+ _get_tree_state = staticmethod(get_tree_state_reg)
805
+
754
806
  def __init__(
755
807
  self,
756
808
  estimator,
@@ -778,66 +830,21 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
778
830
  max_samples=max_samples,
779
831
  )
780
832
 
781
- # The splitter is recognized here for proper dispatching.
782
- self._estimator = estimator # TODO: Verify if this is done in older verions
783
- if self._estimator.__class__ == DecisionTreeRegressor:
784
- self._onedal_regressor = onedal_RandomForestRegressor
785
- elif self._estimator.__class__ == ExtraTreeRegressor:
786
- self._onedal_regressor = onedal_ExtraTreesRegressor
787
- else:
788
- raise TypeError(
789
- f"{estimator.__class__.__name__} is not a supported tree regressor"
790
- )
791
-
792
- def _estimators_(self):
793
- # _estimators_ should only be called if _onedal_model exists
794
- check_is_fitted(self, "_onedal_model")
795
- # convert model to estimators
796
- params = {
797
- "criterion": self.criterion,
798
- "max_depth": self.max_depth,
799
- "min_samples_split": self.min_samples_split,
800
- "min_samples_leaf": self.min_samples_leaf,
801
- "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
802
- "max_features": self.max_features,
803
- "max_leaf_nodes": self.max_leaf_nodes,
804
- "min_impurity_decrease": self.min_impurity_decrease,
805
- "random_state": None,
806
- }
807
- if not sklearn_check_version("1.0"):
808
- params["min_impurity_split"] = self.min_impurity_split
809
- est = self._estimator.__class__(**params)
810
- # we need to set est.tree_ field with Trees constructed from Intel(R)
811
- # oneAPI Data Analytics Library solution
812
- estimators_ = []
813
- random_state_checked = check_random_state(self.random_state)
814
-
815
- for i in range(self.n_estimators):
816
- est_i = clone(est)
817
- est_i.set_params(
818
- random_state=random_state_checked.randint(np.iinfo(np.int32).max)
819
- )
820
- if sklearn_check_version("1.0"):
821
- est_i.n_features_in_ = self.n_features_in_
822
- else:
823
- est_i.n_features_ = self.n_features_in_
824
- est_i.n_classes_ = 1
825
- est_i.n_outputs_ = self.n_outputs_
826
- tree_i_state_class = get_tree_state_reg(self._onedal_model, i)
827
- tree_i_state_dict = {
828
- "max_depth": tree_i_state_class.max_depth,
829
- "node_count": tree_i_state_class.node_count,
830
- "nodes": check_tree_nodes(tree_i_state_class.node_ar),
831
- "values": tree_i_state_class.value_ar,
832
- }
833
-
834
- est_i.tree_ = Tree(
835
- self.n_features_in_, np.array([1], dtype=np.intp), self.n_outputs_
836
- )
837
- est_i.tree_.__setstate__(tree_i_state_dict)
838
- estimators_.append(est_i)
833
+ # The splitter is checked against the class attribute for conformance
834
+ # This should only trigger if the user uses this class directly.
835
+ if (
836
+ self.estimator.__class__ == DecisionTreeRegressor
837
+ and self._onedal_factory != onedal_RandomForestRegressor
838
+ ):
839
+ self._onedal_factory = onedal_RandomForestRegressor
840
+ elif (
841
+ self.estimator.__class__ == ExtraTreeRegressor
842
+ and self._onedal_factory != onedal_ExtraTreesRegressor
843
+ ):
844
+ self._onedal_factory = onedal_ExtraTreesRegressor
839
845
 
840
- self._cached_estimators_ = estimators_
846
+ if self._onedal_factory is None:
847
+ raise TypeError(f" oneDAL estimator has not been set.")
841
848
 
842
849
  def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
843
850
  if sp.issparse(y):
@@ -885,12 +892,35 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
885
892
  ]
886
893
  )
887
894
 
895
+ if patching_status.get_status() and sklearn_check_version("1.4"):
896
+ try:
897
+ _assert_all_finite(X)
898
+ input_is_finite = True
899
+ except ValueError:
900
+ input_is_finite = False
901
+ patching_status.and_conditions(
902
+ [
903
+ (input_is_finite, "Non-finite input is not supported."),
904
+ (
905
+ self.monotonic_cst is None,
906
+ "Monotonicity constraints are not supported.",
907
+ ),
908
+ ]
909
+ )
910
+
888
911
  if patching_status.get_status():
889
- if sklearn_check_version("1.0"):
890
- self._check_feature_names(X, reset=True)
891
- X = check_array(X, dtype=[np.float64, np.float32])
892
- y = np.asarray(y)
893
- y = np.atleast_1d(y)
912
+ if sklearn_check_version("0.24"):
913
+ X, y = self._validate_data(
914
+ X,
915
+ y,
916
+ multi_output=True,
917
+ accept_sparse=True,
918
+ dtype=[np.float64, np.float32],
919
+ force_all_finite=False,
920
+ )
921
+ else:
922
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
923
+ y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
894
924
 
895
925
  if y.ndim == 2 and y.shape[1] == 1:
896
926
  warnings.warn(
@@ -901,15 +931,13 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
901
931
  stacklevel=2,
902
932
  )
903
933
 
904
- y = check_array(y, ensure_2d=False, dtype=X.dtype)
905
- check_consistent_length(X, y)
906
-
907
934
  if y.ndim == 1:
908
935
  # reshape is necessary to preserve the data contiguity against vs
909
936
  # [:, np.newaxis] that does not.
910
937
  y = np.reshape(y, (-1, 1))
911
938
 
912
939
  self.n_outputs_ = y.shape[1]
940
+
913
941
  patching_status.and_conditions(
914
942
  [
915
943
  (
@@ -919,30 +947,8 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
919
947
  ]
920
948
  )
921
949
 
922
- n_samples = X.shape[0]
923
- if isinstance(self.max_samples, numbers.Integral):
924
- if not sklearn_check_version("1.2"):
925
- if not (1 <= self.max_samples <= n_samples):
926
- msg = "`max_samples` must be in range 1 to {} but got value {}"
927
- raise ValueError(msg.format(n_samples, self.max_samples))
928
- else:
929
- if self.max_samples > n_samples:
930
- msg = "`max_samples` must be <= n_samples={} but got value {}"
931
- raise ValueError(msg.format(n_samples, self.max_samples))
932
- elif isinstance(self.max_samples, numbers.Real):
933
- if sklearn_check_version("1.2"):
934
- pass
935
- elif sklearn_check_version("1.0"):
936
- if not (0 < float(self.max_samples) <= 1):
937
- msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
938
- raise ValueError(msg.format(self.max_samples))
939
- else:
940
- if not (0 < float(self.max_samples) < 1):
941
- msg = "`max_samples` must be in range (0, 1) but got value {}"
942
- raise ValueError(msg.format(self.max_samples))
943
- elif self.max_samples is not None:
944
- msg = "`max_samples` should be int or float, but got type '{}'"
945
- raise TypeError(msg.format(type(self.max_samples)))
950
+ # Sklearn function used for doing checks on max_samples attribute
951
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
946
952
 
947
953
  if not self.bootstrap and self.max_samples is not None:
948
954
  raise ValueError(
@@ -951,6 +957,17 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
951
957
  "`max_sample=None`."
952
958
  )
953
959
 
960
+ if (
961
+ patching_status.get_status()
962
+ and (self.random_state is not None)
963
+ and (not daal_check_version((2024, "P", 0)))
964
+ ):
965
+ warnings.warn(
966
+ "Setting 'random_state' value is not supported. "
967
+ "State set by oneDAL to default value (777).",
968
+ RuntimeWarning,
969
+ )
970
+
954
971
  return patching_status, X, y, sample_weight
955
972
 
956
973
  def _onedal_cpu_supported(self, method_name, *data):
@@ -968,7 +985,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
968
985
  [
969
986
  (
970
987
  daal_check_version((2023, "P", 200))
971
- or self._estimator.__class__ == DecisionTreeClassifier,
988
+ or self.estimator.__class__ == DecisionTreeClassifier,
972
989
  "ExtraTrees only supported starting from oneDAL version 2023.2",
973
990
  ),
974
991
  (
@@ -978,28 +995,17 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
978
995
  ]
979
996
  )
980
997
 
981
- if (
982
- patching_status.get_status()
983
- and (self.random_state is not None)
984
- and (not daal_check_version((2024, "P", 0)))
985
- ):
986
- warnings.warn(
987
- "Setting 'random_state' value is not supported. "
988
- "State set by oneDAL to default value (777).",
989
- RuntimeWarning,
990
- )
991
-
992
- elif method_name in ["predict", "predict_proba"]:
998
+ elif method_name == "predict":
993
999
  X = data[0]
994
1000
 
995
1001
  patching_status.and_conditions(
996
1002
  [
997
- (hasattr(self, "_onedal_model"), "oneDAL model was not trained."),
1003
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
998
1004
  (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
999
1005
  (self.warm_start is False, "Warm start is not supported."),
1000
1006
  (
1001
1007
  daal_check_version((2023, "P", 200))
1002
- or self._estimator.__class__ == DecisionTreeClassifier,
1008
+ or self.estimator.__class__ == DecisionTreeClassifier,
1003
1009
  "ExtraTrees only supported starting from oneDAL version 2023.2",
1004
1010
  ),
1005
1011
  ]
@@ -1013,8 +1019,6 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
1013
1019
  ),
1014
1020
  ]
1015
1021
  )
1016
- else:
1017
- dal_ready = False
1018
1022
 
1019
1023
  else:
1020
1024
  raise RuntimeError(
@@ -1038,35 +1042,24 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
1038
1042
  [
1039
1043
  (
1040
1044
  daal_check_version((2023, "P", 100))
1041
- or self._estimator.__class__ == DecisionTreeClassifier,
1045
+ or self.estimator.__class__ == DecisionTreeClassifier,
1042
1046
  "ExtraTrees only supported starting from oneDAL version 2023.1",
1043
1047
  ),
1044
1048
  (sample_weight is not None, "sample_weight is not supported."),
1045
1049
  ]
1046
1050
  )
1047
1051
 
1048
- if (
1049
- patching_status.get_status()
1050
- and (self.random_state is not None)
1051
- and (not daal_check_version((2024, "P", 0)))
1052
- ):
1053
- warnings.warn(
1054
- "Setting 'random_state' value is not supported. "
1055
- "State set by oneDAL to default value (777).",
1056
- RuntimeWarning,
1057
- )
1058
-
1059
1052
  elif method_name == "predict":
1060
1053
  X = data[0]
1061
1054
 
1062
1055
  patching_status.and_conditions(
1063
1056
  [
1064
- (hasattr(self, "_onedal_model"), "oneDAL model was not trained."),
1057
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1065
1058
  (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1066
1059
  (self.warm_start is False, "Warm start is not supported."),
1067
1060
  (
1068
1061
  daal_check_version((2023, "P", 100))
1069
- or self._estimator.__class__ == DecisionTreeClassifier,
1062
+ or self.estimator.__class__ == DecisionTreeClassifier,
1070
1063
  "ExtraTrees only supported starting from oneDAL version 2023.1",
1071
1064
  ),
1072
1065
  ]
@@ -1088,76 +1081,11 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
1088
1081
 
1089
1082
  return patching_status
1090
1083
 
1091
- def _onedal_fit(self, X, y, sample_weight=None, queue=None):
1092
- if sp.issparse(y):
1093
- raise ValueError("sparse multilabel-indicator for y is not supported.")
1094
- if sklearn_check_version("1.2"):
1095
- self._validate_params()
1096
- else:
1097
- self._check_parameters()
1098
- if sample_weight is not None:
1099
- sample_weight = self.check_sample_weight(sample_weight, X)
1100
- if sklearn_check_version("1.0"):
1101
- self._check_feature_names(X, reset=True)
1102
- X = check_array(X, dtype=[np.float64, np.float32])
1103
- y = np.atleast_1d(np.asarray(y))
1104
- if y.ndim == 2 and y.shape[1] == 1:
1105
- warnings.warn(
1106
- "A column-vector y was passed when a 1d array was"
1107
- " expected. Please change the shape of y to "
1108
- "(n_samples,), for example using ravel().",
1109
- DataConversionWarning,
1110
- stacklevel=2,
1111
- )
1112
- y = check_array(y, ensure_2d=False, dtype=X.dtype)
1113
- check_consistent_length(X, y)
1114
- self.n_features_in_ = X.shape[1]
1115
- if not sklearn_check_version("1.0"):
1116
- self.n_features_ = self.n_features_in_
1117
-
1118
- if self.oob_score:
1119
- err = "out_of_bag_error_r2|out_of_bag_error_prediction"
1120
- else:
1121
- err = "none"
1122
-
1123
- onedal_params = {
1124
- "n_estimators": self.n_estimators,
1125
- "criterion": self.criterion,
1126
- "max_depth": self.max_depth,
1127
- "min_samples_split": self.min_samples_split,
1128
- "min_samples_leaf": self.min_samples_leaf,
1129
- "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
1130
- "max_features": self.max_features,
1131
- "max_leaf_nodes": self.max_leaf_nodes,
1132
- "min_impurity_decrease": self.min_impurity_decrease,
1133
- "bootstrap": self.bootstrap,
1134
- "oob_score": self.oob_score,
1135
- "n_jobs": self.n_jobs,
1136
- "random_state": self.random_state,
1137
- "verbose": self.verbose,
1138
- "warm_start": self.warm_start,
1139
- "error_metric_mode": err,
1140
- "variable_importance_mode": "mdi",
1141
- "max_samples": self.max_samples,
1142
- }
1143
- if daal_check_version((2023, "P", 101)):
1144
- onedal_params["splitter_mode"] = "random"
1145
-
1146
- # Lazy evaluation of estimators_
1147
- self._cached_estimators_ = None
1148
-
1149
- self._onedal_estimator = self._onedal_regressor(**onedal_params)
1150
- self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
1151
-
1152
- self._save_attributes()
1153
- if sklearn_check_version("1.2"):
1154
- self._estimator = ExtraTreeRegressor()
1155
-
1156
- return self
1157
-
1158
1084
  def _onedal_predict(self, X, queue=None):
1159
- X = check_array(X, dtype=[np.float32, np.float64])
1160
- check_is_fitted(self, "_onedal_model")
1085
+ X = check_array(
1086
+ X, dtype=[np.float64, np.float32], force_all_finite=False
1087
+ ) # Warning, order of dtype matters
1088
+ check_is_fitted(self, "_onedal_estimator")
1161
1089
 
1162
1090
  if sklearn_check_version("1.0"):
1163
1091
  self._check_feature_names(X, reset=False)
@@ -1193,28 +1121,85 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
1193
1121
  fit.__doc__ = sklearn_ForestRegressor.fit.__doc__
1194
1122
  predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1195
1123
 
1196
- if sklearn_check_version("1.0"):
1197
-
1198
- @deprecated(
1199
- "Attribute `n_features_` was deprecated in version 1.0 and will be "
1200
- "removed in 1.2. Use `n_features_in_` instead."
1201
- )
1202
- @property
1203
- def n_features_(self):
1204
- return self.n_features_in_
1205
-
1206
1124
 
1207
- class ExtraTreesClassifier(ForestClassifier):
1208
- __doc__ = sklearn_ExtraTreesClassifier.__doc__
1125
+ class RandomForestClassifier(ForestClassifier):
1126
+ __doc__ = sklearn_RandomForestClassifier.__doc__
1127
+ _onedal_factory = onedal_RandomForestClassifier
1209
1128
 
1210
1129
  if sklearn_check_version("1.2"):
1211
1130
  _parameter_constraints: dict = {
1212
- **sklearn_ExtraTreesClassifier._parameter_constraints,
1131
+ **sklearn_RandomForestClassifier._parameter_constraints,
1213
1132
  "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1214
1133
  "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1215
1134
  }
1216
1135
 
1217
- if sklearn_check_version("1.0"):
1136
+ if sklearn_check_version("1.4"):
1137
+
1138
+ def __init__(
1139
+ self,
1140
+ n_estimators=100,
1141
+ *,
1142
+ criterion="gini",
1143
+ max_depth=None,
1144
+ min_samples_split=2,
1145
+ min_samples_leaf=1,
1146
+ min_weight_fraction_leaf=0.0,
1147
+ max_features="sqrt",
1148
+ max_leaf_nodes=None,
1149
+ min_impurity_decrease=0.0,
1150
+ bootstrap=True,
1151
+ oob_score=False,
1152
+ n_jobs=None,
1153
+ random_state=None,
1154
+ verbose=0,
1155
+ warm_start=False,
1156
+ class_weight=None,
1157
+ ccp_alpha=0.0,
1158
+ max_samples=None,
1159
+ monotonic_cst=None,
1160
+ max_bins=256,
1161
+ min_bin_size=1,
1162
+ ):
1163
+ super().__init__(
1164
+ DecisionTreeClassifier(),
1165
+ n_estimators,
1166
+ estimator_params=(
1167
+ "criterion",
1168
+ "max_depth",
1169
+ "min_samples_split",
1170
+ "min_samples_leaf",
1171
+ "min_weight_fraction_leaf",
1172
+ "max_features",
1173
+ "max_leaf_nodes",
1174
+ "min_impurity_decrease",
1175
+ "random_state",
1176
+ "ccp_alpha",
1177
+ "monotonic_cst",
1178
+ ),
1179
+ bootstrap=bootstrap,
1180
+ oob_score=oob_score,
1181
+ n_jobs=n_jobs,
1182
+ random_state=random_state,
1183
+ verbose=verbose,
1184
+ warm_start=warm_start,
1185
+ class_weight=class_weight,
1186
+ max_samples=max_samples,
1187
+ )
1188
+
1189
+ self.criterion = criterion
1190
+ self.max_depth = max_depth
1191
+ self.min_samples_split = min_samples_split
1192
+ self.min_samples_leaf = min_samples_leaf
1193
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1194
+ self.max_features = max_features
1195
+ self.max_leaf_nodes = max_leaf_nodes
1196
+ self.min_impurity_decrease = min_impurity_decrease
1197
+ self.ccp_alpha = ccp_alpha
1198
+ self.max_bins = max_bins
1199
+ self.min_bin_size = min_bin_size
1200
+ self.monotonic_cst = monotonic_cst
1201
+
1202
+ elif sklearn_check_version("1.0"):
1218
1203
 
1219
1204
  def __init__(
1220
1205
  self,
@@ -1228,7 +1213,7 @@ class ExtraTreesClassifier(ForestClassifier):
1228
1213
  max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1229
1214
  max_leaf_nodes=None,
1230
1215
  min_impurity_decrease=0.0,
1231
- bootstrap=False,
1216
+ bootstrap=True,
1232
1217
  oob_score=False,
1233
1218
  n_jobs=None,
1234
1219
  random_state=None,
@@ -1241,7 +1226,7 @@ class ExtraTreesClassifier(ForestClassifier):
1241
1226
  min_bin_size=1,
1242
1227
  ):
1243
1228
  super().__init__(
1244
- ExtraTreeClassifier(),
1229
+ DecisionTreeClassifier(),
1245
1230
  n_estimators,
1246
1231
  estimator_params=(
1247
1232
  "criterion",
@@ -1292,7 +1277,7 @@ class ExtraTreesClassifier(ForestClassifier):
1292
1277
  max_leaf_nodes=None,
1293
1278
  min_impurity_decrease=0.0,
1294
1279
  min_impurity_split=None,
1295
- bootstrap=False,
1280
+ bootstrap=True,
1296
1281
  oob_score=False,
1297
1282
  n_jobs=None,
1298
1283
  random_state=None,
@@ -1305,7 +1290,7 @@ class ExtraTreesClassifier(ForestClassifier):
1305
1290
  min_bin_size=1,
1306
1291
  ):
1307
1292
  super().__init__(
1308
- ExtraTreeClassifier(),
1293
+ DecisionTreeClassifier(),
1309
1294
  n_estimators,
1310
1295
  estimator_params=(
1311
1296
  "criterion",
@@ -1346,17 +1331,82 @@ class ExtraTreesClassifier(ForestClassifier):
1346
1331
  self.min_bin_size = min_bin_size
1347
1332
 
1348
1333
 
1349
- class ExtraTreesRegressor(ForestRegressor):
1350
- __doc__ = sklearn_ExtraTreesRegressor.__doc__
1334
+ class RandomForestRegressor(ForestRegressor):
1335
+ __doc__ = sklearn_RandomForestRegressor.__doc__
1336
+ _onedal_factory = onedal_RandomForestRegressor
1351
1337
 
1352
1338
  if sklearn_check_version("1.2"):
1353
1339
  _parameter_constraints: dict = {
1354
- **sklearn_ExtraTreesRegressor._parameter_constraints,
1340
+ **sklearn_RandomForestRegressor._parameter_constraints,
1355
1341
  "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1356
1342
  "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1357
1343
  }
1358
1344
 
1359
- if sklearn_check_version("1.0"):
1345
+ if sklearn_check_version("1.4"):
1346
+
1347
+ def __init__(
1348
+ self,
1349
+ n_estimators=100,
1350
+ *,
1351
+ criterion="squared_error",
1352
+ max_depth=None,
1353
+ min_samples_split=2,
1354
+ min_samples_leaf=1,
1355
+ min_weight_fraction_leaf=0.0,
1356
+ max_features=1.0,
1357
+ max_leaf_nodes=None,
1358
+ min_impurity_decrease=0.0,
1359
+ bootstrap=True,
1360
+ oob_score=False,
1361
+ n_jobs=None,
1362
+ random_state=None,
1363
+ verbose=0,
1364
+ warm_start=False,
1365
+ ccp_alpha=0.0,
1366
+ max_samples=None,
1367
+ monotonic_cst=None,
1368
+ max_bins=256,
1369
+ min_bin_size=1,
1370
+ ):
1371
+ super().__init__(
1372
+ DecisionTreeRegressor(),
1373
+ n_estimators=n_estimators,
1374
+ estimator_params=(
1375
+ "criterion",
1376
+ "max_depth",
1377
+ "min_samples_split",
1378
+ "min_samples_leaf",
1379
+ "min_weight_fraction_leaf",
1380
+ "max_features",
1381
+ "max_leaf_nodes",
1382
+ "min_impurity_decrease",
1383
+ "random_state",
1384
+ "ccp_alpha",
1385
+ "monotonic_cst",
1386
+ ),
1387
+ bootstrap=bootstrap,
1388
+ oob_score=oob_score,
1389
+ n_jobs=n_jobs,
1390
+ random_state=random_state,
1391
+ verbose=verbose,
1392
+ warm_start=warm_start,
1393
+ max_samples=max_samples,
1394
+ )
1395
+
1396
+ self.criterion = criterion
1397
+ self.max_depth = max_depth
1398
+ self.min_samples_split = min_samples_split
1399
+ self.min_samples_leaf = min_samples_leaf
1400
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1401
+ self.max_features = max_features
1402
+ self.max_leaf_nodes = max_leaf_nodes
1403
+ self.min_impurity_decrease = min_impurity_decrease
1404
+ self.ccp_alpha = ccp_alpha
1405
+ self.max_bins = max_bins
1406
+ self.min_bin_size = min_bin_size
1407
+ self.monotonic_cst = monotonic_cst
1408
+
1409
+ elif sklearn_check_version("1.0"):
1360
1410
 
1361
1411
  def __init__(
1362
1412
  self,
@@ -1370,7 +1420,7 @@ class ExtraTreesRegressor(ForestRegressor):
1370
1420
  max_features=1.0 if sklearn_check_version("1.1") else "auto",
1371
1421
  max_leaf_nodes=None,
1372
1422
  min_impurity_decrease=0.0,
1373
- bootstrap=False,
1423
+ bootstrap=True,
1374
1424
  oob_score=False,
1375
1425
  n_jobs=None,
1376
1426
  random_state=None,
@@ -1382,7 +1432,7 @@ class ExtraTreesRegressor(ForestRegressor):
1382
1432
  min_bin_size=1,
1383
1433
  ):
1384
1434
  super().__init__(
1385
- estimator=ExtraTreeRegressor(),
1435
+ DecisionTreeRegressor(),
1386
1436
  n_estimators=n_estimators,
1387
1437
  estimator_params=(
1388
1438
  "criterion",
@@ -1432,7 +1482,7 @@ class ExtraTreesRegressor(ForestRegressor):
1432
1482
  max_leaf_nodes=None,
1433
1483
  min_impurity_decrease=0.0,
1434
1484
  min_impurity_split=None,
1435
- bootstrap=False,
1485
+ bootstrap=True,
1436
1486
  oob_score=False,
1437
1487
  n_jobs=None,
1438
1488
  random_state=None,
@@ -1444,7 +1494,7 @@ class ExtraTreesRegressor(ForestRegressor):
1444
1494
  min_bin_size=1,
1445
1495
  ):
1446
1496
  super().__init__(
1447
- estimator=ExtraTreeRegressor(),
1497
+ DecisionTreeRegressor(),
1448
1498
  n_estimators=n_estimators,
1449
1499
  estimator_params=(
1450
1500
  "criterion",
@@ -1479,3 +1529,419 @@ class ExtraTreesRegressor(ForestRegressor):
1479
1529
  self.ccp_alpha = ccp_alpha
1480
1530
  self.max_bins = max_bins
1481
1531
  self.min_bin_size = min_bin_size
1532
+
1533
+
1534
+ class ExtraTreesClassifier(ForestClassifier):
1535
+ __doc__ = sklearn_ExtraTreesClassifier.__doc__
1536
+ _onedal_factory = onedal_ExtraTreesClassifier
1537
+
1538
+ if sklearn_check_version("1.2"):
1539
+ _parameter_constraints: dict = {
1540
+ **sklearn_ExtraTreesClassifier._parameter_constraints,
1541
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1542
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1543
+ }
1544
+
1545
+ if sklearn_check_version("1.4"):
1546
+
1547
+ def __init__(
1548
+ self,
1549
+ n_estimators=100,
1550
+ *,
1551
+ criterion="gini",
1552
+ max_depth=None,
1553
+ min_samples_split=2,
1554
+ min_samples_leaf=1,
1555
+ min_weight_fraction_leaf=0.0,
1556
+ max_features="sqrt",
1557
+ max_leaf_nodes=None,
1558
+ min_impurity_decrease=0.0,
1559
+ bootstrap=False,
1560
+ oob_score=False,
1561
+ n_jobs=None,
1562
+ random_state=None,
1563
+ verbose=0,
1564
+ warm_start=False,
1565
+ class_weight=None,
1566
+ ccp_alpha=0.0,
1567
+ max_samples=None,
1568
+ monotonic_cst=None,
1569
+ max_bins=256,
1570
+ min_bin_size=1,
1571
+ ):
1572
+ super().__init__(
1573
+ ExtraTreeClassifier(),
1574
+ n_estimators,
1575
+ estimator_params=(
1576
+ "criterion",
1577
+ "max_depth",
1578
+ "min_samples_split",
1579
+ "min_samples_leaf",
1580
+ "min_weight_fraction_leaf",
1581
+ "max_features",
1582
+ "max_leaf_nodes",
1583
+ "min_impurity_decrease",
1584
+ "random_state",
1585
+ "ccp_alpha",
1586
+ "monotonic_cst",
1587
+ ),
1588
+ bootstrap=bootstrap,
1589
+ oob_score=oob_score,
1590
+ n_jobs=n_jobs,
1591
+ random_state=random_state,
1592
+ verbose=verbose,
1593
+ warm_start=warm_start,
1594
+ class_weight=class_weight,
1595
+ max_samples=max_samples,
1596
+ )
1597
+
1598
+ self.criterion = criterion
1599
+ self.max_depth = max_depth
1600
+ self.min_samples_split = min_samples_split
1601
+ self.min_samples_leaf = min_samples_leaf
1602
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1603
+ self.max_features = max_features
1604
+ self.max_leaf_nodes = max_leaf_nodes
1605
+ self.min_impurity_decrease = min_impurity_decrease
1606
+ self.ccp_alpha = ccp_alpha
1607
+ self.max_bins = max_bins
1608
+ self.min_bin_size = min_bin_size
1609
+ self.monotonic_cst = monotonic_cst
1610
+
1611
+ elif sklearn_check_version("1.0"):
1612
+
1613
+ def __init__(
1614
+ self,
1615
+ n_estimators=100,
1616
+ *,
1617
+ criterion="gini",
1618
+ max_depth=None,
1619
+ min_samples_split=2,
1620
+ min_samples_leaf=1,
1621
+ min_weight_fraction_leaf=0.0,
1622
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1623
+ max_leaf_nodes=None,
1624
+ min_impurity_decrease=0.0,
1625
+ bootstrap=False,
1626
+ oob_score=False,
1627
+ n_jobs=None,
1628
+ random_state=None,
1629
+ verbose=0,
1630
+ warm_start=False,
1631
+ class_weight=None,
1632
+ ccp_alpha=0.0,
1633
+ max_samples=None,
1634
+ max_bins=256,
1635
+ min_bin_size=1,
1636
+ ):
1637
+ super().__init__(
1638
+ ExtraTreeClassifier(),
1639
+ n_estimators,
1640
+ estimator_params=(
1641
+ "criterion",
1642
+ "max_depth",
1643
+ "min_samples_split",
1644
+ "min_samples_leaf",
1645
+ "min_weight_fraction_leaf",
1646
+ "max_features",
1647
+ "max_leaf_nodes",
1648
+ "min_impurity_decrease",
1649
+ "random_state",
1650
+ "ccp_alpha",
1651
+ ),
1652
+ bootstrap=bootstrap,
1653
+ oob_score=oob_score,
1654
+ n_jobs=n_jobs,
1655
+ random_state=random_state,
1656
+ verbose=verbose,
1657
+ warm_start=warm_start,
1658
+ class_weight=class_weight,
1659
+ max_samples=max_samples,
1660
+ )
1661
+
1662
+ self.criterion = criterion
1663
+ self.max_depth = max_depth
1664
+ self.min_samples_split = min_samples_split
1665
+ self.min_samples_leaf = min_samples_leaf
1666
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1667
+ self.max_features = max_features
1668
+ self.max_leaf_nodes = max_leaf_nodes
1669
+ self.min_impurity_decrease = min_impurity_decrease
1670
+ self.ccp_alpha = ccp_alpha
1671
+ self.max_bins = max_bins
1672
+ self.min_bin_size = min_bin_size
1673
+
1674
+ else:
1675
+
1676
+ def __init__(
1677
+ self,
1678
+ n_estimators=100,
1679
+ *,
1680
+ criterion="gini",
1681
+ max_depth=None,
1682
+ min_samples_split=2,
1683
+ min_samples_leaf=1,
1684
+ min_weight_fraction_leaf=0.0,
1685
+ max_features="auto",
1686
+ max_leaf_nodes=None,
1687
+ min_impurity_decrease=0.0,
1688
+ min_impurity_split=None,
1689
+ bootstrap=False,
1690
+ oob_score=False,
1691
+ n_jobs=None,
1692
+ random_state=None,
1693
+ verbose=0,
1694
+ warm_start=False,
1695
+ class_weight=None,
1696
+ ccp_alpha=0.0,
1697
+ max_samples=None,
1698
+ max_bins=256,
1699
+ min_bin_size=1,
1700
+ ):
1701
+ super().__init__(
1702
+ ExtraTreeClassifier(),
1703
+ n_estimators,
1704
+ estimator_params=(
1705
+ "criterion",
1706
+ "max_depth",
1707
+ "min_samples_split",
1708
+ "min_samples_leaf",
1709
+ "min_weight_fraction_leaf",
1710
+ "max_features",
1711
+ "max_leaf_nodes",
1712
+ "min_impurity_decrease",
1713
+ "min_impurity_split",
1714
+ "random_state",
1715
+ "ccp_alpha",
1716
+ ),
1717
+ bootstrap=bootstrap,
1718
+ oob_score=oob_score,
1719
+ n_jobs=n_jobs,
1720
+ random_state=random_state,
1721
+ verbose=verbose,
1722
+ warm_start=warm_start,
1723
+ class_weight=class_weight,
1724
+ max_samples=max_samples,
1725
+ )
1726
+
1727
+ self.criterion = criterion
1728
+ self.max_depth = max_depth
1729
+ self.min_samples_split = min_samples_split
1730
+ self.min_samples_leaf = min_samples_leaf
1731
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1732
+ self.max_features = max_features
1733
+ self.max_leaf_nodes = max_leaf_nodes
1734
+ self.min_impurity_decrease = min_impurity_decrease
1735
+ self.min_impurity_split = min_impurity_split
1736
+ self.ccp_alpha = ccp_alpha
1737
+ self.max_bins = max_bins
1738
+ self.min_bin_size = min_bin_size
1739
+ self.max_bins = max_bins
1740
+ self.min_bin_size = min_bin_size
1741
+
1742
+
1743
+ class ExtraTreesRegressor(ForestRegressor):
1744
+ __doc__ = sklearn_ExtraTreesRegressor.__doc__
1745
+ _onedal_factory = onedal_ExtraTreesRegressor
1746
+
1747
+ if sklearn_check_version("1.2"):
1748
+ _parameter_constraints: dict = {
1749
+ **sklearn_ExtraTreesRegressor._parameter_constraints,
1750
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1751
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1752
+ }
1753
+
1754
+ if sklearn_check_version("1.4"):
1755
+
1756
+ def __init__(
1757
+ self,
1758
+ n_estimators=100,
1759
+ *,
1760
+ criterion="squared_error",
1761
+ max_depth=None,
1762
+ min_samples_split=2,
1763
+ min_samples_leaf=1,
1764
+ min_weight_fraction_leaf=0.0,
1765
+ max_features=1.0,
1766
+ max_leaf_nodes=None,
1767
+ min_impurity_decrease=0.0,
1768
+ bootstrap=False,
1769
+ oob_score=False,
1770
+ n_jobs=None,
1771
+ random_state=None,
1772
+ verbose=0,
1773
+ warm_start=False,
1774
+ ccp_alpha=0.0,
1775
+ max_samples=None,
1776
+ monotonic_cst=None,
1777
+ max_bins=256,
1778
+ min_bin_size=1,
1779
+ ):
1780
+ super().__init__(
1781
+ ExtraTreeRegressor(),
1782
+ n_estimators=n_estimators,
1783
+ estimator_params=(
1784
+ "criterion",
1785
+ "max_depth",
1786
+ "min_samples_split",
1787
+ "min_samples_leaf",
1788
+ "min_weight_fraction_leaf",
1789
+ "max_features",
1790
+ "max_leaf_nodes",
1791
+ "min_impurity_decrease",
1792
+ "random_state",
1793
+ "ccp_alpha",
1794
+ "monotonic_cst",
1795
+ ),
1796
+ bootstrap=bootstrap,
1797
+ oob_score=oob_score,
1798
+ n_jobs=n_jobs,
1799
+ random_state=random_state,
1800
+ verbose=verbose,
1801
+ warm_start=warm_start,
1802
+ max_samples=max_samples,
1803
+ )
1804
+
1805
+ self.criterion = criterion
1806
+ self.max_depth = max_depth
1807
+ self.min_samples_split = min_samples_split
1808
+ self.min_samples_leaf = min_samples_leaf
1809
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1810
+ self.max_features = max_features
1811
+ self.max_leaf_nodes = max_leaf_nodes
1812
+ self.min_impurity_decrease = min_impurity_decrease
1813
+ self.ccp_alpha = ccp_alpha
1814
+ self.max_bins = max_bins
1815
+ self.min_bin_size = min_bin_size
1816
+ self.monotonic_cst = monotonic_cst
1817
+
1818
+ elif sklearn_check_version("1.0"):
1819
+
1820
+ def __init__(
1821
+ self,
1822
+ n_estimators=100,
1823
+ *,
1824
+ criterion="squared_error",
1825
+ max_depth=None,
1826
+ min_samples_split=2,
1827
+ min_samples_leaf=1,
1828
+ min_weight_fraction_leaf=0.0,
1829
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1830
+ max_leaf_nodes=None,
1831
+ min_impurity_decrease=0.0,
1832
+ bootstrap=False,
1833
+ oob_score=False,
1834
+ n_jobs=None,
1835
+ random_state=None,
1836
+ verbose=0,
1837
+ warm_start=False,
1838
+ ccp_alpha=0.0,
1839
+ max_samples=None,
1840
+ max_bins=256,
1841
+ min_bin_size=1,
1842
+ ):
1843
+ super().__init__(
1844
+ ExtraTreeRegressor(),
1845
+ n_estimators=n_estimators,
1846
+ estimator_params=(
1847
+ "criterion",
1848
+ "max_depth",
1849
+ "min_samples_split",
1850
+ "min_samples_leaf",
1851
+ "min_weight_fraction_leaf",
1852
+ "max_features",
1853
+ "max_leaf_nodes",
1854
+ "min_impurity_decrease",
1855
+ "random_state",
1856
+ "ccp_alpha",
1857
+ ),
1858
+ bootstrap=bootstrap,
1859
+ oob_score=oob_score,
1860
+ n_jobs=n_jobs,
1861
+ random_state=random_state,
1862
+ verbose=verbose,
1863
+ warm_start=warm_start,
1864
+ max_samples=max_samples,
1865
+ )
1866
+
1867
+ self.criterion = criterion
1868
+ self.max_depth = max_depth
1869
+ self.min_samples_split = min_samples_split
1870
+ self.min_samples_leaf = min_samples_leaf
1871
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1872
+ self.max_features = max_features
1873
+ self.max_leaf_nodes = max_leaf_nodes
1874
+ self.min_impurity_decrease = min_impurity_decrease
1875
+ self.ccp_alpha = ccp_alpha
1876
+ self.max_bins = max_bins
1877
+ self.min_bin_size = min_bin_size
1878
+
1879
+ else:
1880
+
1881
+ def __init__(
1882
+ self,
1883
+ n_estimators=100,
1884
+ *,
1885
+ criterion="mse",
1886
+ max_depth=None,
1887
+ min_samples_split=2,
1888
+ min_samples_leaf=1,
1889
+ min_weight_fraction_leaf=0.0,
1890
+ max_features="auto",
1891
+ max_leaf_nodes=None,
1892
+ min_impurity_decrease=0.0,
1893
+ min_impurity_split=None,
1894
+ bootstrap=False,
1895
+ oob_score=False,
1896
+ n_jobs=None,
1897
+ random_state=None,
1898
+ verbose=0,
1899
+ warm_start=False,
1900
+ ccp_alpha=0.0,
1901
+ max_samples=None,
1902
+ max_bins=256,
1903
+ min_bin_size=1,
1904
+ ):
1905
+ super().__init__(
1906
+ ExtraTreeRegressor(),
1907
+ n_estimators=n_estimators,
1908
+ estimator_params=(
1909
+ "criterion",
1910
+ "max_depth",
1911
+ "min_samples_split",
1912
+ "min_samples_leaf",
1913
+ "min_weight_fraction_leaf",
1914
+ "max_features",
1915
+ "max_leaf_nodes",
1916
+ "min_impurity_decrease",
1917
+ "min_impurity_split" "random_state",
1918
+ "ccp_alpha",
1919
+ ),
1920
+ bootstrap=bootstrap,
1921
+ oob_score=oob_score,
1922
+ n_jobs=n_jobs,
1923
+ random_state=random_state,
1924
+ verbose=verbose,
1925
+ warm_start=warm_start,
1926
+ max_samples=max_samples,
1927
+ )
1928
+
1929
+ self.criterion = criterion
1930
+ self.max_depth = max_depth
1931
+ self.min_samples_split = min_samples_split
1932
+ self.min_samples_leaf = min_samples_leaf
1933
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1934
+ self.max_features = max_features
1935
+ self.max_leaf_nodes = max_leaf_nodes
1936
+ self.min_impurity_decrease = min_impurity_decrease
1937
+ self.min_impurity_split = min_impurity_split
1938
+ self.ccp_alpha = ccp_alpha
1939
+ self.max_bins = max_bins
1940
+ self.min_bin_size = min_bin_size
1941
+
1942
+
1943
+ # Allow for isinstance calls without inheritance changes using ABCMeta
1944
+ sklearn_RandomForestClassifier.register(RandomForestClassifier)
1945
+ sklearn_RandomForestRegressor.register(RandomForestRegressor)
1946
+ sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
1947
+ sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)