scikit-learn-intelex 2024.0.0__py38-none-win_amd64.whl → 2024.0.1__py38-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_utils.py +2 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +70 -77
- {scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +6 -2
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +960 -494
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +18 -15
- {scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +59 -12
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +15 -4
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +3 -1
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +2 -6
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -14
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +8 -5
- {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
- scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -54
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -17
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1557
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -20
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
- scikit_learn_intelex-2024.0.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py +0 -47
- scikit_learn_intelex-2024.0.0.dist-info/RECORD +0 -98
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
- {scikit_learn_intelex-2024.0.0.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
- {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2024.0.0.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
|
@@ -24,9 +24,16 @@ from scipy import sparse as sp
|
|
|
24
24
|
from sklearn.base import clone
|
|
25
25
|
from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
|
|
26
26
|
from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
|
|
27
|
+
from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
|
|
28
|
+
from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
|
|
29
|
+
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
27
30
|
from sklearn.exceptions import DataConversionWarning
|
|
28
|
-
from sklearn.tree import
|
|
29
|
-
|
|
31
|
+
from sklearn.tree import (
|
|
32
|
+
DecisionTreeClassifier,
|
|
33
|
+
DecisionTreeRegressor,
|
|
34
|
+
ExtraTreeClassifier,
|
|
35
|
+
ExtraTreeRegressor,
|
|
36
|
+
)
|
|
30
37
|
from sklearn.tree._tree import Tree
|
|
31
38
|
from sklearn.utils import check_random_state, deprecated
|
|
32
39
|
from sklearn.utils.validation import (
|
|
@@ -53,18 +60,117 @@ try:
|
|
|
53
60
|
except ModuleNotFoundError:
|
|
54
61
|
from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
|
|
55
62
|
from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
|
|
63
|
+
|
|
56
64
|
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
57
65
|
from onedal.utils import _num_features, _num_samples
|
|
58
66
|
|
|
59
|
-
from
|
|
60
|
-
from
|
|
61
|
-
from
|
|
67
|
+
from .._config import get_config
|
|
68
|
+
from .._device_offload import dispatch, wrap_output_data
|
|
69
|
+
from .._utils import PatchingConditionsChain
|
|
62
70
|
|
|
63
71
|
if sklearn_check_version("1.2"):
|
|
64
72
|
from sklearn.utils._param_validation import Interval
|
|
73
|
+
if sklearn_check_version("1.4"):
|
|
74
|
+
from daal4py.sklearn.utils import _assert_all_finite
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class BaseForest(ABC):
|
|
78
|
+
_onedal_factory = None
|
|
79
|
+
|
|
80
|
+
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
81
|
+
if sklearn_check_version("0.24"):
|
|
82
|
+
X, y = self._validate_data(
|
|
83
|
+
X,
|
|
84
|
+
y,
|
|
85
|
+
multi_output=False,
|
|
86
|
+
accept_sparse=False,
|
|
87
|
+
dtype=[np.float64, np.float32],
|
|
88
|
+
force_all_finite=False,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
X, y = check_X_y(
|
|
92
|
+
X,
|
|
93
|
+
y,
|
|
94
|
+
accept_sparse=False,
|
|
95
|
+
dtype=[np.float64, np.float32],
|
|
96
|
+
multi_output=False,
|
|
97
|
+
force_all_finite=False,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if sample_weight is not None:
|
|
101
|
+
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
102
|
+
|
|
103
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
104
|
+
warnings.warn(
|
|
105
|
+
"A column-vector y was passed when a 1d array was"
|
|
106
|
+
" expected. Please change the shape of y to "
|
|
107
|
+
"(n_samples,), for example using ravel().",
|
|
108
|
+
DataConversionWarning,
|
|
109
|
+
stacklevel=2,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if y.ndim == 1:
|
|
113
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
114
|
+
# [:, np.newaxis] that does not.
|
|
115
|
+
y = np.reshape(y, (-1, 1))
|
|
116
|
+
|
|
117
|
+
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
118
|
+
|
|
119
|
+
self.n_features_in_ = X.shape[1]
|
|
120
|
+
|
|
121
|
+
if expanded_class_weight is not None:
|
|
122
|
+
if sample_weight is not None:
|
|
123
|
+
sample_weight = sample_weight * expanded_class_weight
|
|
124
|
+
else:
|
|
125
|
+
sample_weight = expanded_class_weight
|
|
126
|
+
if sample_weight is not None:
|
|
127
|
+
sample_weight = [sample_weight]
|
|
128
|
+
|
|
129
|
+
onedal_params = {
|
|
130
|
+
"n_estimators": self.n_estimators,
|
|
131
|
+
"criterion": self.criterion,
|
|
132
|
+
"max_depth": self.max_depth,
|
|
133
|
+
"min_samples_split": self.min_samples_split,
|
|
134
|
+
"min_samples_leaf": self.min_samples_leaf,
|
|
135
|
+
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
136
|
+
"max_features": self.max_features,
|
|
137
|
+
"max_leaf_nodes": self.max_leaf_nodes,
|
|
138
|
+
"min_impurity_decrease": self.min_impurity_decrease,
|
|
139
|
+
"bootstrap": self.bootstrap,
|
|
140
|
+
"oob_score": self.oob_score,
|
|
141
|
+
"n_jobs": self.n_jobs,
|
|
142
|
+
"random_state": self.random_state,
|
|
143
|
+
"verbose": self.verbose,
|
|
144
|
+
"warm_start": self.warm_start,
|
|
145
|
+
"error_metric_mode": self._err if self.oob_score else "none",
|
|
146
|
+
"variable_importance_mode": "mdi",
|
|
147
|
+
"class_weight": self.class_weight,
|
|
148
|
+
"max_bins": self.max_bins,
|
|
149
|
+
"min_bin_size": self.min_bin_size,
|
|
150
|
+
"max_samples": self.max_samples,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
if not sklearn_check_version("1.0"):
|
|
154
|
+
onedal_params["min_impurity_split"] = self.min_impurity_split
|
|
155
|
+
else:
|
|
156
|
+
onedal_params["min_impurity_split"] = None
|
|
157
|
+
|
|
158
|
+
# Lazy evaluation of estimators_
|
|
159
|
+
self._cached_estimators_ = None
|
|
160
|
+
|
|
161
|
+
# Compute
|
|
162
|
+
self._onedal_estimator = self._onedal_factory(**onedal_params)
|
|
163
|
+
self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
|
|
164
|
+
|
|
165
|
+
self._save_attributes()
|
|
166
|
+
|
|
167
|
+
# Decapsulate classes_ attributes
|
|
168
|
+
if hasattr(self, "classes_") and self.n_outputs_ == 1:
|
|
169
|
+
self.n_classes_ = self.n_classes_[0]
|
|
170
|
+
self.classes_ = self.classes_[0]
|
|
65
171
|
|
|
172
|
+
return self
|
|
66
173
|
|
|
67
|
-
class BaseTree(ABC):
|
|
68
174
|
def _fit_proba(self, X, y, sample_weight=None, queue=None):
|
|
69
175
|
params = self.get_params()
|
|
70
176
|
self.__class__(**params)
|
|
@@ -75,8 +181,6 @@ class BaseTree(ABC):
|
|
|
75
181
|
cfg["target_offload"] = queue
|
|
76
182
|
|
|
77
183
|
def _save_attributes(self):
|
|
78
|
-
self._onedal_model = self._onedal_estimator._onedal_model
|
|
79
|
-
|
|
80
184
|
if self.oob_score:
|
|
81
185
|
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
82
186
|
if hasattr(self._onedal_estimator, "oob_prediction_"):
|
|
@@ -85,6 +189,8 @@ class BaseTree(ABC):
|
|
|
85
189
|
self.oob_decision_function_ = (
|
|
86
190
|
self._onedal_estimator.oob_decision_function_
|
|
87
191
|
)
|
|
192
|
+
|
|
193
|
+
self._validate_estimator()
|
|
88
194
|
return self
|
|
89
195
|
|
|
90
196
|
# TODO:
|
|
@@ -183,6 +289,7 @@ class BaseTree(ABC):
|
|
|
183
289
|
ensure_2d=False,
|
|
184
290
|
dtype=dtype,
|
|
185
291
|
order="C",
|
|
292
|
+
force_all_finite=False,
|
|
186
293
|
)
|
|
187
294
|
if sample_weight.ndim != 1:
|
|
188
295
|
raise ValueError("Sample weights must be 1D array or scalar")
|
|
@@ -198,7 +305,7 @@ class BaseTree(ABC):
|
|
|
198
305
|
@property
|
|
199
306
|
def estimators_(self):
|
|
200
307
|
if hasattr(self, "_cached_estimators_"):
|
|
201
|
-
if self._cached_estimators_ is None
|
|
308
|
+
if self._cached_estimators_ is None:
|
|
202
309
|
self._estimators_()
|
|
203
310
|
return self._cached_estimators_
|
|
204
311
|
else:
|
|
@@ -211,13 +318,99 @@ class BaseTree(ABC):
|
|
|
211
318
|
# Needed to allow for proper sklearn operation in fallback mode
|
|
212
319
|
self._cached_estimators_ = estimators
|
|
213
320
|
|
|
321
|
+
def _estimators_(self):
|
|
322
|
+
# _estimators_ should only be called if _onedal_estimator exists
|
|
323
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
324
|
+
if hasattr(self, "n_classes_"):
|
|
325
|
+
n_classes_ = (
|
|
326
|
+
self.n_classes_
|
|
327
|
+
if isinstance(self.n_classes_, int)
|
|
328
|
+
else self.n_classes_[0]
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
n_classes_ = 1
|
|
332
|
+
|
|
333
|
+
# convert model to estimators
|
|
334
|
+
params = {
|
|
335
|
+
"criterion": self._onedal_estimator.criterion,
|
|
336
|
+
"max_depth": self._onedal_estimator.max_depth,
|
|
337
|
+
"min_samples_split": self._onedal_estimator.min_samples_split,
|
|
338
|
+
"min_samples_leaf": self._onedal_estimator.min_samples_leaf,
|
|
339
|
+
"min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
|
|
340
|
+
"max_features": self._onedal_estimator.max_features,
|
|
341
|
+
"max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
|
|
342
|
+
"min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
|
|
343
|
+
"random_state": None,
|
|
344
|
+
}
|
|
345
|
+
if not sklearn_check_version("1.0"):
|
|
346
|
+
params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
|
|
347
|
+
est = self.estimator.__class__(**params)
|
|
348
|
+
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
349
|
+
# oneAPI Data Analytics Library solution
|
|
350
|
+
estimators_ = []
|
|
351
|
+
|
|
352
|
+
random_state_checked = check_random_state(self.random_state)
|
|
353
|
+
|
|
354
|
+
for i in range(self._onedal_estimator.n_estimators):
|
|
355
|
+
est_i = clone(est)
|
|
356
|
+
est_i.set_params(
|
|
357
|
+
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
358
|
+
)
|
|
359
|
+
if sklearn_check_version("1.0"):
|
|
360
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
361
|
+
else:
|
|
362
|
+
est_i.n_features_ = self.n_features_in_
|
|
363
|
+
est_i.n_outputs_ = self.n_outputs_
|
|
364
|
+
est_i.n_classes_ = n_classes_
|
|
365
|
+
tree_i_state_class = self._get_tree_state(
|
|
366
|
+
self._onedal_estimator._onedal_model, i, n_classes_
|
|
367
|
+
)
|
|
368
|
+
tree_i_state_dict = {
|
|
369
|
+
"max_depth": tree_i_state_class.max_depth,
|
|
370
|
+
"node_count": tree_i_state_class.node_count,
|
|
371
|
+
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
372
|
+
"values": tree_i_state_class.value_ar,
|
|
373
|
+
}
|
|
374
|
+
est_i.tree_ = Tree(
|
|
375
|
+
self.n_features_in_,
|
|
376
|
+
np.array([n_classes_], dtype=np.intp),
|
|
377
|
+
self.n_outputs_,
|
|
378
|
+
)
|
|
379
|
+
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
380
|
+
estimators_.append(est_i)
|
|
381
|
+
|
|
382
|
+
self._cached_estimators_ = estimators_
|
|
383
|
+
|
|
384
|
+
if sklearn_check_version("1.0"):
|
|
385
|
+
|
|
386
|
+
@deprecated(
|
|
387
|
+
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
388
|
+
"removed in 1.2. Use `n_features_in_` instead."
|
|
389
|
+
)
|
|
390
|
+
@property
|
|
391
|
+
def n_features_(self):
|
|
392
|
+
return self.n_features_in_
|
|
393
|
+
|
|
394
|
+
if not sklearn_check_version("1.2"):
|
|
395
|
+
|
|
396
|
+
@property
|
|
397
|
+
def base_estimator(self):
|
|
398
|
+
return self.estimator
|
|
399
|
+
|
|
400
|
+
@base_estimator.setter
|
|
401
|
+
def base_estimator(self, estimator):
|
|
402
|
+
self.estimator = estimator
|
|
403
|
+
|
|
214
404
|
|
|
215
|
-
class ForestClassifier(sklearn_ForestClassifier,
|
|
405
|
+
class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
216
406
|
# Surprisingly, even though scikit-learn warns against using
|
|
217
407
|
# their ForestClassifier directly, it actually has a more stable
|
|
218
408
|
# API than the user-facing objects (over time). If they change it
|
|
219
409
|
# significantly at some point then this may need to be versioned.
|
|
220
410
|
|
|
411
|
+
_err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
|
|
412
|
+
_get_tree_state = staticmethod(get_tree_state_cls)
|
|
413
|
+
|
|
221
414
|
def __init__(
|
|
222
415
|
self,
|
|
223
416
|
estimator,
|
|
@@ -247,16 +440,27 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
247
440
|
max_samples=max_samples,
|
|
248
441
|
)
|
|
249
442
|
|
|
250
|
-
# The
|
|
251
|
-
|
|
252
|
-
if
|
|
253
|
-
self.
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
443
|
+
# The estimator is checked against the class attribute for conformance.
|
|
444
|
+
# This should only trigger if the user uses this class directly.
|
|
445
|
+
if (
|
|
446
|
+
self.estimator.__class__ == DecisionTreeClassifier
|
|
447
|
+
and self._onedal_factory != onedal_RandomForestClassifier
|
|
448
|
+
):
|
|
449
|
+
self._onedal_factory = onedal_RandomForestClassifier
|
|
450
|
+
elif (
|
|
451
|
+
self.estimator.__class__ == ExtraTreeClassifier
|
|
452
|
+
and self._onedal_factory != onedal_ExtraTreesClassifier
|
|
453
|
+
):
|
|
454
|
+
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
455
|
+
|
|
456
|
+
if self._onedal_factory is None:
|
|
457
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
458
|
+
|
|
459
|
+
def _estimators_(self):
|
|
460
|
+
super()._estimators_()
|
|
461
|
+
classes_ = self.classes_[0]
|
|
462
|
+
for est in self._cached_estimators_:
|
|
463
|
+
est.classes_ = classes_
|
|
260
464
|
|
|
261
465
|
def fit(self, X, y, sample_weight=None):
|
|
262
466
|
dispatch(
|
|
@@ -292,17 +496,17 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
292
496
|
or not self.oob_score,
|
|
293
497
|
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
294
498
|
),
|
|
295
|
-
(
|
|
296
|
-
(
|
|
297
|
-
self.ccp_alpha == 0.0,
|
|
298
|
-
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
299
|
-
),
|
|
499
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
300
500
|
(
|
|
301
501
|
self.criterion == "gini",
|
|
302
502
|
f"'{self.criterion}' criterion is not supported. "
|
|
303
503
|
"Only 'gini' criterion is supported.",
|
|
304
504
|
),
|
|
305
|
-
(
|
|
505
|
+
(
|
|
506
|
+
self.ccp_alpha == 0.0,
|
|
507
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
508
|
+
),
|
|
509
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
306
510
|
(
|
|
307
511
|
self.n_estimators <= 6024,
|
|
308
512
|
"More than 6024 estimators is not supported.",
|
|
@@ -310,12 +514,46 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
310
514
|
]
|
|
311
515
|
)
|
|
312
516
|
|
|
517
|
+
if self.bootstrap:
|
|
518
|
+
patching_status.and_conditions(
|
|
519
|
+
[
|
|
520
|
+
(
|
|
521
|
+
self.class_weight != "balanced_subsample",
|
|
522
|
+
"'balanced_subsample' for class_weight is not supported",
|
|
523
|
+
)
|
|
524
|
+
]
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
528
|
+
try:
|
|
529
|
+
_assert_all_finite(X)
|
|
530
|
+
input_is_finite = True
|
|
531
|
+
except ValueError:
|
|
532
|
+
input_is_finite = False
|
|
533
|
+
patching_status.and_conditions(
|
|
534
|
+
[
|
|
535
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
536
|
+
(
|
|
537
|
+
self.monotonic_cst is None,
|
|
538
|
+
"Monotonicity constraints are not supported.",
|
|
539
|
+
),
|
|
540
|
+
]
|
|
541
|
+
)
|
|
542
|
+
|
|
313
543
|
if patching_status.get_status():
|
|
314
|
-
if sklearn_check_version("
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
544
|
+
if sklearn_check_version("0.24"):
|
|
545
|
+
X, y = self._validate_data(
|
|
546
|
+
X,
|
|
547
|
+
y,
|
|
548
|
+
multi_output=True,
|
|
549
|
+
accept_sparse=True,
|
|
550
|
+
dtype=[np.float64, np.float32],
|
|
551
|
+
force_all_finite=False,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
555
|
+
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
556
|
+
|
|
319
557
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
320
558
|
warnings.warn(
|
|
321
559
|
"A column-vector y was passed when a 1d array was"
|
|
@@ -324,11 +562,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
324
562
|
DataConversionWarning,
|
|
325
563
|
stacklevel=2,
|
|
326
564
|
)
|
|
327
|
-
check_consistent_length(X, y)
|
|
328
565
|
|
|
329
566
|
if y.ndim == 1:
|
|
330
567
|
y = np.reshape(y, (-1, 1))
|
|
568
|
+
|
|
331
569
|
self.n_outputs_ = y.shape[1]
|
|
570
|
+
|
|
332
571
|
patching_status.and_conditions(
|
|
333
572
|
[
|
|
334
573
|
(
|
|
@@ -343,30 +582,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
343
582
|
)
|
|
344
583
|
# TODO: Fix to support integers as input
|
|
345
584
|
|
|
346
|
-
n_samples
|
|
347
|
-
if isinstance(self.max_samples, numbers.Integral):
|
|
348
|
-
if not sklearn_check_version("1.2"):
|
|
349
|
-
if not (1 <= self.max_samples <= n_samples):
|
|
350
|
-
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
351
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
352
|
-
else:
|
|
353
|
-
if self.max_samples > n_samples:
|
|
354
|
-
msg = "`max_samples` must be <= n_samples={} but got value {}"
|
|
355
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
356
|
-
elif isinstance(self.max_samples, numbers.Real):
|
|
357
|
-
if sklearn_check_version("1.2"):
|
|
358
|
-
pass
|
|
359
|
-
elif sklearn_check_version("1.0"):
|
|
360
|
-
if not (0 < float(self.max_samples) <= 1):
|
|
361
|
-
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
|
|
362
|
-
raise ValueError(msg.format(self.max_samples))
|
|
363
|
-
else:
|
|
364
|
-
if not (0 < float(self.max_samples) < 1):
|
|
365
|
-
msg = "`max_samples` must be in range (0, 1) but got value {}"
|
|
366
|
-
raise ValueError(msg.format(self.max_samples))
|
|
367
|
-
elif self.max_samples is not None:
|
|
368
|
-
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
369
|
-
raise TypeError(msg.format(type(self.max_samples)))
|
|
585
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
370
586
|
|
|
371
587
|
if not self.bootstrap and self.max_samples is not None:
|
|
372
588
|
raise ValueError(
|
|
@@ -375,6 +591,17 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
375
591
|
"`max_sample=None`."
|
|
376
592
|
)
|
|
377
593
|
|
|
594
|
+
if (
|
|
595
|
+
patching_status.get_status()
|
|
596
|
+
and (self.random_state is not None)
|
|
597
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
598
|
+
):
|
|
599
|
+
warnings.warn(
|
|
600
|
+
"Setting 'random_state' value is not supported. "
|
|
601
|
+
"State set by oneDAL to default value (777).",
|
|
602
|
+
RuntimeWarning,
|
|
603
|
+
)
|
|
604
|
+
|
|
378
605
|
return patching_status, X, y, sample_weight
|
|
379
606
|
|
|
380
607
|
@wrap_output_data
|
|
@@ -423,124 +650,57 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
423
650
|
predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
|
|
424
651
|
predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
|
|
425
652
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
"
|
|
430
|
-
"removed in 1.2. Use `n_features_in_` instead."
|
|
431
|
-
)
|
|
432
|
-
@property
|
|
433
|
-
def n_features_(self):
|
|
434
|
-
return self.n_features_in_
|
|
435
|
-
|
|
436
|
-
def _estimators_(self):
|
|
437
|
-
# _estimators_ should only be called if _onedal_model exists
|
|
438
|
-
check_is_fitted(self, "_onedal_model")
|
|
439
|
-
classes_ = self.classes_[0]
|
|
440
|
-
n_classes_ = (
|
|
441
|
-
self.n_classes_ if isinstance(self.n_classes_, int) else self.n_classes_[0]
|
|
653
|
+
def _onedal_cpu_supported(self, method_name, *data):
|
|
654
|
+
class_name = self.__class__.__name__
|
|
655
|
+
patching_status = PatchingConditionsChain(
|
|
656
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
442
657
|
)
|
|
443
|
-
# convert model to estimators
|
|
444
|
-
params = {
|
|
445
|
-
"criterion": self.criterion,
|
|
446
|
-
"max_depth": self.max_depth,
|
|
447
|
-
"min_samples_split": self.min_samples_split,
|
|
448
|
-
"min_samples_leaf": self.min_samples_leaf,
|
|
449
|
-
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
450
|
-
"max_features": self.max_features,
|
|
451
|
-
"max_leaf_nodes": self.max_leaf_nodes,
|
|
452
|
-
"min_impurity_decrease": self.min_impurity_decrease,
|
|
453
|
-
"random_state": None,
|
|
454
|
-
}
|
|
455
|
-
if not sklearn_check_version("1.0"):
|
|
456
|
-
params["min_impurity_split"] = self.min_impurity_split
|
|
457
|
-
est = self._estimator.__class__(**params)
|
|
458
|
-
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
459
|
-
# oneAPI Data Analytics Library solution
|
|
460
|
-
estimators_ = []
|
|
461
658
|
|
|
462
|
-
|
|
659
|
+
if method_name == "fit":
|
|
660
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
661
|
+
patching_status, *data
|
|
662
|
+
)
|
|
463
663
|
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
664
|
+
patching_status.and_conditions(
|
|
665
|
+
[
|
|
666
|
+
(
|
|
667
|
+
daal_check_version((2023, "P", 200))
|
|
668
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
669
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
670
|
+
),
|
|
671
|
+
(
|
|
672
|
+
not sp.issparse(sample_weight),
|
|
673
|
+
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
674
|
+
),
|
|
675
|
+
]
|
|
468
676
|
)
|
|
469
|
-
if sklearn_check_version("1.0"):
|
|
470
|
-
est_i.n_features_in_ = self.n_features_in_
|
|
471
|
-
else:
|
|
472
|
-
est_i.n_features_ = self.n_features_in_
|
|
473
|
-
est_i.n_outputs_ = self.n_outputs_
|
|
474
|
-
est_i.classes_ = classes_
|
|
475
|
-
est_i.n_classes_ = n_classes_
|
|
476
|
-
tree_i_state_class = get_tree_state_cls(self._onedal_model, i, n_classes_)
|
|
477
|
-
tree_i_state_dict = {
|
|
478
|
-
"max_depth": tree_i_state_class.max_depth,
|
|
479
|
-
"node_count": tree_i_state_class.node_count,
|
|
480
|
-
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
481
|
-
"values": tree_i_state_class.value_ar,
|
|
482
|
-
}
|
|
483
|
-
est_i.tree_ = Tree(
|
|
484
|
-
self.n_features_in_,
|
|
485
|
-
np.array([n_classes_], dtype=np.intp),
|
|
486
|
-
self.n_outputs_,
|
|
487
|
-
)
|
|
488
|
-
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
489
|
-
estimators_.append(est_i)
|
|
490
|
-
|
|
491
|
-
self._cached_estimators_ = estimators_
|
|
492
|
-
|
|
493
|
-
def _onedal_cpu_supported(self, method_name, *data):
|
|
494
|
-
class_name = self.__class__.__name__
|
|
495
|
-
patching_status = PatchingConditionsChain(
|
|
496
|
-
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
497
|
-
)
|
|
498
|
-
|
|
499
|
-
if method_name == "fit":
|
|
500
|
-
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
501
|
-
patching_status, *data
|
|
502
|
-
)
|
|
503
|
-
|
|
504
|
-
patching_status.and_conditions(
|
|
505
|
-
[
|
|
506
|
-
(
|
|
507
|
-
daal_check_version((2023, "P", 200))
|
|
508
|
-
or self._estimator.__class__ == DecisionTreeClassifier,
|
|
509
|
-
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
510
|
-
),
|
|
511
|
-
(
|
|
512
|
-
not sp.issparse(sample_weight),
|
|
513
|
-
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
514
|
-
),
|
|
515
|
-
]
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
if (
|
|
519
|
-
patching_status.get_status()
|
|
520
|
-
and (self.random_state is not None)
|
|
521
|
-
and (not daal_check_version((2024, "P", 0)))
|
|
522
|
-
):
|
|
523
|
-
warnings.warn(
|
|
524
|
-
"Setting 'random_state' value is not supported. "
|
|
525
|
-
"State set by oneDAL to default value (777).",
|
|
526
|
-
RuntimeWarning,
|
|
527
|
-
)
|
|
528
677
|
|
|
529
678
|
elif method_name in ["predict", "predict_proba"]:
|
|
530
679
|
X = data[0]
|
|
531
680
|
|
|
532
681
|
patching_status.and_conditions(
|
|
533
682
|
[
|
|
534
|
-
(hasattr(self, "
|
|
683
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
535
684
|
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
536
685
|
(self.warm_start is False, "Warm start is not supported."),
|
|
537
686
|
(
|
|
538
687
|
daal_check_version((2023, "P", 100))
|
|
539
|
-
or self.
|
|
688
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
540
689
|
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
541
690
|
),
|
|
542
691
|
]
|
|
543
692
|
)
|
|
693
|
+
|
|
694
|
+
if method_name == "predict_proba":
|
|
695
|
+
patching_status.and_conditions(
|
|
696
|
+
[
|
|
697
|
+
(
|
|
698
|
+
daal_check_version((2021, "P", 400)),
|
|
699
|
+
"oneDAL version is lower than 2021.4.",
|
|
700
|
+
)
|
|
701
|
+
]
|
|
702
|
+
)
|
|
703
|
+
|
|
544
704
|
if hasattr(self, "n_outputs_"):
|
|
545
705
|
patching_status.and_conditions(
|
|
546
706
|
[
|
|
@@ -573,24 +733,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
573
733
|
[
|
|
574
734
|
(
|
|
575
735
|
daal_check_version((2023, "P", 100))
|
|
576
|
-
or self.
|
|
736
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
577
737
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
578
738
|
),
|
|
579
739
|
(sample_weight is not None, "sample_weight is not supported."),
|
|
580
740
|
]
|
|
581
741
|
)
|
|
582
742
|
|
|
583
|
-
if (
|
|
584
|
-
patching_status.get_status()
|
|
585
|
-
and (self.random_state is not None)
|
|
586
|
-
and (not daal_check_version((2024, "P", 0)))
|
|
587
|
-
):
|
|
588
|
-
warnings.warn(
|
|
589
|
-
"Setting 'random_state' value is not supported. "
|
|
590
|
-
"State set by oneDAL to default value (777).",
|
|
591
|
-
RuntimeWarning,
|
|
592
|
-
)
|
|
593
|
-
|
|
594
743
|
elif method_name in ["predict", "predict_proba"]:
|
|
595
744
|
X = data[0]
|
|
596
745
|
|
|
@@ -625,113 +774,13 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
625
774
|
|
|
626
775
|
return patching_status
|
|
627
776
|
|
|
628
|
-
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
629
|
-
if sklearn_check_version("1.2"):
|
|
630
|
-
X, y = self._validate_data(
|
|
631
|
-
X,
|
|
632
|
-
y,
|
|
633
|
-
multi_output=False,
|
|
634
|
-
accept_sparse=False,
|
|
635
|
-
dtype=[np.float64, np.float32],
|
|
636
|
-
)
|
|
637
|
-
else:
|
|
638
|
-
X, y = check_X_y(
|
|
639
|
-
X,
|
|
640
|
-
y,
|
|
641
|
-
accept_sparse=False,
|
|
642
|
-
dtype=[np.float64, np.float32],
|
|
643
|
-
multi_output=False,
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
if sample_weight is not None:
|
|
647
|
-
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
648
|
-
|
|
649
|
-
y = np.atleast_1d(y)
|
|
650
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
651
|
-
warnings.warn(
|
|
652
|
-
"A column-vector y was passed when a 1d array was"
|
|
653
|
-
" expected. Please change the shape of y to "
|
|
654
|
-
"(n_samples,), for example using ravel().",
|
|
655
|
-
DataConversionWarning,
|
|
656
|
-
stacklevel=2,
|
|
657
|
-
)
|
|
658
|
-
if y.ndim == 1:
|
|
659
|
-
# reshape is necessary to preserve the data contiguity against vs
|
|
660
|
-
# [:, np.newaxis] that does not.
|
|
661
|
-
y = np.reshape(y, (-1, 1))
|
|
662
|
-
|
|
663
|
-
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
664
|
-
|
|
665
|
-
n_classes_ = self.n_classes_[0]
|
|
666
|
-
self.n_features_in_ = X.shape[1]
|
|
667
|
-
if not sklearn_check_version("1.0"):
|
|
668
|
-
self.n_features_ = self.n_features_in_
|
|
669
|
-
|
|
670
|
-
if expanded_class_weight is not None:
|
|
671
|
-
if sample_weight is not None:
|
|
672
|
-
sample_weight = sample_weight * expanded_class_weight
|
|
673
|
-
else:
|
|
674
|
-
sample_weight = expanded_class_weight
|
|
675
|
-
if sample_weight is not None:
|
|
676
|
-
sample_weight = [sample_weight]
|
|
677
|
-
|
|
678
|
-
if n_classes_ < 2:
|
|
679
|
-
raise ValueError("Training data only contain information about one class.")
|
|
680
|
-
|
|
681
|
-
if self.oob_score:
|
|
682
|
-
err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
|
|
683
|
-
else:
|
|
684
|
-
err = "none"
|
|
685
|
-
|
|
686
|
-
onedal_params = {
|
|
687
|
-
"n_estimators": self.n_estimators,
|
|
688
|
-
"criterion": self.criterion,
|
|
689
|
-
"max_depth": self.max_depth,
|
|
690
|
-
"min_samples_split": self.min_samples_split,
|
|
691
|
-
"min_samples_leaf": self.min_samples_leaf,
|
|
692
|
-
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
693
|
-
"max_features": self.max_features,
|
|
694
|
-
"max_leaf_nodes": self.max_leaf_nodes,
|
|
695
|
-
"min_impurity_decrease": self.min_impurity_decrease,
|
|
696
|
-
"bootstrap": self.bootstrap,
|
|
697
|
-
"oob_score": self.oob_score,
|
|
698
|
-
"n_jobs": self.n_jobs,
|
|
699
|
-
"random_state": self.random_state,
|
|
700
|
-
"verbose": self.verbose,
|
|
701
|
-
"warm_start": self.warm_start,
|
|
702
|
-
"error_metric_mode": err,
|
|
703
|
-
"variable_importance_mode": "mdi",
|
|
704
|
-
"class_weight": self.class_weight,
|
|
705
|
-
"max_bins": self.max_bins,
|
|
706
|
-
"min_bin_size": self.min_bin_size,
|
|
707
|
-
"max_samples": self.max_samples,
|
|
708
|
-
}
|
|
709
|
-
if daal_check_version((2023, "P", 101)):
|
|
710
|
-
onedal_params["splitter_mode"] = "random"
|
|
711
|
-
if not sklearn_check_version("1.0"):
|
|
712
|
-
onedal_params["min_impurity_split"] = self.min_impurity_split
|
|
713
|
-
else:
|
|
714
|
-
onedal_params["min_impurity_split"] = None
|
|
715
|
-
|
|
716
|
-
# Lazy evaluation of estimators_
|
|
717
|
-
self._cached_estimators_ = None
|
|
718
|
-
|
|
719
|
-
# Compute
|
|
720
|
-
self._onedal_estimator = self._onedal_classifier(**onedal_params)
|
|
721
|
-
self._onedal_estimator.fit(X, np.squeeze(y), sample_weight, queue=queue)
|
|
722
|
-
|
|
723
|
-
self._save_attributes()
|
|
724
|
-
if sklearn_check_version("1.2"):
|
|
725
|
-
self._estimator = ExtraTreeClassifier()
|
|
726
|
-
|
|
727
|
-
# Decapsulate classes_ attributes
|
|
728
|
-
self.n_classes_ = self.n_classes_[0]
|
|
729
|
-
self.classes_ = self.classes_[0]
|
|
730
|
-
return self
|
|
731
|
-
|
|
732
777
|
def _onedal_predict(self, X, queue=None):
|
|
733
|
-
X = check_array(
|
|
734
|
-
|
|
778
|
+
X = check_array(
|
|
779
|
+
X,
|
|
780
|
+
dtype=[np.float64, np.float32],
|
|
781
|
+
force_all_finite=False,
|
|
782
|
+
) # Warning, order of dtype matters
|
|
783
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
735
784
|
|
|
736
785
|
if sklearn_check_version("1.0"):
|
|
737
786
|
self._check_feature_names(X, reset=False)
|
|
@@ -740,8 +789,8 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
740
789
|
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
741
790
|
|
|
742
791
|
def _onedal_predict_proba(self, X, queue=None):
|
|
743
|
-
X = check_array(X, dtype=[np.float64, np.float32])
|
|
744
|
-
check_is_fitted(self, "
|
|
792
|
+
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
793
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
745
794
|
|
|
746
795
|
if sklearn_check_version("0.23"):
|
|
747
796
|
self._check_n_features(X, reset=False)
|
|
@@ -750,7 +799,10 @@ class ForestClassifier(sklearn_ForestClassifier, BaseTree):
|
|
|
750
799
|
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
751
800
|
|
|
752
801
|
|
|
753
|
-
class ForestRegressor(sklearn_ForestRegressor,
|
|
802
|
+
class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
803
|
+
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
804
|
+
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
805
|
+
|
|
754
806
|
def __init__(
|
|
755
807
|
self,
|
|
756
808
|
estimator,
|
|
@@ -778,66 +830,21 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
778
830
|
max_samples=max_samples,
|
|
779
831
|
)
|
|
780
832
|
|
|
781
|
-
# The splitter is
|
|
782
|
-
|
|
783
|
-
if
|
|
784
|
-
self.
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
# _estimators_ should only be called if _onedal_model exists
|
|
794
|
-
check_is_fitted(self, "_onedal_model")
|
|
795
|
-
# convert model to estimators
|
|
796
|
-
params = {
|
|
797
|
-
"criterion": self.criterion,
|
|
798
|
-
"max_depth": self.max_depth,
|
|
799
|
-
"min_samples_split": self.min_samples_split,
|
|
800
|
-
"min_samples_leaf": self.min_samples_leaf,
|
|
801
|
-
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
802
|
-
"max_features": self.max_features,
|
|
803
|
-
"max_leaf_nodes": self.max_leaf_nodes,
|
|
804
|
-
"min_impurity_decrease": self.min_impurity_decrease,
|
|
805
|
-
"random_state": None,
|
|
806
|
-
}
|
|
807
|
-
if not sklearn_check_version("1.0"):
|
|
808
|
-
params["min_impurity_split"] = self.min_impurity_split
|
|
809
|
-
est = self._estimator.__class__(**params)
|
|
810
|
-
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
811
|
-
# oneAPI Data Analytics Library solution
|
|
812
|
-
estimators_ = []
|
|
813
|
-
random_state_checked = check_random_state(self.random_state)
|
|
814
|
-
|
|
815
|
-
for i in range(self.n_estimators):
|
|
816
|
-
est_i = clone(est)
|
|
817
|
-
est_i.set_params(
|
|
818
|
-
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
819
|
-
)
|
|
820
|
-
if sklearn_check_version("1.0"):
|
|
821
|
-
est_i.n_features_in_ = self.n_features_in_
|
|
822
|
-
else:
|
|
823
|
-
est_i.n_features_ = self.n_features_in_
|
|
824
|
-
est_i.n_classes_ = 1
|
|
825
|
-
est_i.n_outputs_ = self.n_outputs_
|
|
826
|
-
tree_i_state_class = get_tree_state_reg(self._onedal_model, i)
|
|
827
|
-
tree_i_state_dict = {
|
|
828
|
-
"max_depth": tree_i_state_class.max_depth,
|
|
829
|
-
"node_count": tree_i_state_class.node_count,
|
|
830
|
-
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
831
|
-
"values": tree_i_state_class.value_ar,
|
|
832
|
-
}
|
|
833
|
-
|
|
834
|
-
est_i.tree_ = Tree(
|
|
835
|
-
self.n_features_in_, np.array([1], dtype=np.intp), self.n_outputs_
|
|
836
|
-
)
|
|
837
|
-
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
838
|
-
estimators_.append(est_i)
|
|
833
|
+
# The splitter is checked against the class attribute for conformance
|
|
834
|
+
# This should only trigger if the user uses this class directly.
|
|
835
|
+
if (
|
|
836
|
+
self.estimator.__class__ == DecisionTreeRegressor
|
|
837
|
+
and self._onedal_factory != onedal_RandomForestRegressor
|
|
838
|
+
):
|
|
839
|
+
self._onedal_factory = onedal_RandomForestRegressor
|
|
840
|
+
elif (
|
|
841
|
+
self.estimator.__class__ == ExtraTreeRegressor
|
|
842
|
+
and self._onedal_factory != onedal_ExtraTreesRegressor
|
|
843
|
+
):
|
|
844
|
+
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
839
845
|
|
|
840
|
-
self.
|
|
846
|
+
if self._onedal_factory is None:
|
|
847
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
841
848
|
|
|
842
849
|
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
843
850
|
if sp.issparse(y):
|
|
@@ -885,12 +892,35 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
885
892
|
]
|
|
886
893
|
)
|
|
887
894
|
|
|
895
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
896
|
+
try:
|
|
897
|
+
_assert_all_finite(X)
|
|
898
|
+
input_is_finite = True
|
|
899
|
+
except ValueError:
|
|
900
|
+
input_is_finite = False
|
|
901
|
+
patching_status.and_conditions(
|
|
902
|
+
[
|
|
903
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
904
|
+
(
|
|
905
|
+
self.monotonic_cst is None,
|
|
906
|
+
"Monotonicity constraints are not supported.",
|
|
907
|
+
),
|
|
908
|
+
]
|
|
909
|
+
)
|
|
910
|
+
|
|
888
911
|
if patching_status.get_status():
|
|
889
|
-
if sklearn_check_version("
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
912
|
+
if sklearn_check_version("0.24"):
|
|
913
|
+
X, y = self._validate_data(
|
|
914
|
+
X,
|
|
915
|
+
y,
|
|
916
|
+
multi_output=True,
|
|
917
|
+
accept_sparse=True,
|
|
918
|
+
dtype=[np.float64, np.float32],
|
|
919
|
+
force_all_finite=False,
|
|
920
|
+
)
|
|
921
|
+
else:
|
|
922
|
+
X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
|
|
923
|
+
y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
|
|
894
924
|
|
|
895
925
|
if y.ndim == 2 and y.shape[1] == 1:
|
|
896
926
|
warnings.warn(
|
|
@@ -901,15 +931,13 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
901
931
|
stacklevel=2,
|
|
902
932
|
)
|
|
903
933
|
|
|
904
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype)
|
|
905
|
-
check_consistent_length(X, y)
|
|
906
|
-
|
|
907
934
|
if y.ndim == 1:
|
|
908
935
|
# reshape is necessary to preserve the data contiguity against vs
|
|
909
936
|
# [:, np.newaxis] that does not.
|
|
910
937
|
y = np.reshape(y, (-1, 1))
|
|
911
938
|
|
|
912
939
|
self.n_outputs_ = y.shape[1]
|
|
940
|
+
|
|
913
941
|
patching_status.and_conditions(
|
|
914
942
|
[
|
|
915
943
|
(
|
|
@@ -919,30 +947,8 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
919
947
|
]
|
|
920
948
|
)
|
|
921
949
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
if not sklearn_check_version("1.2"):
|
|
925
|
-
if not (1 <= self.max_samples <= n_samples):
|
|
926
|
-
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
927
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
928
|
-
else:
|
|
929
|
-
if self.max_samples > n_samples:
|
|
930
|
-
msg = "`max_samples` must be <= n_samples={} but got value {}"
|
|
931
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
932
|
-
elif isinstance(self.max_samples, numbers.Real):
|
|
933
|
-
if sklearn_check_version("1.2"):
|
|
934
|
-
pass
|
|
935
|
-
elif sklearn_check_version("1.0"):
|
|
936
|
-
if not (0 < float(self.max_samples) <= 1):
|
|
937
|
-
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
|
|
938
|
-
raise ValueError(msg.format(self.max_samples))
|
|
939
|
-
else:
|
|
940
|
-
if not (0 < float(self.max_samples) < 1):
|
|
941
|
-
msg = "`max_samples` must be in range (0, 1) but got value {}"
|
|
942
|
-
raise ValueError(msg.format(self.max_samples))
|
|
943
|
-
elif self.max_samples is not None:
|
|
944
|
-
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
945
|
-
raise TypeError(msg.format(type(self.max_samples)))
|
|
950
|
+
# Sklearn function used for doing checks on max_samples attribute
|
|
951
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
946
952
|
|
|
947
953
|
if not self.bootstrap and self.max_samples is not None:
|
|
948
954
|
raise ValueError(
|
|
@@ -951,6 +957,17 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
951
957
|
"`max_sample=None`."
|
|
952
958
|
)
|
|
953
959
|
|
|
960
|
+
if (
|
|
961
|
+
patching_status.get_status()
|
|
962
|
+
and (self.random_state is not None)
|
|
963
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
964
|
+
):
|
|
965
|
+
warnings.warn(
|
|
966
|
+
"Setting 'random_state' value is not supported. "
|
|
967
|
+
"State set by oneDAL to default value (777).",
|
|
968
|
+
RuntimeWarning,
|
|
969
|
+
)
|
|
970
|
+
|
|
954
971
|
return patching_status, X, y, sample_weight
|
|
955
972
|
|
|
956
973
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
@@ -968,7 +985,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
968
985
|
[
|
|
969
986
|
(
|
|
970
987
|
daal_check_version((2023, "P", 200))
|
|
971
|
-
or self.
|
|
988
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
972
989
|
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
973
990
|
),
|
|
974
991
|
(
|
|
@@ -978,28 +995,17 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
978
995
|
]
|
|
979
996
|
)
|
|
980
997
|
|
|
981
|
-
|
|
982
|
-
patching_status.get_status()
|
|
983
|
-
and (self.random_state is not None)
|
|
984
|
-
and (not daal_check_version((2024, "P", 0)))
|
|
985
|
-
):
|
|
986
|
-
warnings.warn(
|
|
987
|
-
"Setting 'random_state' value is not supported. "
|
|
988
|
-
"State set by oneDAL to default value (777).",
|
|
989
|
-
RuntimeWarning,
|
|
990
|
-
)
|
|
991
|
-
|
|
992
|
-
elif method_name in ["predict", "predict_proba"]:
|
|
998
|
+
elif method_name == "predict":
|
|
993
999
|
X = data[0]
|
|
994
1000
|
|
|
995
1001
|
patching_status.and_conditions(
|
|
996
1002
|
[
|
|
997
|
-
(hasattr(self, "
|
|
1003
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
998
1004
|
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
999
1005
|
(self.warm_start is False, "Warm start is not supported."),
|
|
1000
1006
|
(
|
|
1001
1007
|
daal_check_version((2023, "P", 200))
|
|
1002
|
-
or self.
|
|
1008
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1003
1009
|
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
1004
1010
|
),
|
|
1005
1011
|
]
|
|
@@ -1013,8 +1019,6 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
1013
1019
|
),
|
|
1014
1020
|
]
|
|
1015
1021
|
)
|
|
1016
|
-
else:
|
|
1017
|
-
dal_ready = False
|
|
1018
1022
|
|
|
1019
1023
|
else:
|
|
1020
1024
|
raise RuntimeError(
|
|
@@ -1038,35 +1042,24 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
1038
1042
|
[
|
|
1039
1043
|
(
|
|
1040
1044
|
daal_check_version((2023, "P", 100))
|
|
1041
|
-
or self.
|
|
1045
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1042
1046
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1043
1047
|
),
|
|
1044
1048
|
(sample_weight is not None, "sample_weight is not supported."),
|
|
1045
1049
|
]
|
|
1046
1050
|
)
|
|
1047
1051
|
|
|
1048
|
-
if (
|
|
1049
|
-
patching_status.get_status()
|
|
1050
|
-
and (self.random_state is not None)
|
|
1051
|
-
and (not daal_check_version((2024, "P", 0)))
|
|
1052
|
-
):
|
|
1053
|
-
warnings.warn(
|
|
1054
|
-
"Setting 'random_state' value is not supported. "
|
|
1055
|
-
"State set by oneDAL to default value (777).",
|
|
1056
|
-
RuntimeWarning,
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
1052
|
elif method_name == "predict":
|
|
1060
1053
|
X = data[0]
|
|
1061
1054
|
|
|
1062
1055
|
patching_status.and_conditions(
|
|
1063
1056
|
[
|
|
1064
|
-
(hasattr(self, "
|
|
1057
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
1065
1058
|
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1066
1059
|
(self.warm_start is False, "Warm start is not supported."),
|
|
1067
1060
|
(
|
|
1068
1061
|
daal_check_version((2023, "P", 100))
|
|
1069
|
-
or self.
|
|
1062
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1070
1063
|
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1071
1064
|
),
|
|
1072
1065
|
]
|
|
@@ -1088,76 +1081,11 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
1088
1081
|
|
|
1089
1082
|
return patching_status
|
|
1090
1083
|
|
|
1091
|
-
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
1092
|
-
if sp.issparse(y):
|
|
1093
|
-
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
1094
|
-
if sklearn_check_version("1.2"):
|
|
1095
|
-
self._validate_params()
|
|
1096
|
-
else:
|
|
1097
|
-
self._check_parameters()
|
|
1098
|
-
if sample_weight is not None:
|
|
1099
|
-
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
1100
|
-
if sklearn_check_version("1.0"):
|
|
1101
|
-
self._check_feature_names(X, reset=True)
|
|
1102
|
-
X = check_array(X, dtype=[np.float64, np.float32])
|
|
1103
|
-
y = np.atleast_1d(np.asarray(y))
|
|
1104
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
1105
|
-
warnings.warn(
|
|
1106
|
-
"A column-vector y was passed when a 1d array was"
|
|
1107
|
-
" expected. Please change the shape of y to "
|
|
1108
|
-
"(n_samples,), for example using ravel().",
|
|
1109
|
-
DataConversionWarning,
|
|
1110
|
-
stacklevel=2,
|
|
1111
|
-
)
|
|
1112
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype)
|
|
1113
|
-
check_consistent_length(X, y)
|
|
1114
|
-
self.n_features_in_ = X.shape[1]
|
|
1115
|
-
if not sklearn_check_version("1.0"):
|
|
1116
|
-
self.n_features_ = self.n_features_in_
|
|
1117
|
-
|
|
1118
|
-
if self.oob_score:
|
|
1119
|
-
err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
1120
|
-
else:
|
|
1121
|
-
err = "none"
|
|
1122
|
-
|
|
1123
|
-
onedal_params = {
|
|
1124
|
-
"n_estimators": self.n_estimators,
|
|
1125
|
-
"criterion": self.criterion,
|
|
1126
|
-
"max_depth": self.max_depth,
|
|
1127
|
-
"min_samples_split": self.min_samples_split,
|
|
1128
|
-
"min_samples_leaf": self.min_samples_leaf,
|
|
1129
|
-
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
1130
|
-
"max_features": self.max_features,
|
|
1131
|
-
"max_leaf_nodes": self.max_leaf_nodes,
|
|
1132
|
-
"min_impurity_decrease": self.min_impurity_decrease,
|
|
1133
|
-
"bootstrap": self.bootstrap,
|
|
1134
|
-
"oob_score": self.oob_score,
|
|
1135
|
-
"n_jobs": self.n_jobs,
|
|
1136
|
-
"random_state": self.random_state,
|
|
1137
|
-
"verbose": self.verbose,
|
|
1138
|
-
"warm_start": self.warm_start,
|
|
1139
|
-
"error_metric_mode": err,
|
|
1140
|
-
"variable_importance_mode": "mdi",
|
|
1141
|
-
"max_samples": self.max_samples,
|
|
1142
|
-
}
|
|
1143
|
-
if daal_check_version((2023, "P", 101)):
|
|
1144
|
-
onedal_params["splitter_mode"] = "random"
|
|
1145
|
-
|
|
1146
|
-
# Lazy evaluation of estimators_
|
|
1147
|
-
self._cached_estimators_ = None
|
|
1148
|
-
|
|
1149
|
-
self._onedal_estimator = self._onedal_regressor(**onedal_params)
|
|
1150
|
-
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
|
|
1151
|
-
|
|
1152
|
-
self._save_attributes()
|
|
1153
|
-
if sklearn_check_version("1.2"):
|
|
1154
|
-
self._estimator = ExtraTreeRegressor()
|
|
1155
|
-
|
|
1156
|
-
return self
|
|
1157
|
-
|
|
1158
1084
|
def _onedal_predict(self, X, queue=None):
|
|
1159
|
-
X = check_array(
|
|
1160
|
-
|
|
1085
|
+
X = check_array(
|
|
1086
|
+
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1087
|
+
) # Warning, order of dtype matters
|
|
1088
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
1161
1089
|
|
|
1162
1090
|
if sklearn_check_version("1.0"):
|
|
1163
1091
|
self._check_feature_names(X, reset=False)
|
|
@@ -1193,28 +1121,85 @@ class ForestRegressor(sklearn_ForestRegressor, BaseTree):
|
|
|
1193
1121
|
fit.__doc__ = sklearn_ForestRegressor.fit.__doc__
|
|
1194
1122
|
predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
|
|
1195
1123
|
|
|
1196
|
-
if sklearn_check_version("1.0"):
|
|
1197
|
-
|
|
1198
|
-
@deprecated(
|
|
1199
|
-
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
1200
|
-
"removed in 1.2. Use `n_features_in_` instead."
|
|
1201
|
-
)
|
|
1202
|
-
@property
|
|
1203
|
-
def n_features_(self):
|
|
1204
|
-
return self.n_features_in_
|
|
1205
|
-
|
|
1206
1124
|
|
|
1207
|
-
class
|
|
1208
|
-
__doc__ =
|
|
1125
|
+
class RandomForestClassifier(ForestClassifier):
|
|
1126
|
+
__doc__ = sklearn_RandomForestClassifier.__doc__
|
|
1127
|
+
_onedal_factory = onedal_RandomForestClassifier
|
|
1209
1128
|
|
|
1210
1129
|
if sklearn_check_version("1.2"):
|
|
1211
1130
|
_parameter_constraints: dict = {
|
|
1212
|
-
**
|
|
1131
|
+
**sklearn_RandomForestClassifier._parameter_constraints,
|
|
1213
1132
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1214
1133
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1215
1134
|
}
|
|
1216
1135
|
|
|
1217
|
-
if sklearn_check_version("1.
|
|
1136
|
+
if sklearn_check_version("1.4"):
|
|
1137
|
+
|
|
1138
|
+
def __init__(
|
|
1139
|
+
self,
|
|
1140
|
+
n_estimators=100,
|
|
1141
|
+
*,
|
|
1142
|
+
criterion="gini",
|
|
1143
|
+
max_depth=None,
|
|
1144
|
+
min_samples_split=2,
|
|
1145
|
+
min_samples_leaf=1,
|
|
1146
|
+
min_weight_fraction_leaf=0.0,
|
|
1147
|
+
max_features="sqrt",
|
|
1148
|
+
max_leaf_nodes=None,
|
|
1149
|
+
min_impurity_decrease=0.0,
|
|
1150
|
+
bootstrap=True,
|
|
1151
|
+
oob_score=False,
|
|
1152
|
+
n_jobs=None,
|
|
1153
|
+
random_state=None,
|
|
1154
|
+
verbose=0,
|
|
1155
|
+
warm_start=False,
|
|
1156
|
+
class_weight=None,
|
|
1157
|
+
ccp_alpha=0.0,
|
|
1158
|
+
max_samples=None,
|
|
1159
|
+
monotonic_cst=None,
|
|
1160
|
+
max_bins=256,
|
|
1161
|
+
min_bin_size=1,
|
|
1162
|
+
):
|
|
1163
|
+
super().__init__(
|
|
1164
|
+
DecisionTreeClassifier(),
|
|
1165
|
+
n_estimators,
|
|
1166
|
+
estimator_params=(
|
|
1167
|
+
"criterion",
|
|
1168
|
+
"max_depth",
|
|
1169
|
+
"min_samples_split",
|
|
1170
|
+
"min_samples_leaf",
|
|
1171
|
+
"min_weight_fraction_leaf",
|
|
1172
|
+
"max_features",
|
|
1173
|
+
"max_leaf_nodes",
|
|
1174
|
+
"min_impurity_decrease",
|
|
1175
|
+
"random_state",
|
|
1176
|
+
"ccp_alpha",
|
|
1177
|
+
"monotonic_cst",
|
|
1178
|
+
),
|
|
1179
|
+
bootstrap=bootstrap,
|
|
1180
|
+
oob_score=oob_score,
|
|
1181
|
+
n_jobs=n_jobs,
|
|
1182
|
+
random_state=random_state,
|
|
1183
|
+
verbose=verbose,
|
|
1184
|
+
warm_start=warm_start,
|
|
1185
|
+
class_weight=class_weight,
|
|
1186
|
+
max_samples=max_samples,
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
self.criterion = criterion
|
|
1190
|
+
self.max_depth = max_depth
|
|
1191
|
+
self.min_samples_split = min_samples_split
|
|
1192
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1193
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1194
|
+
self.max_features = max_features
|
|
1195
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1196
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1197
|
+
self.ccp_alpha = ccp_alpha
|
|
1198
|
+
self.max_bins = max_bins
|
|
1199
|
+
self.min_bin_size = min_bin_size
|
|
1200
|
+
self.monotonic_cst = monotonic_cst
|
|
1201
|
+
|
|
1202
|
+
elif sklearn_check_version("1.0"):
|
|
1218
1203
|
|
|
1219
1204
|
def __init__(
|
|
1220
1205
|
self,
|
|
@@ -1228,7 +1213,7 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1228
1213
|
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1229
1214
|
max_leaf_nodes=None,
|
|
1230
1215
|
min_impurity_decrease=0.0,
|
|
1231
|
-
bootstrap=
|
|
1216
|
+
bootstrap=True,
|
|
1232
1217
|
oob_score=False,
|
|
1233
1218
|
n_jobs=None,
|
|
1234
1219
|
random_state=None,
|
|
@@ -1241,7 +1226,7 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1241
1226
|
min_bin_size=1,
|
|
1242
1227
|
):
|
|
1243
1228
|
super().__init__(
|
|
1244
|
-
|
|
1229
|
+
DecisionTreeClassifier(),
|
|
1245
1230
|
n_estimators,
|
|
1246
1231
|
estimator_params=(
|
|
1247
1232
|
"criterion",
|
|
@@ -1292,7 +1277,7 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1292
1277
|
max_leaf_nodes=None,
|
|
1293
1278
|
min_impurity_decrease=0.0,
|
|
1294
1279
|
min_impurity_split=None,
|
|
1295
|
-
bootstrap=
|
|
1280
|
+
bootstrap=True,
|
|
1296
1281
|
oob_score=False,
|
|
1297
1282
|
n_jobs=None,
|
|
1298
1283
|
random_state=None,
|
|
@@ -1305,7 +1290,7 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1305
1290
|
min_bin_size=1,
|
|
1306
1291
|
):
|
|
1307
1292
|
super().__init__(
|
|
1308
|
-
|
|
1293
|
+
DecisionTreeClassifier(),
|
|
1309
1294
|
n_estimators,
|
|
1310
1295
|
estimator_params=(
|
|
1311
1296
|
"criterion",
|
|
@@ -1346,17 +1331,82 @@ class ExtraTreesClassifier(ForestClassifier):
|
|
|
1346
1331
|
self.min_bin_size = min_bin_size
|
|
1347
1332
|
|
|
1348
1333
|
|
|
1349
|
-
class
|
|
1350
|
-
__doc__ =
|
|
1334
|
+
class RandomForestRegressor(ForestRegressor):
|
|
1335
|
+
__doc__ = sklearn_RandomForestRegressor.__doc__
|
|
1336
|
+
_onedal_factory = onedal_RandomForestRegressor
|
|
1351
1337
|
|
|
1352
1338
|
if sklearn_check_version("1.2"):
|
|
1353
1339
|
_parameter_constraints: dict = {
|
|
1354
|
-
**
|
|
1340
|
+
**sklearn_RandomForestRegressor._parameter_constraints,
|
|
1355
1341
|
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1356
1342
|
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1357
1343
|
}
|
|
1358
1344
|
|
|
1359
|
-
if sklearn_check_version("1.
|
|
1345
|
+
if sklearn_check_version("1.4"):
|
|
1346
|
+
|
|
1347
|
+
def __init__(
|
|
1348
|
+
self,
|
|
1349
|
+
n_estimators=100,
|
|
1350
|
+
*,
|
|
1351
|
+
criterion="squared_error",
|
|
1352
|
+
max_depth=None,
|
|
1353
|
+
min_samples_split=2,
|
|
1354
|
+
min_samples_leaf=1,
|
|
1355
|
+
min_weight_fraction_leaf=0.0,
|
|
1356
|
+
max_features=1.0,
|
|
1357
|
+
max_leaf_nodes=None,
|
|
1358
|
+
min_impurity_decrease=0.0,
|
|
1359
|
+
bootstrap=True,
|
|
1360
|
+
oob_score=False,
|
|
1361
|
+
n_jobs=None,
|
|
1362
|
+
random_state=None,
|
|
1363
|
+
verbose=0,
|
|
1364
|
+
warm_start=False,
|
|
1365
|
+
ccp_alpha=0.0,
|
|
1366
|
+
max_samples=None,
|
|
1367
|
+
monotonic_cst=None,
|
|
1368
|
+
max_bins=256,
|
|
1369
|
+
min_bin_size=1,
|
|
1370
|
+
):
|
|
1371
|
+
super().__init__(
|
|
1372
|
+
DecisionTreeRegressor(),
|
|
1373
|
+
n_estimators=n_estimators,
|
|
1374
|
+
estimator_params=(
|
|
1375
|
+
"criterion",
|
|
1376
|
+
"max_depth",
|
|
1377
|
+
"min_samples_split",
|
|
1378
|
+
"min_samples_leaf",
|
|
1379
|
+
"min_weight_fraction_leaf",
|
|
1380
|
+
"max_features",
|
|
1381
|
+
"max_leaf_nodes",
|
|
1382
|
+
"min_impurity_decrease",
|
|
1383
|
+
"random_state",
|
|
1384
|
+
"ccp_alpha",
|
|
1385
|
+
"monotonic_cst",
|
|
1386
|
+
),
|
|
1387
|
+
bootstrap=bootstrap,
|
|
1388
|
+
oob_score=oob_score,
|
|
1389
|
+
n_jobs=n_jobs,
|
|
1390
|
+
random_state=random_state,
|
|
1391
|
+
verbose=verbose,
|
|
1392
|
+
warm_start=warm_start,
|
|
1393
|
+
max_samples=max_samples,
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
self.criterion = criterion
|
|
1397
|
+
self.max_depth = max_depth
|
|
1398
|
+
self.min_samples_split = min_samples_split
|
|
1399
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1400
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1401
|
+
self.max_features = max_features
|
|
1402
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1403
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1404
|
+
self.ccp_alpha = ccp_alpha
|
|
1405
|
+
self.max_bins = max_bins
|
|
1406
|
+
self.min_bin_size = min_bin_size
|
|
1407
|
+
self.monotonic_cst = monotonic_cst
|
|
1408
|
+
|
|
1409
|
+
elif sklearn_check_version("1.0"):
|
|
1360
1410
|
|
|
1361
1411
|
def __init__(
|
|
1362
1412
|
self,
|
|
@@ -1370,7 +1420,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1370
1420
|
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1371
1421
|
max_leaf_nodes=None,
|
|
1372
1422
|
min_impurity_decrease=0.0,
|
|
1373
|
-
bootstrap=
|
|
1423
|
+
bootstrap=True,
|
|
1374
1424
|
oob_score=False,
|
|
1375
1425
|
n_jobs=None,
|
|
1376
1426
|
random_state=None,
|
|
@@ -1382,7 +1432,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1382
1432
|
min_bin_size=1,
|
|
1383
1433
|
):
|
|
1384
1434
|
super().__init__(
|
|
1385
|
-
|
|
1435
|
+
DecisionTreeRegressor(),
|
|
1386
1436
|
n_estimators=n_estimators,
|
|
1387
1437
|
estimator_params=(
|
|
1388
1438
|
"criterion",
|
|
@@ -1432,7 +1482,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1432
1482
|
max_leaf_nodes=None,
|
|
1433
1483
|
min_impurity_decrease=0.0,
|
|
1434
1484
|
min_impurity_split=None,
|
|
1435
|
-
bootstrap=
|
|
1485
|
+
bootstrap=True,
|
|
1436
1486
|
oob_score=False,
|
|
1437
1487
|
n_jobs=None,
|
|
1438
1488
|
random_state=None,
|
|
@@ -1444,7 +1494,7 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1444
1494
|
min_bin_size=1,
|
|
1445
1495
|
):
|
|
1446
1496
|
super().__init__(
|
|
1447
|
-
|
|
1497
|
+
DecisionTreeRegressor(),
|
|
1448
1498
|
n_estimators=n_estimators,
|
|
1449
1499
|
estimator_params=(
|
|
1450
1500
|
"criterion",
|
|
@@ -1479,3 +1529,419 @@ class ExtraTreesRegressor(ForestRegressor):
|
|
|
1479
1529
|
self.ccp_alpha = ccp_alpha
|
|
1480
1530
|
self.max_bins = max_bins
|
|
1481
1531
|
self.min_bin_size = min_bin_size
|
|
1532
|
+
|
|
1533
|
+
|
|
1534
|
+
class ExtraTreesClassifier(ForestClassifier):
|
|
1535
|
+
__doc__ = sklearn_ExtraTreesClassifier.__doc__
|
|
1536
|
+
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1537
|
+
|
|
1538
|
+
if sklearn_check_version("1.2"):
|
|
1539
|
+
_parameter_constraints: dict = {
|
|
1540
|
+
**sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1541
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1542
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1543
|
+
}
|
|
1544
|
+
|
|
1545
|
+
if sklearn_check_version("1.4"):
|
|
1546
|
+
|
|
1547
|
+
def __init__(
|
|
1548
|
+
self,
|
|
1549
|
+
n_estimators=100,
|
|
1550
|
+
*,
|
|
1551
|
+
criterion="gini",
|
|
1552
|
+
max_depth=None,
|
|
1553
|
+
min_samples_split=2,
|
|
1554
|
+
min_samples_leaf=1,
|
|
1555
|
+
min_weight_fraction_leaf=0.0,
|
|
1556
|
+
max_features="sqrt",
|
|
1557
|
+
max_leaf_nodes=None,
|
|
1558
|
+
min_impurity_decrease=0.0,
|
|
1559
|
+
bootstrap=False,
|
|
1560
|
+
oob_score=False,
|
|
1561
|
+
n_jobs=None,
|
|
1562
|
+
random_state=None,
|
|
1563
|
+
verbose=0,
|
|
1564
|
+
warm_start=False,
|
|
1565
|
+
class_weight=None,
|
|
1566
|
+
ccp_alpha=0.0,
|
|
1567
|
+
max_samples=None,
|
|
1568
|
+
monotonic_cst=None,
|
|
1569
|
+
max_bins=256,
|
|
1570
|
+
min_bin_size=1,
|
|
1571
|
+
):
|
|
1572
|
+
super().__init__(
|
|
1573
|
+
ExtraTreeClassifier(),
|
|
1574
|
+
n_estimators,
|
|
1575
|
+
estimator_params=(
|
|
1576
|
+
"criterion",
|
|
1577
|
+
"max_depth",
|
|
1578
|
+
"min_samples_split",
|
|
1579
|
+
"min_samples_leaf",
|
|
1580
|
+
"min_weight_fraction_leaf",
|
|
1581
|
+
"max_features",
|
|
1582
|
+
"max_leaf_nodes",
|
|
1583
|
+
"min_impurity_decrease",
|
|
1584
|
+
"random_state",
|
|
1585
|
+
"ccp_alpha",
|
|
1586
|
+
"monotonic_cst",
|
|
1587
|
+
),
|
|
1588
|
+
bootstrap=bootstrap,
|
|
1589
|
+
oob_score=oob_score,
|
|
1590
|
+
n_jobs=n_jobs,
|
|
1591
|
+
random_state=random_state,
|
|
1592
|
+
verbose=verbose,
|
|
1593
|
+
warm_start=warm_start,
|
|
1594
|
+
class_weight=class_weight,
|
|
1595
|
+
max_samples=max_samples,
|
|
1596
|
+
)
|
|
1597
|
+
|
|
1598
|
+
self.criterion = criterion
|
|
1599
|
+
self.max_depth = max_depth
|
|
1600
|
+
self.min_samples_split = min_samples_split
|
|
1601
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1602
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1603
|
+
self.max_features = max_features
|
|
1604
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1605
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1606
|
+
self.ccp_alpha = ccp_alpha
|
|
1607
|
+
self.max_bins = max_bins
|
|
1608
|
+
self.min_bin_size = min_bin_size
|
|
1609
|
+
self.monotonic_cst = monotonic_cst
|
|
1610
|
+
|
|
1611
|
+
elif sklearn_check_version("1.0"):
|
|
1612
|
+
|
|
1613
|
+
def __init__(
|
|
1614
|
+
self,
|
|
1615
|
+
n_estimators=100,
|
|
1616
|
+
*,
|
|
1617
|
+
criterion="gini",
|
|
1618
|
+
max_depth=None,
|
|
1619
|
+
min_samples_split=2,
|
|
1620
|
+
min_samples_leaf=1,
|
|
1621
|
+
min_weight_fraction_leaf=0.0,
|
|
1622
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1623
|
+
max_leaf_nodes=None,
|
|
1624
|
+
min_impurity_decrease=0.0,
|
|
1625
|
+
bootstrap=False,
|
|
1626
|
+
oob_score=False,
|
|
1627
|
+
n_jobs=None,
|
|
1628
|
+
random_state=None,
|
|
1629
|
+
verbose=0,
|
|
1630
|
+
warm_start=False,
|
|
1631
|
+
class_weight=None,
|
|
1632
|
+
ccp_alpha=0.0,
|
|
1633
|
+
max_samples=None,
|
|
1634
|
+
max_bins=256,
|
|
1635
|
+
min_bin_size=1,
|
|
1636
|
+
):
|
|
1637
|
+
super().__init__(
|
|
1638
|
+
ExtraTreeClassifier(),
|
|
1639
|
+
n_estimators,
|
|
1640
|
+
estimator_params=(
|
|
1641
|
+
"criterion",
|
|
1642
|
+
"max_depth",
|
|
1643
|
+
"min_samples_split",
|
|
1644
|
+
"min_samples_leaf",
|
|
1645
|
+
"min_weight_fraction_leaf",
|
|
1646
|
+
"max_features",
|
|
1647
|
+
"max_leaf_nodes",
|
|
1648
|
+
"min_impurity_decrease",
|
|
1649
|
+
"random_state",
|
|
1650
|
+
"ccp_alpha",
|
|
1651
|
+
),
|
|
1652
|
+
bootstrap=bootstrap,
|
|
1653
|
+
oob_score=oob_score,
|
|
1654
|
+
n_jobs=n_jobs,
|
|
1655
|
+
random_state=random_state,
|
|
1656
|
+
verbose=verbose,
|
|
1657
|
+
warm_start=warm_start,
|
|
1658
|
+
class_weight=class_weight,
|
|
1659
|
+
max_samples=max_samples,
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
self.criterion = criterion
|
|
1663
|
+
self.max_depth = max_depth
|
|
1664
|
+
self.min_samples_split = min_samples_split
|
|
1665
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1666
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1667
|
+
self.max_features = max_features
|
|
1668
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1669
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1670
|
+
self.ccp_alpha = ccp_alpha
|
|
1671
|
+
self.max_bins = max_bins
|
|
1672
|
+
self.min_bin_size = min_bin_size
|
|
1673
|
+
|
|
1674
|
+
else:
|
|
1675
|
+
|
|
1676
|
+
def __init__(
|
|
1677
|
+
self,
|
|
1678
|
+
n_estimators=100,
|
|
1679
|
+
*,
|
|
1680
|
+
criterion="gini",
|
|
1681
|
+
max_depth=None,
|
|
1682
|
+
min_samples_split=2,
|
|
1683
|
+
min_samples_leaf=1,
|
|
1684
|
+
min_weight_fraction_leaf=0.0,
|
|
1685
|
+
max_features="auto",
|
|
1686
|
+
max_leaf_nodes=None,
|
|
1687
|
+
min_impurity_decrease=0.0,
|
|
1688
|
+
min_impurity_split=None,
|
|
1689
|
+
bootstrap=False,
|
|
1690
|
+
oob_score=False,
|
|
1691
|
+
n_jobs=None,
|
|
1692
|
+
random_state=None,
|
|
1693
|
+
verbose=0,
|
|
1694
|
+
warm_start=False,
|
|
1695
|
+
class_weight=None,
|
|
1696
|
+
ccp_alpha=0.0,
|
|
1697
|
+
max_samples=None,
|
|
1698
|
+
max_bins=256,
|
|
1699
|
+
min_bin_size=1,
|
|
1700
|
+
):
|
|
1701
|
+
super().__init__(
|
|
1702
|
+
ExtraTreeClassifier(),
|
|
1703
|
+
n_estimators,
|
|
1704
|
+
estimator_params=(
|
|
1705
|
+
"criterion",
|
|
1706
|
+
"max_depth",
|
|
1707
|
+
"min_samples_split",
|
|
1708
|
+
"min_samples_leaf",
|
|
1709
|
+
"min_weight_fraction_leaf",
|
|
1710
|
+
"max_features",
|
|
1711
|
+
"max_leaf_nodes",
|
|
1712
|
+
"min_impurity_decrease",
|
|
1713
|
+
"min_impurity_split",
|
|
1714
|
+
"random_state",
|
|
1715
|
+
"ccp_alpha",
|
|
1716
|
+
),
|
|
1717
|
+
bootstrap=bootstrap,
|
|
1718
|
+
oob_score=oob_score,
|
|
1719
|
+
n_jobs=n_jobs,
|
|
1720
|
+
random_state=random_state,
|
|
1721
|
+
verbose=verbose,
|
|
1722
|
+
warm_start=warm_start,
|
|
1723
|
+
class_weight=class_weight,
|
|
1724
|
+
max_samples=max_samples,
|
|
1725
|
+
)
|
|
1726
|
+
|
|
1727
|
+
self.criterion = criterion
|
|
1728
|
+
self.max_depth = max_depth
|
|
1729
|
+
self.min_samples_split = min_samples_split
|
|
1730
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1731
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1732
|
+
self.max_features = max_features
|
|
1733
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1734
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1735
|
+
self.min_impurity_split = min_impurity_split
|
|
1736
|
+
self.ccp_alpha = ccp_alpha
|
|
1737
|
+
self.max_bins = max_bins
|
|
1738
|
+
self.min_bin_size = min_bin_size
|
|
1739
|
+
self.max_bins = max_bins
|
|
1740
|
+
self.min_bin_size = min_bin_size
|
|
1741
|
+
|
|
1742
|
+
|
|
1743
|
+
class ExtraTreesRegressor(ForestRegressor):
|
|
1744
|
+
__doc__ = sklearn_ExtraTreesRegressor.__doc__
|
|
1745
|
+
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1746
|
+
|
|
1747
|
+
if sklearn_check_version("1.2"):
|
|
1748
|
+
_parameter_constraints: dict = {
|
|
1749
|
+
**sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1750
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1751
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1752
|
+
}
|
|
1753
|
+
|
|
1754
|
+
if sklearn_check_version("1.4"):
|
|
1755
|
+
|
|
1756
|
+
def __init__(
|
|
1757
|
+
self,
|
|
1758
|
+
n_estimators=100,
|
|
1759
|
+
*,
|
|
1760
|
+
criterion="squared_error",
|
|
1761
|
+
max_depth=None,
|
|
1762
|
+
min_samples_split=2,
|
|
1763
|
+
min_samples_leaf=1,
|
|
1764
|
+
min_weight_fraction_leaf=0.0,
|
|
1765
|
+
max_features=1.0,
|
|
1766
|
+
max_leaf_nodes=None,
|
|
1767
|
+
min_impurity_decrease=0.0,
|
|
1768
|
+
bootstrap=False,
|
|
1769
|
+
oob_score=False,
|
|
1770
|
+
n_jobs=None,
|
|
1771
|
+
random_state=None,
|
|
1772
|
+
verbose=0,
|
|
1773
|
+
warm_start=False,
|
|
1774
|
+
ccp_alpha=0.0,
|
|
1775
|
+
max_samples=None,
|
|
1776
|
+
monotonic_cst=None,
|
|
1777
|
+
max_bins=256,
|
|
1778
|
+
min_bin_size=1,
|
|
1779
|
+
):
|
|
1780
|
+
super().__init__(
|
|
1781
|
+
ExtraTreeRegressor(),
|
|
1782
|
+
n_estimators=n_estimators,
|
|
1783
|
+
estimator_params=(
|
|
1784
|
+
"criterion",
|
|
1785
|
+
"max_depth",
|
|
1786
|
+
"min_samples_split",
|
|
1787
|
+
"min_samples_leaf",
|
|
1788
|
+
"min_weight_fraction_leaf",
|
|
1789
|
+
"max_features",
|
|
1790
|
+
"max_leaf_nodes",
|
|
1791
|
+
"min_impurity_decrease",
|
|
1792
|
+
"random_state",
|
|
1793
|
+
"ccp_alpha",
|
|
1794
|
+
"monotonic_cst",
|
|
1795
|
+
),
|
|
1796
|
+
bootstrap=bootstrap,
|
|
1797
|
+
oob_score=oob_score,
|
|
1798
|
+
n_jobs=n_jobs,
|
|
1799
|
+
random_state=random_state,
|
|
1800
|
+
verbose=verbose,
|
|
1801
|
+
warm_start=warm_start,
|
|
1802
|
+
max_samples=max_samples,
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
self.criterion = criterion
|
|
1806
|
+
self.max_depth = max_depth
|
|
1807
|
+
self.min_samples_split = min_samples_split
|
|
1808
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1809
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1810
|
+
self.max_features = max_features
|
|
1811
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1812
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1813
|
+
self.ccp_alpha = ccp_alpha
|
|
1814
|
+
self.max_bins = max_bins
|
|
1815
|
+
self.min_bin_size = min_bin_size
|
|
1816
|
+
self.monotonic_cst = monotonic_cst
|
|
1817
|
+
|
|
1818
|
+
elif sklearn_check_version("1.0"):
|
|
1819
|
+
|
|
1820
|
+
def __init__(
|
|
1821
|
+
self,
|
|
1822
|
+
n_estimators=100,
|
|
1823
|
+
*,
|
|
1824
|
+
criterion="squared_error",
|
|
1825
|
+
max_depth=None,
|
|
1826
|
+
min_samples_split=2,
|
|
1827
|
+
min_samples_leaf=1,
|
|
1828
|
+
min_weight_fraction_leaf=0.0,
|
|
1829
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1830
|
+
max_leaf_nodes=None,
|
|
1831
|
+
min_impurity_decrease=0.0,
|
|
1832
|
+
bootstrap=False,
|
|
1833
|
+
oob_score=False,
|
|
1834
|
+
n_jobs=None,
|
|
1835
|
+
random_state=None,
|
|
1836
|
+
verbose=0,
|
|
1837
|
+
warm_start=False,
|
|
1838
|
+
ccp_alpha=0.0,
|
|
1839
|
+
max_samples=None,
|
|
1840
|
+
max_bins=256,
|
|
1841
|
+
min_bin_size=1,
|
|
1842
|
+
):
|
|
1843
|
+
super().__init__(
|
|
1844
|
+
ExtraTreeRegressor(),
|
|
1845
|
+
n_estimators=n_estimators,
|
|
1846
|
+
estimator_params=(
|
|
1847
|
+
"criterion",
|
|
1848
|
+
"max_depth",
|
|
1849
|
+
"min_samples_split",
|
|
1850
|
+
"min_samples_leaf",
|
|
1851
|
+
"min_weight_fraction_leaf",
|
|
1852
|
+
"max_features",
|
|
1853
|
+
"max_leaf_nodes",
|
|
1854
|
+
"min_impurity_decrease",
|
|
1855
|
+
"random_state",
|
|
1856
|
+
"ccp_alpha",
|
|
1857
|
+
),
|
|
1858
|
+
bootstrap=bootstrap,
|
|
1859
|
+
oob_score=oob_score,
|
|
1860
|
+
n_jobs=n_jobs,
|
|
1861
|
+
random_state=random_state,
|
|
1862
|
+
verbose=verbose,
|
|
1863
|
+
warm_start=warm_start,
|
|
1864
|
+
max_samples=max_samples,
|
|
1865
|
+
)
|
|
1866
|
+
|
|
1867
|
+
self.criterion = criterion
|
|
1868
|
+
self.max_depth = max_depth
|
|
1869
|
+
self.min_samples_split = min_samples_split
|
|
1870
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1871
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1872
|
+
self.max_features = max_features
|
|
1873
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1874
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1875
|
+
self.ccp_alpha = ccp_alpha
|
|
1876
|
+
self.max_bins = max_bins
|
|
1877
|
+
self.min_bin_size = min_bin_size
|
|
1878
|
+
|
|
1879
|
+
else:
|
|
1880
|
+
|
|
1881
|
+
def __init__(
|
|
1882
|
+
self,
|
|
1883
|
+
n_estimators=100,
|
|
1884
|
+
*,
|
|
1885
|
+
criterion="mse",
|
|
1886
|
+
max_depth=None,
|
|
1887
|
+
min_samples_split=2,
|
|
1888
|
+
min_samples_leaf=1,
|
|
1889
|
+
min_weight_fraction_leaf=0.0,
|
|
1890
|
+
max_features="auto",
|
|
1891
|
+
max_leaf_nodes=None,
|
|
1892
|
+
min_impurity_decrease=0.0,
|
|
1893
|
+
min_impurity_split=None,
|
|
1894
|
+
bootstrap=False,
|
|
1895
|
+
oob_score=False,
|
|
1896
|
+
n_jobs=None,
|
|
1897
|
+
random_state=None,
|
|
1898
|
+
verbose=0,
|
|
1899
|
+
warm_start=False,
|
|
1900
|
+
ccp_alpha=0.0,
|
|
1901
|
+
max_samples=None,
|
|
1902
|
+
max_bins=256,
|
|
1903
|
+
min_bin_size=1,
|
|
1904
|
+
):
|
|
1905
|
+
super().__init__(
|
|
1906
|
+
ExtraTreeRegressor(),
|
|
1907
|
+
n_estimators=n_estimators,
|
|
1908
|
+
estimator_params=(
|
|
1909
|
+
"criterion",
|
|
1910
|
+
"max_depth",
|
|
1911
|
+
"min_samples_split",
|
|
1912
|
+
"min_samples_leaf",
|
|
1913
|
+
"min_weight_fraction_leaf",
|
|
1914
|
+
"max_features",
|
|
1915
|
+
"max_leaf_nodes",
|
|
1916
|
+
"min_impurity_decrease",
|
|
1917
|
+
"min_impurity_split" "random_state",
|
|
1918
|
+
"ccp_alpha",
|
|
1919
|
+
),
|
|
1920
|
+
bootstrap=bootstrap,
|
|
1921
|
+
oob_score=oob_score,
|
|
1922
|
+
n_jobs=n_jobs,
|
|
1923
|
+
random_state=random_state,
|
|
1924
|
+
verbose=verbose,
|
|
1925
|
+
warm_start=warm_start,
|
|
1926
|
+
max_samples=max_samples,
|
|
1927
|
+
)
|
|
1928
|
+
|
|
1929
|
+
self.criterion = criterion
|
|
1930
|
+
self.max_depth = max_depth
|
|
1931
|
+
self.min_samples_split = min_samples_split
|
|
1932
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1933
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1934
|
+
self.max_features = max_features
|
|
1935
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1936
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1937
|
+
self.min_impurity_split = min_impurity_split
|
|
1938
|
+
self.ccp_alpha = ccp_alpha
|
|
1939
|
+
self.max_bins = max_bins
|
|
1940
|
+
self.min_bin_size = min_bin_size
|
|
1941
|
+
|
|
1942
|
+
|
|
1943
|
+
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
1944
|
+
sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
1945
|
+
sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
1946
|
+
sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
1947
|
+
sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|