scikit-learn-intelex 2025.0.0__py311-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-311-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-311-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +242 -0
- daal4py/sklearn/_utils.py +241 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +155 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +693 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +53 -0
- onedal/_device_offload.py +229 -0
- onedal/_onedal_py_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +110 -0
- onedal/cluster/kmeans.py +560 -0
- onedal/cluster/kmeans_init.py +115 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +59 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +116 -0
- onedal/common/tests/test_policy.py +75 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +125 -0
- onedal/covariance/incremental_covariance.py +146 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +122 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +95 -0
- onedal/datatypes/tests/test_data.py +235 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +204 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +198 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +720 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +258 -0
- onedal/linear_model/linear_model.py +329 -0
- onedal/linear_model/logistic_regression.py +249 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- onedal/linear_model/tests/test_linear_regression.py +149 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +778 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +153 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +82 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +117 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +97 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +168 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +41 -0
- onedal/tests/utils/_dataframes_support.py +168 -0
- onedal/tests/utils/_device_selection.py +107 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +91 -0
- onedal/utils/validation.py +432 -0
- scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
- scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
- scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +65 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +98 -0
- sklearnex/_device_offload.py +121 -0
- sklearnex/_utils.py +109 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +140 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +192 -0
- sklearnex/cluster/k_means.py +383 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +153 -0
- sklearnex/conftest.py +73 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +368 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +414 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +543 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2016 -0
- sklearnex/ensemble/tests/test_forest.py +120 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +463 -0
- sklearnex/linear_model/incremental_ridge.py +418 -0
- sklearnex/linear_model/linear.py +302 -0
- sklearnex/linear_model/logistic_path.py +17 -0
- sklearnex/linear_model/logistic_regression.py +403 -0
- sklearnex/linear_model/ridge.py +24 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +21 -0
- sklearnex/manifold/tests/test_tsne.py +26 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +231 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +226 -0
- sklearnex/neighbors/knn_regression.py +203 -0
- sklearnex/neighbors/knn_unsupervised.py +170 -0
- sklearnex/neighbors/tests/test_neighbors.py +80 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +133 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +228 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- sklearnex/preview/linear_model/__init__.py +19 -0
- sklearnex/preview/linear_model/ridge.py +419 -0
- sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +328 -0
- sklearnex/svm/nusvc.py +332 -0
- sklearnex/svm/nusvr.py +148 -0
- sklearnex/svm/svc.py +360 -0
- sklearnex/svm/svr.py +149 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/_utils.py +328 -0
- sklearnex/tests/_utils_spmd.py +198 -0
- sklearnex/tests/test_common.py +54 -0
- sklearnex/tests/test_config.py +43 -0
- sklearnex/tests/test_memory_usage.py +291 -0
- sklearnex/tests/test_monkeypatch.py +276 -0
- sklearnex/tests/test_n_jobs_support.py +103 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +385 -0
- sklearnex/tests/test_run_to_run_stability.py +296 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_finite.py +89 -0
- sklearnex/utils/validation.py +17 -0
|
@@ -0,0 +1,2016 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
import warnings
|
|
19
|
+
from abc import ABC
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from scipy import sparse as sp
|
|
23
|
+
from sklearn.base import clone
|
|
24
|
+
from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
|
|
25
|
+
from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
|
|
26
|
+
from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
|
|
27
|
+
from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
|
|
28
|
+
from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
|
|
29
|
+
from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
|
|
30
|
+
from sklearn.ensemble._forest import _get_n_samples_bootstrap
|
|
31
|
+
from sklearn.exceptions import DataConversionWarning
|
|
32
|
+
from sklearn.metrics import accuracy_score, r2_score
|
|
33
|
+
from sklearn.tree import (
|
|
34
|
+
DecisionTreeClassifier,
|
|
35
|
+
DecisionTreeRegressor,
|
|
36
|
+
ExtraTreeClassifier,
|
|
37
|
+
ExtraTreeRegressor,
|
|
38
|
+
)
|
|
39
|
+
from sklearn.tree._tree import Tree
|
|
40
|
+
from sklearn.utils import check_random_state, deprecated
|
|
41
|
+
from sklearn.utils.validation import (
|
|
42
|
+
_check_sample_weight,
|
|
43
|
+
check_array,
|
|
44
|
+
check_is_fitted,
|
|
45
|
+
check_X_y,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
from daal4py.sklearn._n_jobs_support import control_n_jobs
|
|
49
|
+
from daal4py.sklearn._utils import (
|
|
50
|
+
check_tree_nodes,
|
|
51
|
+
daal_check_version,
|
|
52
|
+
sklearn_check_version,
|
|
53
|
+
)
|
|
54
|
+
from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
55
|
+
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
56
|
+
from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
|
|
57
|
+
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
|
|
58
|
+
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
59
|
+
from onedal.utils import _num_features, _num_samples
|
|
60
|
+
|
|
61
|
+
from .._device_offload import dispatch, wrap_output_data
|
|
62
|
+
from .._utils import PatchingConditionsChain
|
|
63
|
+
from ..utils._array_api import get_namespace
|
|
64
|
+
|
|
65
|
+
if sklearn_check_version("1.2"):
|
|
66
|
+
from sklearn.utils._param_validation import Interval
|
|
67
|
+
if sklearn_check_version("1.4"):
|
|
68
|
+
from daal4py.sklearn.utils import _assert_all_finite
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class BaseForest(ABC):
|
|
72
|
+
_onedal_factory = None
|
|
73
|
+
|
|
74
|
+
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
75
|
+
X, y = self._validate_data(
|
|
76
|
+
X,
|
|
77
|
+
y,
|
|
78
|
+
multi_output=True,
|
|
79
|
+
accept_sparse=False,
|
|
80
|
+
dtype=[np.float64, np.float32],
|
|
81
|
+
force_all_finite=False,
|
|
82
|
+
ensure_2d=True,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if sample_weight is not None:
|
|
86
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
87
|
+
|
|
88
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
89
|
+
warnings.warn(
|
|
90
|
+
"A column-vector y was passed when a 1d array was"
|
|
91
|
+
" expected. Please change the shape of y to "
|
|
92
|
+
"(n_samples,), for example using ravel().",
|
|
93
|
+
DataConversionWarning,
|
|
94
|
+
stacklevel=2,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if y.ndim == 1:
|
|
98
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
99
|
+
# [:, np.newaxis] that does not.
|
|
100
|
+
y = np.reshape(y, (-1, 1))
|
|
101
|
+
|
|
102
|
+
self._n_samples, self.n_outputs_ = y.shape
|
|
103
|
+
|
|
104
|
+
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
105
|
+
|
|
106
|
+
if expanded_class_weight is not None:
|
|
107
|
+
if sample_weight is not None:
|
|
108
|
+
sample_weight = sample_weight * expanded_class_weight
|
|
109
|
+
else:
|
|
110
|
+
sample_weight = expanded_class_weight
|
|
111
|
+
if sample_weight is not None:
|
|
112
|
+
sample_weight = [sample_weight]
|
|
113
|
+
|
|
114
|
+
onedal_params = {
|
|
115
|
+
"n_estimators": self.n_estimators,
|
|
116
|
+
"criterion": self.criterion,
|
|
117
|
+
"max_depth": self.max_depth,
|
|
118
|
+
"min_samples_split": self.min_samples_split,
|
|
119
|
+
"min_samples_leaf": self.min_samples_leaf,
|
|
120
|
+
"min_weight_fraction_leaf": self.min_weight_fraction_leaf,
|
|
121
|
+
"max_features": self._to_absolute_max_features(
|
|
122
|
+
self.max_features, self.n_features_in_
|
|
123
|
+
),
|
|
124
|
+
"max_leaf_nodes": self.max_leaf_nodes,
|
|
125
|
+
"min_impurity_decrease": self.min_impurity_decrease,
|
|
126
|
+
"bootstrap": self.bootstrap,
|
|
127
|
+
"oob_score": self.oob_score,
|
|
128
|
+
"n_jobs": self.n_jobs,
|
|
129
|
+
"random_state": self.random_state,
|
|
130
|
+
"verbose": self.verbose,
|
|
131
|
+
"warm_start": self.warm_start,
|
|
132
|
+
"error_metric_mode": self._err if self.oob_score else "none",
|
|
133
|
+
"variable_importance_mode": "mdi",
|
|
134
|
+
"class_weight": self.class_weight,
|
|
135
|
+
"max_bins": self.max_bins,
|
|
136
|
+
"min_bin_size": self.min_bin_size,
|
|
137
|
+
"max_samples": self.max_samples,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
if not sklearn_check_version("1.0"):
|
|
141
|
+
onedal_params["min_impurity_split"] = self.min_impurity_split
|
|
142
|
+
else:
|
|
143
|
+
onedal_params["min_impurity_split"] = None
|
|
144
|
+
|
|
145
|
+
# Lazy evaluation of estimators_
|
|
146
|
+
self._cached_estimators_ = None
|
|
147
|
+
|
|
148
|
+
# Compute
|
|
149
|
+
self._onedal_estimator = self._onedal_factory(**onedal_params)
|
|
150
|
+
self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
|
|
151
|
+
|
|
152
|
+
self._save_attributes()
|
|
153
|
+
|
|
154
|
+
# Decapsulate classes_ attributes
|
|
155
|
+
if hasattr(self, "classes_") and self.n_outputs_ == 1:
|
|
156
|
+
self.n_classes_ = self.n_classes_[0]
|
|
157
|
+
self.classes_ = self.classes_[0]
|
|
158
|
+
|
|
159
|
+
return self
|
|
160
|
+
|
|
161
|
+
def _save_attributes(self):
|
|
162
|
+
if self.oob_score:
|
|
163
|
+
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
164
|
+
if hasattr(self._onedal_estimator, "oob_prediction_"):
|
|
165
|
+
self.oob_prediction_ = self._onedal_estimator.oob_prediction_
|
|
166
|
+
if hasattr(self._onedal_estimator, "oob_decision_function_"):
|
|
167
|
+
self.oob_decision_function_ = (
|
|
168
|
+
self._onedal_estimator.oob_decision_function_
|
|
169
|
+
)
|
|
170
|
+
if self.bootstrap:
|
|
171
|
+
self._n_samples_bootstrap = max(
|
|
172
|
+
round(
|
|
173
|
+
self._onedal_estimator.observations_per_tree_fraction
|
|
174
|
+
* self._n_samples
|
|
175
|
+
),
|
|
176
|
+
1,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
self._n_samples_bootstrap = None
|
|
180
|
+
self._validate_estimator()
|
|
181
|
+
return self
|
|
182
|
+
|
|
183
|
+
def _to_absolute_max_features(self, max_features, n_features):
|
|
184
|
+
if max_features is None:
|
|
185
|
+
return n_features
|
|
186
|
+
if isinstance(max_features, str):
|
|
187
|
+
if max_features == "auto":
|
|
188
|
+
if not sklearn_check_version("1.3"):
|
|
189
|
+
if sklearn_check_version("1.1"):
|
|
190
|
+
warnings.warn(
|
|
191
|
+
"`max_features='auto'` has been deprecated in 1.1 "
|
|
192
|
+
"and will be removed in 1.3. To keep the past behaviour, "
|
|
193
|
+
"explicitly set `max_features=1.0` or remove this "
|
|
194
|
+
"parameter as it is also the default value for "
|
|
195
|
+
"RandomForestRegressors and ExtraTreesRegressors.",
|
|
196
|
+
FutureWarning,
|
|
197
|
+
)
|
|
198
|
+
return (
|
|
199
|
+
max(1, int(np.sqrt(n_features)))
|
|
200
|
+
if isinstance(self, ForestClassifier)
|
|
201
|
+
else n_features
|
|
202
|
+
)
|
|
203
|
+
if max_features == "sqrt":
|
|
204
|
+
return max(1, int(np.sqrt(n_features)))
|
|
205
|
+
if max_features == "log2":
|
|
206
|
+
return max(1, int(np.log2(n_features)))
|
|
207
|
+
allowed_string_values = (
|
|
208
|
+
'"sqrt" or "log2"'
|
|
209
|
+
if sklearn_check_version("1.3")
|
|
210
|
+
else '"auto", "sqrt" or "log2"'
|
|
211
|
+
)
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"Invalid value for max_features. Allowed string "
|
|
214
|
+
f"values are {allowed_string_values}."
|
|
215
|
+
)
|
|
216
|
+
if isinstance(max_features, (numbers.Integral, np.integer)):
|
|
217
|
+
return max_features
|
|
218
|
+
if max_features > 0.0:
|
|
219
|
+
return max(1, int(max_features * n_features))
|
|
220
|
+
return 0
|
|
221
|
+
|
|
222
|
+
def _check_parameters(self):
|
|
223
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
224
|
+
if not 1 <= self.min_samples_leaf:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"min_samples_leaf must be at least 1 "
|
|
227
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
228
|
+
)
|
|
229
|
+
else: # float
|
|
230
|
+
if not 0.0 < self.min_samples_leaf <= 0.5:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"min_samples_leaf must be at least 1 "
|
|
233
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
234
|
+
)
|
|
235
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
236
|
+
if not 2 <= self.min_samples_split:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"min_samples_split must be an integer "
|
|
239
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
240
|
+
"got the integer %s" % self.min_samples_split
|
|
241
|
+
)
|
|
242
|
+
else: # float
|
|
243
|
+
if not 0.0 < self.min_samples_split <= 1.0:
|
|
244
|
+
raise ValueError(
|
|
245
|
+
"min_samples_split must be an integer "
|
|
246
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
247
|
+
"got the float %s" % self.min_samples_split
|
|
248
|
+
)
|
|
249
|
+
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
|
|
250
|
+
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
|
|
251
|
+
if hasattr(self, "min_impurity_split"):
|
|
252
|
+
warnings.warn(
|
|
253
|
+
"The min_impurity_split parameter is deprecated. "
|
|
254
|
+
"Its default value has changed from 1e-7 to 0 in "
|
|
255
|
+
"version 0.23, and it will be removed in 0.25. "
|
|
256
|
+
"Use the min_impurity_decrease parameter instead.",
|
|
257
|
+
FutureWarning,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if getattr(self, "min_impurity_split") < 0.0:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"min_impurity_split must be greater than " "or equal to 0"
|
|
263
|
+
)
|
|
264
|
+
if self.min_impurity_decrease < 0.0:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
"min_impurity_decrease must be greater than " "or equal to 0"
|
|
267
|
+
)
|
|
268
|
+
if self.max_leaf_nodes is not None:
|
|
269
|
+
if not isinstance(self.max_leaf_nodes, numbers.Integral):
|
|
270
|
+
raise ValueError(
|
|
271
|
+
"max_leaf_nodes must be integral number but was "
|
|
272
|
+
"%r" % self.max_leaf_nodes
|
|
273
|
+
)
|
|
274
|
+
if self.max_leaf_nodes < 2:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
("max_leaf_nodes {0} must be either None " "or larger than 1").format(
|
|
277
|
+
self.max_leaf_nodes
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
if isinstance(self.max_bins, numbers.Integral):
|
|
281
|
+
if not 2 <= self.max_bins:
|
|
282
|
+
raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
|
|
283
|
+
else:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
"max_bins must be integral number but was " "%r" % self.max_bins
|
|
286
|
+
)
|
|
287
|
+
if isinstance(self.min_bin_size, numbers.Integral):
|
|
288
|
+
if not 1 <= self.min_bin_size:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
"min_bin_size must be at least 1, got %s" % self.min_bin_size
|
|
291
|
+
)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def estimators_(self):
|
|
299
|
+
if hasattr(self, "_cached_estimators_"):
|
|
300
|
+
if self._cached_estimators_ is None:
|
|
301
|
+
self._estimators_()
|
|
302
|
+
return self._cached_estimators_
|
|
303
|
+
else:
|
|
304
|
+
raise AttributeError(
|
|
305
|
+
f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
@estimators_.setter
|
|
309
|
+
def estimators_(self, estimators):
|
|
310
|
+
# Needed to allow for proper sklearn operation in fallback mode
|
|
311
|
+
self._cached_estimators_ = estimators
|
|
312
|
+
|
|
313
|
+
def _estimators_(self):
|
|
314
|
+
# _estimators_ should only be called if _onedal_estimator exists
|
|
315
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
316
|
+
if hasattr(self, "n_classes_"):
|
|
317
|
+
n_classes_ = (
|
|
318
|
+
self.n_classes_
|
|
319
|
+
if isinstance(self.n_classes_, int)
|
|
320
|
+
else self.n_classes_[0]
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
n_classes_ = 1
|
|
324
|
+
|
|
325
|
+
# convert model to estimators
|
|
326
|
+
params = {
|
|
327
|
+
"criterion": self._onedal_estimator.criterion,
|
|
328
|
+
"max_depth": self._onedal_estimator.max_depth,
|
|
329
|
+
"min_samples_split": self._onedal_estimator.min_samples_split,
|
|
330
|
+
"min_samples_leaf": self._onedal_estimator.min_samples_leaf,
|
|
331
|
+
"min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
|
|
332
|
+
"max_features": self._onedal_estimator.max_features,
|
|
333
|
+
"max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
|
|
334
|
+
"min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
|
|
335
|
+
"random_state": None,
|
|
336
|
+
}
|
|
337
|
+
if not sklearn_check_version("1.0"):
|
|
338
|
+
params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
|
|
339
|
+
est = self.estimator.__class__(**params)
|
|
340
|
+
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
341
|
+
# oneAPI Data Analytics Library solution
|
|
342
|
+
estimators_ = []
|
|
343
|
+
|
|
344
|
+
random_state_checked = check_random_state(self.random_state)
|
|
345
|
+
|
|
346
|
+
for i in range(self._onedal_estimator.n_estimators):
|
|
347
|
+
est_i = clone(est)
|
|
348
|
+
est_i.set_params(
|
|
349
|
+
random_state=random_state_checked.randint(np.iinfo(np.int32).max)
|
|
350
|
+
)
|
|
351
|
+
if sklearn_check_version("1.0"):
|
|
352
|
+
est_i.n_features_in_ = self.n_features_in_
|
|
353
|
+
else:
|
|
354
|
+
est_i.n_features_ = self.n_features_in_
|
|
355
|
+
est_i.n_outputs_ = self.n_outputs_
|
|
356
|
+
est_i.n_classes_ = n_classes_
|
|
357
|
+
tree_i_state_class = self._get_tree_state(
|
|
358
|
+
self._onedal_estimator._onedal_model, i, n_classes_
|
|
359
|
+
)
|
|
360
|
+
tree_i_state_dict = {
|
|
361
|
+
"max_depth": tree_i_state_class.max_depth,
|
|
362
|
+
"node_count": tree_i_state_class.node_count,
|
|
363
|
+
"nodes": check_tree_nodes(tree_i_state_class.node_ar),
|
|
364
|
+
"values": tree_i_state_class.value_ar,
|
|
365
|
+
}
|
|
366
|
+
est_i.tree_ = Tree(
|
|
367
|
+
self.n_features_in_,
|
|
368
|
+
np.array([n_classes_], dtype=np.intp),
|
|
369
|
+
self.n_outputs_,
|
|
370
|
+
)
|
|
371
|
+
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
372
|
+
estimators_.append(est_i)
|
|
373
|
+
|
|
374
|
+
self._cached_estimators_ = estimators_
|
|
375
|
+
|
|
376
|
+
if sklearn_check_version("1.0"):
|
|
377
|
+
|
|
378
|
+
@deprecated(
|
|
379
|
+
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
380
|
+
"removed in 1.2. Use `n_features_in_` instead."
|
|
381
|
+
)
|
|
382
|
+
@property
|
|
383
|
+
def n_features_(self):
|
|
384
|
+
return self.n_features_in_
|
|
385
|
+
|
|
386
|
+
if not sklearn_check_version("1.2"):
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def base_estimator(self):
|
|
390
|
+
return self.estimator
|
|
391
|
+
|
|
392
|
+
@base_estimator.setter
|
|
393
|
+
def base_estimator(self, estimator):
|
|
394
|
+
self.estimator = estimator
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class ForestClassifier(sklearn_ForestClassifier, BaseForest):
|
|
398
|
+
# Surprisingly, even though scikit-learn warns against using
|
|
399
|
+
# their ForestClassifier directly, it actually has a more stable
|
|
400
|
+
# API than the user-facing objects (over time). If they change it
|
|
401
|
+
# significantly at some point then this may need to be versioned.
|
|
402
|
+
|
|
403
|
+
_err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
|
|
404
|
+
_get_tree_state = staticmethod(get_tree_state_cls)
|
|
405
|
+
|
|
406
|
+
def __init__(
|
|
407
|
+
self,
|
|
408
|
+
estimator,
|
|
409
|
+
n_estimators=100,
|
|
410
|
+
*,
|
|
411
|
+
estimator_params=tuple(),
|
|
412
|
+
bootstrap=False,
|
|
413
|
+
oob_score=False,
|
|
414
|
+
n_jobs=None,
|
|
415
|
+
random_state=None,
|
|
416
|
+
verbose=0,
|
|
417
|
+
warm_start=False,
|
|
418
|
+
class_weight=None,
|
|
419
|
+
max_samples=None,
|
|
420
|
+
):
|
|
421
|
+
super().__init__(
|
|
422
|
+
estimator,
|
|
423
|
+
n_estimators=n_estimators,
|
|
424
|
+
estimator_params=estimator_params,
|
|
425
|
+
bootstrap=bootstrap,
|
|
426
|
+
oob_score=oob_score,
|
|
427
|
+
n_jobs=n_jobs,
|
|
428
|
+
random_state=random_state,
|
|
429
|
+
verbose=verbose,
|
|
430
|
+
warm_start=warm_start,
|
|
431
|
+
class_weight=class_weight,
|
|
432
|
+
max_samples=max_samples,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# The estimator is checked against the class attribute for conformance.
|
|
436
|
+
# This should only trigger if the user uses this class directly.
|
|
437
|
+
if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
|
|
438
|
+
self._onedal_factory, onedal_RandomForestClassifier
|
|
439
|
+
):
|
|
440
|
+
self._onedal_factory = onedal_RandomForestClassifier
|
|
441
|
+
elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
|
|
442
|
+
self._onedal_factory, onedal_ExtraTreesClassifier
|
|
443
|
+
):
|
|
444
|
+
self._onedal_factory = onedal_ExtraTreesClassifier
|
|
445
|
+
|
|
446
|
+
if self._onedal_factory is None:
|
|
447
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
448
|
+
|
|
449
|
+
def _estimators_(self):
|
|
450
|
+
super()._estimators_()
|
|
451
|
+
classes_ = self.classes_[0]
|
|
452
|
+
for est in self._cached_estimators_:
|
|
453
|
+
est.classes_ = classes_
|
|
454
|
+
|
|
455
|
+
def fit(self, X, y, sample_weight=None):
|
|
456
|
+
dispatch(
|
|
457
|
+
self,
|
|
458
|
+
"fit",
|
|
459
|
+
{
|
|
460
|
+
"onedal": self.__class__._onedal_fit,
|
|
461
|
+
"sklearn": sklearn_ForestClassifier.fit,
|
|
462
|
+
},
|
|
463
|
+
X,
|
|
464
|
+
y,
|
|
465
|
+
sample_weight,
|
|
466
|
+
)
|
|
467
|
+
return self
|
|
468
|
+
|
|
469
|
+
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
470
|
+
if sp.issparse(y):
|
|
471
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
472
|
+
|
|
473
|
+
if sklearn_check_version("1.2"):
|
|
474
|
+
self._validate_params()
|
|
475
|
+
else:
|
|
476
|
+
self._check_parameters()
|
|
477
|
+
|
|
478
|
+
if not self.bootstrap and self.oob_score:
|
|
479
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
480
|
+
|
|
481
|
+
patching_status.and_conditions(
|
|
482
|
+
[
|
|
483
|
+
(
|
|
484
|
+
self.oob_score
|
|
485
|
+
and daal_check_version((2021, "P", 500))
|
|
486
|
+
or not self.oob_score,
|
|
487
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
488
|
+
),
|
|
489
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
490
|
+
(
|
|
491
|
+
self.criterion == "gini",
|
|
492
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
493
|
+
"Only 'gini' criterion is supported.",
|
|
494
|
+
),
|
|
495
|
+
(
|
|
496
|
+
self.ccp_alpha == 0.0,
|
|
497
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
498
|
+
),
|
|
499
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
500
|
+
(
|
|
501
|
+
self.n_estimators <= 6024,
|
|
502
|
+
"More than 6024 estimators is not supported.",
|
|
503
|
+
),
|
|
504
|
+
]
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if self.bootstrap:
|
|
508
|
+
patching_status.and_conditions(
|
|
509
|
+
[
|
|
510
|
+
(
|
|
511
|
+
self.class_weight != "balanced_subsample",
|
|
512
|
+
"'balanced_subsample' for class_weight is not supported",
|
|
513
|
+
)
|
|
514
|
+
]
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
518
|
+
try:
|
|
519
|
+
_assert_all_finite(X)
|
|
520
|
+
input_is_finite = True
|
|
521
|
+
except ValueError:
|
|
522
|
+
input_is_finite = False
|
|
523
|
+
patching_status.and_conditions(
|
|
524
|
+
[
|
|
525
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
526
|
+
(
|
|
527
|
+
self.monotonic_cst is None,
|
|
528
|
+
"Monotonicity constraints are not supported.",
|
|
529
|
+
),
|
|
530
|
+
]
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
if patching_status.get_status():
|
|
534
|
+
X, y = check_X_y(
|
|
535
|
+
X,
|
|
536
|
+
y,
|
|
537
|
+
multi_output=True,
|
|
538
|
+
accept_sparse=True,
|
|
539
|
+
dtype=[np.float64, np.float32],
|
|
540
|
+
force_all_finite=False,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
544
|
+
warnings.warn(
|
|
545
|
+
"A column-vector y was passed when a 1d array was"
|
|
546
|
+
" expected. Please change the shape of y to "
|
|
547
|
+
"(n_samples,), for example using ravel().",
|
|
548
|
+
DataConversionWarning,
|
|
549
|
+
stacklevel=2,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if y.ndim == 1:
|
|
553
|
+
y = np.reshape(y, (-1, 1))
|
|
554
|
+
|
|
555
|
+
self.n_outputs_ = y.shape[1]
|
|
556
|
+
|
|
557
|
+
patching_status.and_conditions(
|
|
558
|
+
[
|
|
559
|
+
(
|
|
560
|
+
self.n_outputs_ == 1,
|
|
561
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
562
|
+
),
|
|
563
|
+
(
|
|
564
|
+
y.dtype in [np.float32, np.float64, np.int32, np.int64],
|
|
565
|
+
f"Datatype ({y.dtype}) for y is not supported.",
|
|
566
|
+
),
|
|
567
|
+
]
|
|
568
|
+
)
|
|
569
|
+
# TODO: Fix to support integers as input
|
|
570
|
+
|
|
571
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
572
|
+
|
|
573
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
574
|
+
raise ValueError(
|
|
575
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
576
|
+
"Either switch to `bootstrap=True` or set "
|
|
577
|
+
"`max_sample=None`."
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
if (
|
|
581
|
+
patching_status.get_status()
|
|
582
|
+
and (self.random_state is not None)
|
|
583
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
584
|
+
):
|
|
585
|
+
warnings.warn(
|
|
586
|
+
"Setting 'random_state' value is not supported. "
|
|
587
|
+
"State set by oneDAL to default value (777).",
|
|
588
|
+
RuntimeWarning,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
return patching_status, X, y, sample_weight
|
|
592
|
+
|
|
593
|
+
@wrap_output_data
|
|
594
|
+
def predict(self, X):
|
|
595
|
+
return dispatch(
|
|
596
|
+
self,
|
|
597
|
+
"predict",
|
|
598
|
+
{
|
|
599
|
+
"onedal": self.__class__._onedal_predict,
|
|
600
|
+
"sklearn": sklearn_ForestClassifier.predict,
|
|
601
|
+
},
|
|
602
|
+
X,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
@wrap_output_data
|
|
606
|
+
def predict_proba(self, X):
|
|
607
|
+
# TODO:
|
|
608
|
+
# _check_proba()
|
|
609
|
+
# self._check_proba()
|
|
610
|
+
if sklearn_check_version("1.0"):
|
|
611
|
+
self._check_feature_names(X, reset=False)
|
|
612
|
+
if hasattr(self, "n_features_in_"):
|
|
613
|
+
try:
|
|
614
|
+
num_features = _num_features(X)
|
|
615
|
+
except TypeError:
|
|
616
|
+
num_features = _num_samples(X)
|
|
617
|
+
if num_features != self.n_features_in_:
|
|
618
|
+
raise ValueError(
|
|
619
|
+
(
|
|
620
|
+
f"X has {num_features} features, "
|
|
621
|
+
f"but {self.__class__.__name__} is expecting "
|
|
622
|
+
f"{self.n_features_in_} features as input"
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
return dispatch(
|
|
626
|
+
self,
|
|
627
|
+
"predict_proba",
|
|
628
|
+
{
|
|
629
|
+
"onedal": self.__class__._onedal_predict_proba,
|
|
630
|
+
"sklearn": sklearn_ForestClassifier.predict_proba,
|
|
631
|
+
},
|
|
632
|
+
X,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
def predict_log_proba(self, X):
|
|
636
|
+
xp, _ = get_namespace(X)
|
|
637
|
+
proba = self.predict_proba(X)
|
|
638
|
+
|
|
639
|
+
if self.n_outputs_ == 1:
|
|
640
|
+
return xp.log(proba)
|
|
641
|
+
|
|
642
|
+
else:
|
|
643
|
+
for k in range(self.n_outputs_):
|
|
644
|
+
proba[k] = xp.log(proba[k])
|
|
645
|
+
|
|
646
|
+
return proba
|
|
647
|
+
|
|
648
|
+
@wrap_output_data
|
|
649
|
+
def score(self, X, y, sample_weight=None):
|
|
650
|
+
return dispatch(
|
|
651
|
+
self,
|
|
652
|
+
"score",
|
|
653
|
+
{
|
|
654
|
+
"onedal": self.__class__._onedal_score,
|
|
655
|
+
"sklearn": sklearn_ForestClassifier.score,
|
|
656
|
+
},
|
|
657
|
+
X,
|
|
658
|
+
y,
|
|
659
|
+
sample_weight=sample_weight,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
fit.__doc__ = sklearn_ForestClassifier.fit.__doc__
|
|
663
|
+
predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
|
|
664
|
+
predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
|
|
665
|
+
predict_log_proba.__doc__ = sklearn_ForestClassifier.predict_log_proba.__doc__
|
|
666
|
+
score.__doc__ = sklearn_ForestClassifier.score.__doc__
|
|
667
|
+
|
|
668
|
+
def _onedal_cpu_supported(self, method_name, *data):
|
|
669
|
+
class_name = self.__class__.__name__
|
|
670
|
+
patching_status = PatchingConditionsChain(
|
|
671
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
if method_name == "fit":
|
|
675
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
676
|
+
patching_status, *data
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
patching_status.and_conditions(
|
|
680
|
+
[
|
|
681
|
+
(
|
|
682
|
+
daal_check_version((2023, "P", 200))
|
|
683
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
684
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
685
|
+
),
|
|
686
|
+
(
|
|
687
|
+
not sp.issparse(sample_weight),
|
|
688
|
+
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
689
|
+
),
|
|
690
|
+
]
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
694
|
+
X = data[0]
|
|
695
|
+
|
|
696
|
+
patching_status.and_conditions(
|
|
697
|
+
[
|
|
698
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
699
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
700
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
701
|
+
(
|
|
702
|
+
daal_check_version((2023, "P", 100))
|
|
703
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
704
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
705
|
+
),
|
|
706
|
+
]
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
if method_name == "predict_proba":
|
|
710
|
+
patching_status.and_conditions(
|
|
711
|
+
[
|
|
712
|
+
(
|
|
713
|
+
daal_check_version((2021, "P", 400)),
|
|
714
|
+
"oneDAL version is lower than 2021.4.",
|
|
715
|
+
)
|
|
716
|
+
]
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
if hasattr(self, "n_outputs_"):
|
|
720
|
+
patching_status.and_conditions(
|
|
721
|
+
[
|
|
722
|
+
(
|
|
723
|
+
self.n_outputs_ == 1,
|
|
724
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
725
|
+
),
|
|
726
|
+
]
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
else:
|
|
730
|
+
raise RuntimeError(
|
|
731
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
return patching_status
|
|
735
|
+
|
|
736
|
+
def _onedal_gpu_supported(self, method_name, *data):
|
|
737
|
+
class_name = self.__class__.__name__
|
|
738
|
+
patching_status = PatchingConditionsChain(
|
|
739
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
if method_name == "fit":
|
|
743
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
744
|
+
patching_status, *data
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
patching_status.and_conditions(
|
|
748
|
+
[
|
|
749
|
+
(
|
|
750
|
+
daal_check_version((2023, "P", 100))
|
|
751
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
752
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
753
|
+
),
|
|
754
|
+
(
|
|
755
|
+
not self.oob_score,
|
|
756
|
+
"oob_scores using r2 or accuracy not implemented.",
|
|
757
|
+
),
|
|
758
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
759
|
+
]
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
elif method_name in ["predict", "predict_proba", "score"]:
|
|
763
|
+
X = data[0]
|
|
764
|
+
|
|
765
|
+
patching_status.and_conditions(
|
|
766
|
+
[
|
|
767
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
|
|
768
|
+
(
|
|
769
|
+
not sp.issparse(X),
|
|
770
|
+
"X is sparse. Sparse input is not supported.",
|
|
771
|
+
),
|
|
772
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
773
|
+
(
|
|
774
|
+
daal_check_version((2023, "P", 100)),
|
|
775
|
+
"ExtraTrees supported starting from oneDAL version 2023.1",
|
|
776
|
+
),
|
|
777
|
+
]
|
|
778
|
+
)
|
|
779
|
+
if hasattr(self, "n_outputs_"):
|
|
780
|
+
patching_status.and_conditions(
|
|
781
|
+
[
|
|
782
|
+
(
|
|
783
|
+
self.n_outputs_ == 1,
|
|
784
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
785
|
+
),
|
|
786
|
+
]
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
else:
|
|
790
|
+
raise RuntimeError(
|
|
791
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
return patching_status
|
|
795
|
+
|
|
796
|
+
def _onedal_predict(self, X, queue=None):
|
|
797
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
798
|
+
|
|
799
|
+
if sklearn_check_version("1.0"):
|
|
800
|
+
X = self._validate_data(
|
|
801
|
+
X,
|
|
802
|
+
dtype=[np.float64, np.float32],
|
|
803
|
+
force_all_finite=False,
|
|
804
|
+
reset=False,
|
|
805
|
+
ensure_2d=True,
|
|
806
|
+
)
|
|
807
|
+
else:
|
|
808
|
+
X = check_array(
|
|
809
|
+
X,
|
|
810
|
+
dtype=[np.float64, np.float32],
|
|
811
|
+
force_all_finite=False,
|
|
812
|
+
) # Warning, order of dtype matters
|
|
813
|
+
self._check_n_features(X, reset=False)
|
|
814
|
+
|
|
815
|
+
res = self._onedal_estimator.predict(X, queue=queue)
|
|
816
|
+
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
|
|
817
|
+
|
|
818
|
+
def _onedal_predict_proba(self, X, queue=None):
|
|
819
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
820
|
+
|
|
821
|
+
if sklearn_check_version("1.0"):
|
|
822
|
+
X = self._validate_data(
|
|
823
|
+
X,
|
|
824
|
+
dtype=[np.float64, np.float32],
|
|
825
|
+
force_all_finite=False,
|
|
826
|
+
reset=False,
|
|
827
|
+
ensure_2d=True,
|
|
828
|
+
)
|
|
829
|
+
else:
|
|
830
|
+
X = check_array(
|
|
831
|
+
X,
|
|
832
|
+
dtype=[np.float64, np.float32],
|
|
833
|
+
force_all_finite=False,
|
|
834
|
+
) # Warning, order of dtype matters
|
|
835
|
+
self._check_n_features(X, reset=False)
|
|
836
|
+
|
|
837
|
+
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
838
|
+
|
|
839
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
840
|
+
return accuracy_score(
|
|
841
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
class ForestRegressor(sklearn_ForestRegressor, BaseForest):
|
|
846
|
+
_err = "out_of_bag_error_r2|out_of_bag_error_prediction"
|
|
847
|
+
_get_tree_state = staticmethod(get_tree_state_reg)
|
|
848
|
+
|
|
849
|
+
def __init__(
|
|
850
|
+
self,
|
|
851
|
+
estimator,
|
|
852
|
+
n_estimators=100,
|
|
853
|
+
*,
|
|
854
|
+
estimator_params=tuple(),
|
|
855
|
+
bootstrap=False,
|
|
856
|
+
oob_score=False,
|
|
857
|
+
n_jobs=None,
|
|
858
|
+
random_state=None,
|
|
859
|
+
verbose=0,
|
|
860
|
+
warm_start=False,
|
|
861
|
+
max_samples=None,
|
|
862
|
+
):
|
|
863
|
+
super().__init__(
|
|
864
|
+
estimator,
|
|
865
|
+
n_estimators=n_estimators,
|
|
866
|
+
estimator_params=estimator_params,
|
|
867
|
+
bootstrap=bootstrap,
|
|
868
|
+
oob_score=oob_score,
|
|
869
|
+
n_jobs=n_jobs,
|
|
870
|
+
random_state=random_state,
|
|
871
|
+
verbose=verbose,
|
|
872
|
+
warm_start=warm_start,
|
|
873
|
+
max_samples=max_samples,
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# The splitter is checked against the class attribute for conformance
|
|
877
|
+
# This should only trigger if the user uses this class directly.
|
|
878
|
+
if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
|
|
879
|
+
self._onedal_factory, onedal_RandomForestRegressor
|
|
880
|
+
):
|
|
881
|
+
self._onedal_factory = onedal_RandomForestRegressor
|
|
882
|
+
elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
|
|
883
|
+
self._onedal_factory, onedal_ExtraTreesRegressor
|
|
884
|
+
):
|
|
885
|
+
self._onedal_factory = onedal_ExtraTreesRegressor
|
|
886
|
+
|
|
887
|
+
if self._onedal_factory is None:
|
|
888
|
+
raise TypeError(f" oneDAL estimator has not been set.")
|
|
889
|
+
|
|
890
|
+
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
891
|
+
if sp.issparse(y):
|
|
892
|
+
raise ValueError("sparse multilabel-indicator for y is not supported.")
|
|
893
|
+
|
|
894
|
+
if sklearn_check_version("1.2"):
|
|
895
|
+
self._validate_params()
|
|
896
|
+
else:
|
|
897
|
+
self._check_parameters()
|
|
898
|
+
|
|
899
|
+
if not self.bootstrap and self.oob_score:
|
|
900
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
901
|
+
|
|
902
|
+
if sklearn_check_version("1.0") and self.criterion == "mse":
|
|
903
|
+
warnings.warn(
|
|
904
|
+
"Criterion 'mse' was deprecated in v1.0 and will be "
|
|
905
|
+
"removed in version 1.2. Use `criterion='squared_error'` "
|
|
906
|
+
"which is equivalent.",
|
|
907
|
+
FutureWarning,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
patching_status.and_conditions(
|
|
911
|
+
[
|
|
912
|
+
(
|
|
913
|
+
self.oob_score
|
|
914
|
+
and daal_check_version((2021, "P", 500))
|
|
915
|
+
or not self.oob_score,
|
|
916
|
+
"OOB score is only supported starting from 2021.5 version of oneDAL.",
|
|
917
|
+
),
|
|
918
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
919
|
+
(
|
|
920
|
+
self.criterion in ["mse", "squared_error"],
|
|
921
|
+
f"'{self.criterion}' criterion is not supported. "
|
|
922
|
+
"Only 'mse' and 'squared_error' criteria are supported.",
|
|
923
|
+
),
|
|
924
|
+
(
|
|
925
|
+
self.ccp_alpha == 0.0,
|
|
926
|
+
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
|
|
927
|
+
),
|
|
928
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
929
|
+
(
|
|
930
|
+
self.n_estimators <= 6024,
|
|
931
|
+
"More than 6024 estimators is not supported.",
|
|
932
|
+
),
|
|
933
|
+
]
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
if patching_status.get_status() and sklearn_check_version("1.4"):
|
|
937
|
+
try:
|
|
938
|
+
_assert_all_finite(X)
|
|
939
|
+
input_is_finite = True
|
|
940
|
+
except ValueError:
|
|
941
|
+
input_is_finite = False
|
|
942
|
+
patching_status.and_conditions(
|
|
943
|
+
[
|
|
944
|
+
(input_is_finite, "Non-finite input is not supported."),
|
|
945
|
+
(
|
|
946
|
+
self.monotonic_cst is None,
|
|
947
|
+
"Monotonicity constraints are not supported.",
|
|
948
|
+
),
|
|
949
|
+
]
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
if patching_status.get_status():
|
|
953
|
+
X, y = check_X_y(
|
|
954
|
+
X,
|
|
955
|
+
y,
|
|
956
|
+
multi_output=True,
|
|
957
|
+
accept_sparse=True,
|
|
958
|
+
dtype=[np.float64, np.float32],
|
|
959
|
+
force_all_finite=False,
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
if y.ndim == 2 and y.shape[1] == 1:
|
|
963
|
+
warnings.warn(
|
|
964
|
+
"A column-vector y was passed when a 1d array was"
|
|
965
|
+
" expected. Please change the shape of y to "
|
|
966
|
+
"(n_samples,), for example using ravel().",
|
|
967
|
+
DataConversionWarning,
|
|
968
|
+
stacklevel=2,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
if y.ndim == 1:
|
|
972
|
+
# reshape is necessary to preserve the data contiguity against vs
|
|
973
|
+
# [:, np.newaxis] that does not.
|
|
974
|
+
y = np.reshape(y, (-1, 1))
|
|
975
|
+
|
|
976
|
+
self.n_outputs_ = y.shape[1]
|
|
977
|
+
|
|
978
|
+
patching_status.and_conditions(
|
|
979
|
+
[
|
|
980
|
+
(
|
|
981
|
+
self.n_outputs_ == 1,
|
|
982
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
983
|
+
)
|
|
984
|
+
]
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
# Sklearn function used for doing checks on max_samples attribute
|
|
988
|
+
_get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
989
|
+
|
|
990
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
991
|
+
raise ValueError(
|
|
992
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
993
|
+
"Either switch to `bootstrap=True` or set "
|
|
994
|
+
"`max_sample=None`."
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
if (
|
|
998
|
+
patching_status.get_status()
|
|
999
|
+
and (self.random_state is not None)
|
|
1000
|
+
and (not daal_check_version((2024, "P", 0)))
|
|
1001
|
+
):
|
|
1002
|
+
warnings.warn(
|
|
1003
|
+
"Setting 'random_state' value is not supported. "
|
|
1004
|
+
"State set by oneDAL to default value (777).",
|
|
1005
|
+
RuntimeWarning,
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
return patching_status, X, y, sample_weight
|
|
1009
|
+
|
|
1010
|
+
def _onedal_cpu_supported(self, method_name, *data):
|
|
1011
|
+
class_name = self.__class__.__name__
|
|
1012
|
+
patching_status = PatchingConditionsChain(
|
|
1013
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
if method_name == "fit":
|
|
1017
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
1018
|
+
patching_status, *data
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
patching_status.and_conditions(
|
|
1022
|
+
[
|
|
1023
|
+
(
|
|
1024
|
+
daal_check_version((2023, "P", 200))
|
|
1025
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1026
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
1027
|
+
),
|
|
1028
|
+
(
|
|
1029
|
+
not sp.issparse(sample_weight),
|
|
1030
|
+
"sample_weight is sparse. " "Sparse input is not supported.",
|
|
1031
|
+
),
|
|
1032
|
+
]
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
elif method_name in ["predict", "score"]:
|
|
1036
|
+
X = data[0]
|
|
1037
|
+
|
|
1038
|
+
patching_status.and_conditions(
|
|
1039
|
+
[
|
|
1040
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
1041
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1042
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
1043
|
+
(
|
|
1044
|
+
daal_check_version((2023, "P", 200))
|
|
1045
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1046
|
+
"ExtraTrees only supported starting from oneDAL version 2023.2",
|
|
1047
|
+
),
|
|
1048
|
+
]
|
|
1049
|
+
)
|
|
1050
|
+
if hasattr(self, "n_outputs_"):
|
|
1051
|
+
patching_status.and_conditions(
|
|
1052
|
+
[
|
|
1053
|
+
(
|
|
1054
|
+
self.n_outputs_ == 1,
|
|
1055
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1056
|
+
),
|
|
1057
|
+
]
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
else:
|
|
1061
|
+
raise RuntimeError(
|
|
1062
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
return patching_status
|
|
1066
|
+
|
|
1067
|
+
def _onedal_gpu_supported(self, method_name, *data):
|
|
1068
|
+
class_name = self.__class__.__name__
|
|
1069
|
+
patching_status = PatchingConditionsChain(
|
|
1070
|
+
f"sklearn.ensemble.{class_name}.{method_name}"
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
if method_name == "fit":
|
|
1074
|
+
patching_status, X, y, sample_weight = self._onedal_fit_ready(
|
|
1075
|
+
patching_status, *data
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
patching_status.and_conditions(
|
|
1079
|
+
[
|
|
1080
|
+
(
|
|
1081
|
+
daal_check_version((2023, "P", 100))
|
|
1082
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1083
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1084
|
+
),
|
|
1085
|
+
(not self.oob_score, "oob_score value is not sklearn conformant."),
|
|
1086
|
+
(sample_weight is None, "sample_weight is not supported."),
|
|
1087
|
+
]
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
elif method_name in ["predict", "score"]:
|
|
1091
|
+
X = data[0]
|
|
1092
|
+
|
|
1093
|
+
patching_status.and_conditions(
|
|
1094
|
+
[
|
|
1095
|
+
(hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
|
|
1096
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1097
|
+
(self.warm_start is False, "Warm start is not supported."),
|
|
1098
|
+
(
|
|
1099
|
+
daal_check_version((2023, "P", 100))
|
|
1100
|
+
or self.estimator.__class__ == DecisionTreeClassifier,
|
|
1101
|
+
"ExtraTrees only supported starting from oneDAL version 2023.1",
|
|
1102
|
+
),
|
|
1103
|
+
]
|
|
1104
|
+
)
|
|
1105
|
+
if hasattr(self, "n_outputs_"):
|
|
1106
|
+
patching_status.and_conditions(
|
|
1107
|
+
[
|
|
1108
|
+
(
|
|
1109
|
+
self.n_outputs_ == 1,
|
|
1110
|
+
f"Number of outputs ({self.n_outputs_}) is not 1.",
|
|
1111
|
+
),
|
|
1112
|
+
]
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
else:
|
|
1116
|
+
raise RuntimeError(
|
|
1117
|
+
f"Unknown method {method_name} in {self.__class__.__name__}"
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
return patching_status
|
|
1121
|
+
|
|
1122
|
+
def _onedal_predict(self, X, queue=None):
|
|
1123
|
+
check_is_fitted(self, "_onedal_estimator")
|
|
1124
|
+
|
|
1125
|
+
if sklearn_check_version("1.0"):
|
|
1126
|
+
X = self._validate_data(
|
|
1127
|
+
X,
|
|
1128
|
+
dtype=[np.float64, np.float32],
|
|
1129
|
+
force_all_finite=False,
|
|
1130
|
+
reset=False,
|
|
1131
|
+
ensure_2d=True,
|
|
1132
|
+
) # Warning, order of dtype matters
|
|
1133
|
+
else:
|
|
1134
|
+
X = check_array(
|
|
1135
|
+
X, dtype=[np.float64, np.float32], force_all_finite=False
|
|
1136
|
+
) # Warning, order of dtype matters
|
|
1137
|
+
|
|
1138
|
+
return self._onedal_estimator.predict(X, queue=queue)
|
|
1139
|
+
|
|
1140
|
+
def _onedal_score(self, X, y, sample_weight=None, queue=None):
|
|
1141
|
+
return r2_score(
|
|
1142
|
+
y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
def fit(self, X, y, sample_weight=None):
|
|
1146
|
+
dispatch(
|
|
1147
|
+
self,
|
|
1148
|
+
"fit",
|
|
1149
|
+
{
|
|
1150
|
+
"onedal": self.__class__._onedal_fit,
|
|
1151
|
+
"sklearn": sklearn_ForestRegressor.fit,
|
|
1152
|
+
},
|
|
1153
|
+
X,
|
|
1154
|
+
y,
|
|
1155
|
+
sample_weight,
|
|
1156
|
+
)
|
|
1157
|
+
return self
|
|
1158
|
+
|
|
1159
|
+
@wrap_output_data
|
|
1160
|
+
def predict(self, X):
|
|
1161
|
+
return dispatch(
|
|
1162
|
+
self,
|
|
1163
|
+
"predict",
|
|
1164
|
+
{
|
|
1165
|
+
"onedal": self.__class__._onedal_predict,
|
|
1166
|
+
"sklearn": sklearn_ForestRegressor.predict,
|
|
1167
|
+
},
|
|
1168
|
+
X,
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
@wrap_output_data
|
|
1172
|
+
def score(self, X, y, sample_weight=None):
|
|
1173
|
+
return dispatch(
|
|
1174
|
+
self,
|
|
1175
|
+
"score",
|
|
1176
|
+
{
|
|
1177
|
+
"onedal": self.__class__._onedal_score,
|
|
1178
|
+
"sklearn": sklearn_ForestRegressor.score,
|
|
1179
|
+
},
|
|
1180
|
+
X,
|
|
1181
|
+
y,
|
|
1182
|
+
sample_weight=sample_weight,
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
fit.__doc__ = sklearn_ForestRegressor.fit.__doc__
|
|
1186
|
+
predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
|
|
1187
|
+
score.__doc__ = sklearn_ForestRegressor.score.__doc__
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1191
|
+
class RandomForestClassifier(ForestClassifier):
|
|
1192
|
+
__doc__ = sklearn_RandomForestClassifier.__doc__
|
|
1193
|
+
_onedal_factory = onedal_RandomForestClassifier
|
|
1194
|
+
|
|
1195
|
+
if sklearn_check_version("1.2"):
|
|
1196
|
+
_parameter_constraints: dict = {
|
|
1197
|
+
**sklearn_RandomForestClassifier._parameter_constraints,
|
|
1198
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1199
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
if sklearn_check_version("1.4"):
|
|
1203
|
+
|
|
1204
|
+
def __init__(
|
|
1205
|
+
self,
|
|
1206
|
+
n_estimators=100,
|
|
1207
|
+
*,
|
|
1208
|
+
criterion="gini",
|
|
1209
|
+
max_depth=None,
|
|
1210
|
+
min_samples_split=2,
|
|
1211
|
+
min_samples_leaf=1,
|
|
1212
|
+
min_weight_fraction_leaf=0.0,
|
|
1213
|
+
max_features="sqrt",
|
|
1214
|
+
max_leaf_nodes=None,
|
|
1215
|
+
min_impurity_decrease=0.0,
|
|
1216
|
+
bootstrap=True,
|
|
1217
|
+
oob_score=False,
|
|
1218
|
+
n_jobs=None,
|
|
1219
|
+
random_state=None,
|
|
1220
|
+
verbose=0,
|
|
1221
|
+
warm_start=False,
|
|
1222
|
+
class_weight=None,
|
|
1223
|
+
ccp_alpha=0.0,
|
|
1224
|
+
max_samples=None,
|
|
1225
|
+
monotonic_cst=None,
|
|
1226
|
+
max_bins=256,
|
|
1227
|
+
min_bin_size=1,
|
|
1228
|
+
):
|
|
1229
|
+
super().__init__(
|
|
1230
|
+
DecisionTreeClassifier(),
|
|
1231
|
+
n_estimators,
|
|
1232
|
+
estimator_params=(
|
|
1233
|
+
"criterion",
|
|
1234
|
+
"max_depth",
|
|
1235
|
+
"min_samples_split",
|
|
1236
|
+
"min_samples_leaf",
|
|
1237
|
+
"min_weight_fraction_leaf",
|
|
1238
|
+
"max_features",
|
|
1239
|
+
"max_leaf_nodes",
|
|
1240
|
+
"min_impurity_decrease",
|
|
1241
|
+
"random_state",
|
|
1242
|
+
"ccp_alpha",
|
|
1243
|
+
"monotonic_cst",
|
|
1244
|
+
),
|
|
1245
|
+
bootstrap=bootstrap,
|
|
1246
|
+
oob_score=oob_score,
|
|
1247
|
+
n_jobs=n_jobs,
|
|
1248
|
+
random_state=random_state,
|
|
1249
|
+
verbose=verbose,
|
|
1250
|
+
warm_start=warm_start,
|
|
1251
|
+
class_weight=class_weight,
|
|
1252
|
+
max_samples=max_samples,
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
self.criterion = criterion
|
|
1256
|
+
self.max_depth = max_depth
|
|
1257
|
+
self.min_samples_split = min_samples_split
|
|
1258
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1259
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1260
|
+
self.max_features = max_features
|
|
1261
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1262
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1263
|
+
self.ccp_alpha = ccp_alpha
|
|
1264
|
+
self.max_bins = max_bins
|
|
1265
|
+
self.min_bin_size = min_bin_size
|
|
1266
|
+
self.monotonic_cst = monotonic_cst
|
|
1267
|
+
|
|
1268
|
+
elif sklearn_check_version("1.0"):
|
|
1269
|
+
|
|
1270
|
+
def __init__(
|
|
1271
|
+
self,
|
|
1272
|
+
n_estimators=100,
|
|
1273
|
+
*,
|
|
1274
|
+
criterion="gini",
|
|
1275
|
+
max_depth=None,
|
|
1276
|
+
min_samples_split=2,
|
|
1277
|
+
min_samples_leaf=1,
|
|
1278
|
+
min_weight_fraction_leaf=0.0,
|
|
1279
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1280
|
+
max_leaf_nodes=None,
|
|
1281
|
+
min_impurity_decrease=0.0,
|
|
1282
|
+
bootstrap=True,
|
|
1283
|
+
oob_score=False,
|
|
1284
|
+
n_jobs=None,
|
|
1285
|
+
random_state=None,
|
|
1286
|
+
verbose=0,
|
|
1287
|
+
warm_start=False,
|
|
1288
|
+
class_weight=None,
|
|
1289
|
+
ccp_alpha=0.0,
|
|
1290
|
+
max_samples=None,
|
|
1291
|
+
max_bins=256,
|
|
1292
|
+
min_bin_size=1,
|
|
1293
|
+
):
|
|
1294
|
+
super().__init__(
|
|
1295
|
+
DecisionTreeClassifier(),
|
|
1296
|
+
n_estimators,
|
|
1297
|
+
estimator_params=(
|
|
1298
|
+
"criterion",
|
|
1299
|
+
"max_depth",
|
|
1300
|
+
"min_samples_split",
|
|
1301
|
+
"min_samples_leaf",
|
|
1302
|
+
"min_weight_fraction_leaf",
|
|
1303
|
+
"max_features",
|
|
1304
|
+
"max_leaf_nodes",
|
|
1305
|
+
"min_impurity_decrease",
|
|
1306
|
+
"random_state",
|
|
1307
|
+
"ccp_alpha",
|
|
1308
|
+
),
|
|
1309
|
+
bootstrap=bootstrap,
|
|
1310
|
+
oob_score=oob_score,
|
|
1311
|
+
n_jobs=n_jobs,
|
|
1312
|
+
random_state=random_state,
|
|
1313
|
+
verbose=verbose,
|
|
1314
|
+
warm_start=warm_start,
|
|
1315
|
+
class_weight=class_weight,
|
|
1316
|
+
max_samples=max_samples,
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
self.criterion = criterion
|
|
1320
|
+
self.max_depth = max_depth
|
|
1321
|
+
self.min_samples_split = min_samples_split
|
|
1322
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1323
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1324
|
+
self.max_features = max_features
|
|
1325
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1326
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1327
|
+
self.ccp_alpha = ccp_alpha
|
|
1328
|
+
self.max_bins = max_bins
|
|
1329
|
+
self.min_bin_size = min_bin_size
|
|
1330
|
+
|
|
1331
|
+
else:
|
|
1332
|
+
|
|
1333
|
+
def __init__(
|
|
1334
|
+
self,
|
|
1335
|
+
n_estimators=100,
|
|
1336
|
+
*,
|
|
1337
|
+
criterion="gini",
|
|
1338
|
+
max_depth=None,
|
|
1339
|
+
min_samples_split=2,
|
|
1340
|
+
min_samples_leaf=1,
|
|
1341
|
+
min_weight_fraction_leaf=0.0,
|
|
1342
|
+
max_features="auto",
|
|
1343
|
+
max_leaf_nodes=None,
|
|
1344
|
+
min_impurity_decrease=0.0,
|
|
1345
|
+
min_impurity_split=None,
|
|
1346
|
+
bootstrap=True,
|
|
1347
|
+
oob_score=False,
|
|
1348
|
+
n_jobs=None,
|
|
1349
|
+
random_state=None,
|
|
1350
|
+
verbose=0,
|
|
1351
|
+
warm_start=False,
|
|
1352
|
+
class_weight=None,
|
|
1353
|
+
ccp_alpha=0.0,
|
|
1354
|
+
max_samples=None,
|
|
1355
|
+
max_bins=256,
|
|
1356
|
+
min_bin_size=1,
|
|
1357
|
+
):
|
|
1358
|
+
super().__init__(
|
|
1359
|
+
DecisionTreeClassifier(),
|
|
1360
|
+
n_estimators,
|
|
1361
|
+
estimator_params=(
|
|
1362
|
+
"criterion",
|
|
1363
|
+
"max_depth",
|
|
1364
|
+
"min_samples_split",
|
|
1365
|
+
"min_samples_leaf",
|
|
1366
|
+
"min_weight_fraction_leaf",
|
|
1367
|
+
"max_features",
|
|
1368
|
+
"max_leaf_nodes",
|
|
1369
|
+
"min_impurity_decrease",
|
|
1370
|
+
"min_impurity_split",
|
|
1371
|
+
"random_state",
|
|
1372
|
+
"ccp_alpha",
|
|
1373
|
+
),
|
|
1374
|
+
bootstrap=bootstrap,
|
|
1375
|
+
oob_score=oob_score,
|
|
1376
|
+
n_jobs=n_jobs,
|
|
1377
|
+
random_state=random_state,
|
|
1378
|
+
verbose=verbose,
|
|
1379
|
+
warm_start=warm_start,
|
|
1380
|
+
class_weight=class_weight,
|
|
1381
|
+
max_samples=max_samples,
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
self.criterion = criterion
|
|
1385
|
+
self.max_depth = max_depth
|
|
1386
|
+
self.min_samples_split = min_samples_split
|
|
1387
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1388
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1389
|
+
self.max_features = max_features
|
|
1390
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1391
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1392
|
+
self.min_impurity_split = min_impurity_split
|
|
1393
|
+
self.ccp_alpha = ccp_alpha
|
|
1394
|
+
self.max_bins = max_bins
|
|
1395
|
+
self.min_bin_size = min_bin_size
|
|
1396
|
+
self.max_bins = max_bins
|
|
1397
|
+
self.min_bin_size = min_bin_size
|
|
1398
|
+
|
|
1399
|
+
|
|
1400
|
+
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1401
|
+
class RandomForestRegressor(ForestRegressor):
|
|
1402
|
+
__doc__ = sklearn_RandomForestRegressor.__doc__
|
|
1403
|
+
_onedal_factory = onedal_RandomForestRegressor
|
|
1404
|
+
|
|
1405
|
+
if sklearn_check_version("1.2"):
|
|
1406
|
+
_parameter_constraints: dict = {
|
|
1407
|
+
**sklearn_RandomForestRegressor._parameter_constraints,
|
|
1408
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1409
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1410
|
+
}
|
|
1411
|
+
|
|
1412
|
+
if sklearn_check_version("1.4"):
|
|
1413
|
+
|
|
1414
|
+
def __init__(
|
|
1415
|
+
self,
|
|
1416
|
+
n_estimators=100,
|
|
1417
|
+
*,
|
|
1418
|
+
criterion="squared_error",
|
|
1419
|
+
max_depth=None,
|
|
1420
|
+
min_samples_split=2,
|
|
1421
|
+
min_samples_leaf=1,
|
|
1422
|
+
min_weight_fraction_leaf=0.0,
|
|
1423
|
+
max_features=1.0,
|
|
1424
|
+
max_leaf_nodes=None,
|
|
1425
|
+
min_impurity_decrease=0.0,
|
|
1426
|
+
bootstrap=True,
|
|
1427
|
+
oob_score=False,
|
|
1428
|
+
n_jobs=None,
|
|
1429
|
+
random_state=None,
|
|
1430
|
+
verbose=0,
|
|
1431
|
+
warm_start=False,
|
|
1432
|
+
ccp_alpha=0.0,
|
|
1433
|
+
max_samples=None,
|
|
1434
|
+
monotonic_cst=None,
|
|
1435
|
+
max_bins=256,
|
|
1436
|
+
min_bin_size=1,
|
|
1437
|
+
):
|
|
1438
|
+
super().__init__(
|
|
1439
|
+
DecisionTreeRegressor(),
|
|
1440
|
+
n_estimators=n_estimators,
|
|
1441
|
+
estimator_params=(
|
|
1442
|
+
"criterion",
|
|
1443
|
+
"max_depth",
|
|
1444
|
+
"min_samples_split",
|
|
1445
|
+
"min_samples_leaf",
|
|
1446
|
+
"min_weight_fraction_leaf",
|
|
1447
|
+
"max_features",
|
|
1448
|
+
"max_leaf_nodes",
|
|
1449
|
+
"min_impurity_decrease",
|
|
1450
|
+
"random_state",
|
|
1451
|
+
"ccp_alpha",
|
|
1452
|
+
"monotonic_cst",
|
|
1453
|
+
),
|
|
1454
|
+
bootstrap=bootstrap,
|
|
1455
|
+
oob_score=oob_score,
|
|
1456
|
+
n_jobs=n_jobs,
|
|
1457
|
+
random_state=random_state,
|
|
1458
|
+
verbose=verbose,
|
|
1459
|
+
warm_start=warm_start,
|
|
1460
|
+
max_samples=max_samples,
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
self.criterion = criterion
|
|
1464
|
+
self.max_depth = max_depth
|
|
1465
|
+
self.min_samples_split = min_samples_split
|
|
1466
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1467
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1468
|
+
self.max_features = max_features
|
|
1469
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1470
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1471
|
+
self.ccp_alpha = ccp_alpha
|
|
1472
|
+
self.max_bins = max_bins
|
|
1473
|
+
self.min_bin_size = min_bin_size
|
|
1474
|
+
self.monotonic_cst = monotonic_cst
|
|
1475
|
+
|
|
1476
|
+
elif sklearn_check_version("1.0"):
|
|
1477
|
+
|
|
1478
|
+
def __init__(
|
|
1479
|
+
self,
|
|
1480
|
+
n_estimators=100,
|
|
1481
|
+
*,
|
|
1482
|
+
criterion="squared_error",
|
|
1483
|
+
max_depth=None,
|
|
1484
|
+
min_samples_split=2,
|
|
1485
|
+
min_samples_leaf=1,
|
|
1486
|
+
min_weight_fraction_leaf=0.0,
|
|
1487
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1488
|
+
max_leaf_nodes=None,
|
|
1489
|
+
min_impurity_decrease=0.0,
|
|
1490
|
+
bootstrap=True,
|
|
1491
|
+
oob_score=False,
|
|
1492
|
+
n_jobs=None,
|
|
1493
|
+
random_state=None,
|
|
1494
|
+
verbose=0,
|
|
1495
|
+
warm_start=False,
|
|
1496
|
+
ccp_alpha=0.0,
|
|
1497
|
+
max_samples=None,
|
|
1498
|
+
max_bins=256,
|
|
1499
|
+
min_bin_size=1,
|
|
1500
|
+
):
|
|
1501
|
+
super().__init__(
|
|
1502
|
+
DecisionTreeRegressor(),
|
|
1503
|
+
n_estimators=n_estimators,
|
|
1504
|
+
estimator_params=(
|
|
1505
|
+
"criterion",
|
|
1506
|
+
"max_depth",
|
|
1507
|
+
"min_samples_split",
|
|
1508
|
+
"min_samples_leaf",
|
|
1509
|
+
"min_weight_fraction_leaf",
|
|
1510
|
+
"max_features",
|
|
1511
|
+
"max_leaf_nodes",
|
|
1512
|
+
"min_impurity_decrease",
|
|
1513
|
+
"random_state",
|
|
1514
|
+
"ccp_alpha",
|
|
1515
|
+
),
|
|
1516
|
+
bootstrap=bootstrap,
|
|
1517
|
+
oob_score=oob_score,
|
|
1518
|
+
n_jobs=n_jobs,
|
|
1519
|
+
random_state=random_state,
|
|
1520
|
+
verbose=verbose,
|
|
1521
|
+
warm_start=warm_start,
|
|
1522
|
+
max_samples=max_samples,
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
self.criterion = criterion
|
|
1526
|
+
self.max_depth = max_depth
|
|
1527
|
+
self.min_samples_split = min_samples_split
|
|
1528
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1529
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1530
|
+
self.max_features = max_features
|
|
1531
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1532
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1533
|
+
self.ccp_alpha = ccp_alpha
|
|
1534
|
+
self.max_bins = max_bins
|
|
1535
|
+
self.min_bin_size = min_bin_size
|
|
1536
|
+
|
|
1537
|
+
else:
|
|
1538
|
+
|
|
1539
|
+
def __init__(
|
|
1540
|
+
self,
|
|
1541
|
+
n_estimators=100,
|
|
1542
|
+
*,
|
|
1543
|
+
criterion="mse",
|
|
1544
|
+
max_depth=None,
|
|
1545
|
+
min_samples_split=2,
|
|
1546
|
+
min_samples_leaf=1,
|
|
1547
|
+
min_weight_fraction_leaf=0.0,
|
|
1548
|
+
max_features="auto",
|
|
1549
|
+
max_leaf_nodes=None,
|
|
1550
|
+
min_impurity_decrease=0.0,
|
|
1551
|
+
min_impurity_split=None,
|
|
1552
|
+
bootstrap=True,
|
|
1553
|
+
oob_score=False,
|
|
1554
|
+
n_jobs=None,
|
|
1555
|
+
random_state=None,
|
|
1556
|
+
verbose=0,
|
|
1557
|
+
warm_start=False,
|
|
1558
|
+
ccp_alpha=0.0,
|
|
1559
|
+
max_samples=None,
|
|
1560
|
+
max_bins=256,
|
|
1561
|
+
min_bin_size=1,
|
|
1562
|
+
):
|
|
1563
|
+
super().__init__(
|
|
1564
|
+
DecisionTreeRegressor(),
|
|
1565
|
+
n_estimators=n_estimators,
|
|
1566
|
+
estimator_params=(
|
|
1567
|
+
"criterion",
|
|
1568
|
+
"max_depth",
|
|
1569
|
+
"min_samples_split",
|
|
1570
|
+
"min_samples_leaf",
|
|
1571
|
+
"min_weight_fraction_leaf",
|
|
1572
|
+
"max_features",
|
|
1573
|
+
"max_leaf_nodes",
|
|
1574
|
+
"min_impurity_decrease",
|
|
1575
|
+
"min_impurity_split" "random_state",
|
|
1576
|
+
"ccp_alpha",
|
|
1577
|
+
),
|
|
1578
|
+
bootstrap=bootstrap,
|
|
1579
|
+
oob_score=oob_score,
|
|
1580
|
+
n_jobs=n_jobs,
|
|
1581
|
+
random_state=random_state,
|
|
1582
|
+
verbose=verbose,
|
|
1583
|
+
warm_start=warm_start,
|
|
1584
|
+
max_samples=max_samples,
|
|
1585
|
+
)
|
|
1586
|
+
|
|
1587
|
+
self.criterion = criterion
|
|
1588
|
+
self.max_depth = max_depth
|
|
1589
|
+
self.min_samples_split = min_samples_split
|
|
1590
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1591
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1592
|
+
self.max_features = max_features
|
|
1593
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1594
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1595
|
+
self.min_impurity_split = min_impurity_split
|
|
1596
|
+
self.ccp_alpha = ccp_alpha
|
|
1597
|
+
self.max_bins = max_bins
|
|
1598
|
+
self.min_bin_size = min_bin_size
|
|
1599
|
+
|
|
1600
|
+
|
|
1601
|
+
@control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
|
|
1602
|
+
class ExtraTreesClassifier(ForestClassifier):
|
|
1603
|
+
__doc__ = sklearn_ExtraTreesClassifier.__doc__
|
|
1604
|
+
_onedal_factory = onedal_ExtraTreesClassifier
|
|
1605
|
+
|
|
1606
|
+
if sklearn_check_version("1.2"):
|
|
1607
|
+
_parameter_constraints: dict = {
|
|
1608
|
+
**sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
1609
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1610
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1611
|
+
}
|
|
1612
|
+
|
|
1613
|
+
if sklearn_check_version("1.4"):
|
|
1614
|
+
|
|
1615
|
+
def __init__(
|
|
1616
|
+
self,
|
|
1617
|
+
n_estimators=100,
|
|
1618
|
+
*,
|
|
1619
|
+
criterion="gini",
|
|
1620
|
+
max_depth=None,
|
|
1621
|
+
min_samples_split=2,
|
|
1622
|
+
min_samples_leaf=1,
|
|
1623
|
+
min_weight_fraction_leaf=0.0,
|
|
1624
|
+
max_features="sqrt",
|
|
1625
|
+
max_leaf_nodes=None,
|
|
1626
|
+
min_impurity_decrease=0.0,
|
|
1627
|
+
bootstrap=False,
|
|
1628
|
+
oob_score=False,
|
|
1629
|
+
n_jobs=None,
|
|
1630
|
+
random_state=None,
|
|
1631
|
+
verbose=0,
|
|
1632
|
+
warm_start=False,
|
|
1633
|
+
class_weight=None,
|
|
1634
|
+
ccp_alpha=0.0,
|
|
1635
|
+
max_samples=None,
|
|
1636
|
+
monotonic_cst=None,
|
|
1637
|
+
max_bins=256,
|
|
1638
|
+
min_bin_size=1,
|
|
1639
|
+
):
|
|
1640
|
+
super().__init__(
|
|
1641
|
+
ExtraTreeClassifier(),
|
|
1642
|
+
n_estimators,
|
|
1643
|
+
estimator_params=(
|
|
1644
|
+
"criterion",
|
|
1645
|
+
"max_depth",
|
|
1646
|
+
"min_samples_split",
|
|
1647
|
+
"min_samples_leaf",
|
|
1648
|
+
"min_weight_fraction_leaf",
|
|
1649
|
+
"max_features",
|
|
1650
|
+
"max_leaf_nodes",
|
|
1651
|
+
"min_impurity_decrease",
|
|
1652
|
+
"random_state",
|
|
1653
|
+
"ccp_alpha",
|
|
1654
|
+
"monotonic_cst",
|
|
1655
|
+
),
|
|
1656
|
+
bootstrap=bootstrap,
|
|
1657
|
+
oob_score=oob_score,
|
|
1658
|
+
n_jobs=n_jobs,
|
|
1659
|
+
random_state=random_state,
|
|
1660
|
+
verbose=verbose,
|
|
1661
|
+
warm_start=warm_start,
|
|
1662
|
+
class_weight=class_weight,
|
|
1663
|
+
max_samples=max_samples,
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
self.criterion = criterion
|
|
1667
|
+
self.max_depth = max_depth
|
|
1668
|
+
self.min_samples_split = min_samples_split
|
|
1669
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1670
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1671
|
+
self.max_features = max_features
|
|
1672
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1673
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1674
|
+
self.ccp_alpha = ccp_alpha
|
|
1675
|
+
self.max_bins = max_bins
|
|
1676
|
+
self.min_bin_size = min_bin_size
|
|
1677
|
+
self.monotonic_cst = monotonic_cst
|
|
1678
|
+
|
|
1679
|
+
elif sklearn_check_version("1.0"):
|
|
1680
|
+
|
|
1681
|
+
def __init__(
|
|
1682
|
+
self,
|
|
1683
|
+
n_estimators=100,
|
|
1684
|
+
*,
|
|
1685
|
+
criterion="gini",
|
|
1686
|
+
max_depth=None,
|
|
1687
|
+
min_samples_split=2,
|
|
1688
|
+
min_samples_leaf=1,
|
|
1689
|
+
min_weight_fraction_leaf=0.0,
|
|
1690
|
+
max_features="sqrt" if sklearn_check_version("1.1") else "auto",
|
|
1691
|
+
max_leaf_nodes=None,
|
|
1692
|
+
min_impurity_decrease=0.0,
|
|
1693
|
+
bootstrap=False,
|
|
1694
|
+
oob_score=False,
|
|
1695
|
+
n_jobs=None,
|
|
1696
|
+
random_state=None,
|
|
1697
|
+
verbose=0,
|
|
1698
|
+
warm_start=False,
|
|
1699
|
+
class_weight=None,
|
|
1700
|
+
ccp_alpha=0.0,
|
|
1701
|
+
max_samples=None,
|
|
1702
|
+
max_bins=256,
|
|
1703
|
+
min_bin_size=1,
|
|
1704
|
+
):
|
|
1705
|
+
super().__init__(
|
|
1706
|
+
ExtraTreeClassifier(),
|
|
1707
|
+
n_estimators,
|
|
1708
|
+
estimator_params=(
|
|
1709
|
+
"criterion",
|
|
1710
|
+
"max_depth",
|
|
1711
|
+
"min_samples_split",
|
|
1712
|
+
"min_samples_leaf",
|
|
1713
|
+
"min_weight_fraction_leaf",
|
|
1714
|
+
"max_features",
|
|
1715
|
+
"max_leaf_nodes",
|
|
1716
|
+
"min_impurity_decrease",
|
|
1717
|
+
"random_state",
|
|
1718
|
+
"ccp_alpha",
|
|
1719
|
+
),
|
|
1720
|
+
bootstrap=bootstrap,
|
|
1721
|
+
oob_score=oob_score,
|
|
1722
|
+
n_jobs=n_jobs,
|
|
1723
|
+
random_state=random_state,
|
|
1724
|
+
verbose=verbose,
|
|
1725
|
+
warm_start=warm_start,
|
|
1726
|
+
class_weight=class_weight,
|
|
1727
|
+
max_samples=max_samples,
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
self.criterion = criterion
|
|
1731
|
+
self.max_depth = max_depth
|
|
1732
|
+
self.min_samples_split = min_samples_split
|
|
1733
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1734
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1735
|
+
self.max_features = max_features
|
|
1736
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1737
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1738
|
+
self.ccp_alpha = ccp_alpha
|
|
1739
|
+
self.max_bins = max_bins
|
|
1740
|
+
self.min_bin_size = min_bin_size
|
|
1741
|
+
|
|
1742
|
+
else:
|
|
1743
|
+
|
|
1744
|
+
def __init__(
|
|
1745
|
+
self,
|
|
1746
|
+
n_estimators=100,
|
|
1747
|
+
*,
|
|
1748
|
+
criterion="gini",
|
|
1749
|
+
max_depth=None,
|
|
1750
|
+
min_samples_split=2,
|
|
1751
|
+
min_samples_leaf=1,
|
|
1752
|
+
min_weight_fraction_leaf=0.0,
|
|
1753
|
+
max_features="auto",
|
|
1754
|
+
max_leaf_nodes=None,
|
|
1755
|
+
min_impurity_decrease=0.0,
|
|
1756
|
+
min_impurity_split=None,
|
|
1757
|
+
bootstrap=False,
|
|
1758
|
+
oob_score=False,
|
|
1759
|
+
n_jobs=None,
|
|
1760
|
+
random_state=None,
|
|
1761
|
+
verbose=0,
|
|
1762
|
+
warm_start=False,
|
|
1763
|
+
class_weight=None,
|
|
1764
|
+
ccp_alpha=0.0,
|
|
1765
|
+
max_samples=None,
|
|
1766
|
+
max_bins=256,
|
|
1767
|
+
min_bin_size=1,
|
|
1768
|
+
):
|
|
1769
|
+
super().__init__(
|
|
1770
|
+
ExtraTreeClassifier(),
|
|
1771
|
+
n_estimators,
|
|
1772
|
+
estimator_params=(
|
|
1773
|
+
"criterion",
|
|
1774
|
+
"max_depth",
|
|
1775
|
+
"min_samples_split",
|
|
1776
|
+
"min_samples_leaf",
|
|
1777
|
+
"min_weight_fraction_leaf",
|
|
1778
|
+
"max_features",
|
|
1779
|
+
"max_leaf_nodes",
|
|
1780
|
+
"min_impurity_decrease",
|
|
1781
|
+
"min_impurity_split",
|
|
1782
|
+
"random_state",
|
|
1783
|
+
"ccp_alpha",
|
|
1784
|
+
),
|
|
1785
|
+
bootstrap=bootstrap,
|
|
1786
|
+
oob_score=oob_score,
|
|
1787
|
+
n_jobs=n_jobs,
|
|
1788
|
+
random_state=random_state,
|
|
1789
|
+
verbose=verbose,
|
|
1790
|
+
warm_start=warm_start,
|
|
1791
|
+
class_weight=class_weight,
|
|
1792
|
+
max_samples=max_samples,
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
self.criterion = criterion
|
|
1796
|
+
self.max_depth = max_depth
|
|
1797
|
+
self.min_samples_split = min_samples_split
|
|
1798
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1799
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1800
|
+
self.max_features = max_features
|
|
1801
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1802
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1803
|
+
self.min_impurity_split = min_impurity_split
|
|
1804
|
+
self.ccp_alpha = ccp_alpha
|
|
1805
|
+
self.max_bins = max_bins
|
|
1806
|
+
self.min_bin_size = min_bin_size
|
|
1807
|
+
self.max_bins = max_bins
|
|
1808
|
+
self.min_bin_size = min_bin_size
|
|
1809
|
+
|
|
1810
|
+
|
|
1811
|
+
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
1812
|
+
class ExtraTreesRegressor(ForestRegressor):
|
|
1813
|
+
__doc__ = sklearn_ExtraTreesRegressor.__doc__
|
|
1814
|
+
_onedal_factory = onedal_ExtraTreesRegressor
|
|
1815
|
+
|
|
1816
|
+
if sklearn_check_version("1.2"):
|
|
1817
|
+
_parameter_constraints: dict = {
|
|
1818
|
+
**sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
1819
|
+
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
1820
|
+
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
if sklearn_check_version("1.4"):
|
|
1824
|
+
|
|
1825
|
+
def __init__(
|
|
1826
|
+
self,
|
|
1827
|
+
n_estimators=100,
|
|
1828
|
+
*,
|
|
1829
|
+
criterion="squared_error",
|
|
1830
|
+
max_depth=None,
|
|
1831
|
+
min_samples_split=2,
|
|
1832
|
+
min_samples_leaf=1,
|
|
1833
|
+
min_weight_fraction_leaf=0.0,
|
|
1834
|
+
max_features=1.0,
|
|
1835
|
+
max_leaf_nodes=None,
|
|
1836
|
+
min_impurity_decrease=0.0,
|
|
1837
|
+
bootstrap=False,
|
|
1838
|
+
oob_score=False,
|
|
1839
|
+
n_jobs=None,
|
|
1840
|
+
random_state=None,
|
|
1841
|
+
verbose=0,
|
|
1842
|
+
warm_start=False,
|
|
1843
|
+
ccp_alpha=0.0,
|
|
1844
|
+
max_samples=None,
|
|
1845
|
+
monotonic_cst=None,
|
|
1846
|
+
max_bins=256,
|
|
1847
|
+
min_bin_size=1,
|
|
1848
|
+
):
|
|
1849
|
+
super().__init__(
|
|
1850
|
+
ExtraTreeRegressor(),
|
|
1851
|
+
n_estimators=n_estimators,
|
|
1852
|
+
estimator_params=(
|
|
1853
|
+
"criterion",
|
|
1854
|
+
"max_depth",
|
|
1855
|
+
"min_samples_split",
|
|
1856
|
+
"min_samples_leaf",
|
|
1857
|
+
"min_weight_fraction_leaf",
|
|
1858
|
+
"max_features",
|
|
1859
|
+
"max_leaf_nodes",
|
|
1860
|
+
"min_impurity_decrease",
|
|
1861
|
+
"random_state",
|
|
1862
|
+
"ccp_alpha",
|
|
1863
|
+
"monotonic_cst",
|
|
1864
|
+
),
|
|
1865
|
+
bootstrap=bootstrap,
|
|
1866
|
+
oob_score=oob_score,
|
|
1867
|
+
n_jobs=n_jobs,
|
|
1868
|
+
random_state=random_state,
|
|
1869
|
+
verbose=verbose,
|
|
1870
|
+
warm_start=warm_start,
|
|
1871
|
+
max_samples=max_samples,
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
self.criterion = criterion
|
|
1875
|
+
self.max_depth = max_depth
|
|
1876
|
+
self.min_samples_split = min_samples_split
|
|
1877
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1878
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1879
|
+
self.max_features = max_features
|
|
1880
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1881
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1882
|
+
self.ccp_alpha = ccp_alpha
|
|
1883
|
+
self.max_bins = max_bins
|
|
1884
|
+
self.min_bin_size = min_bin_size
|
|
1885
|
+
self.monotonic_cst = monotonic_cst
|
|
1886
|
+
|
|
1887
|
+
elif sklearn_check_version("1.0"):
|
|
1888
|
+
|
|
1889
|
+
def __init__(
|
|
1890
|
+
self,
|
|
1891
|
+
n_estimators=100,
|
|
1892
|
+
*,
|
|
1893
|
+
criterion="squared_error",
|
|
1894
|
+
max_depth=None,
|
|
1895
|
+
min_samples_split=2,
|
|
1896
|
+
min_samples_leaf=1,
|
|
1897
|
+
min_weight_fraction_leaf=0.0,
|
|
1898
|
+
max_features=1.0 if sklearn_check_version("1.1") else "auto",
|
|
1899
|
+
max_leaf_nodes=None,
|
|
1900
|
+
min_impurity_decrease=0.0,
|
|
1901
|
+
bootstrap=False,
|
|
1902
|
+
oob_score=False,
|
|
1903
|
+
n_jobs=None,
|
|
1904
|
+
random_state=None,
|
|
1905
|
+
verbose=0,
|
|
1906
|
+
warm_start=False,
|
|
1907
|
+
ccp_alpha=0.0,
|
|
1908
|
+
max_samples=None,
|
|
1909
|
+
max_bins=256,
|
|
1910
|
+
min_bin_size=1,
|
|
1911
|
+
):
|
|
1912
|
+
super().__init__(
|
|
1913
|
+
ExtraTreeRegressor(),
|
|
1914
|
+
n_estimators=n_estimators,
|
|
1915
|
+
estimator_params=(
|
|
1916
|
+
"criterion",
|
|
1917
|
+
"max_depth",
|
|
1918
|
+
"min_samples_split",
|
|
1919
|
+
"min_samples_leaf",
|
|
1920
|
+
"min_weight_fraction_leaf",
|
|
1921
|
+
"max_features",
|
|
1922
|
+
"max_leaf_nodes",
|
|
1923
|
+
"min_impurity_decrease",
|
|
1924
|
+
"random_state",
|
|
1925
|
+
"ccp_alpha",
|
|
1926
|
+
),
|
|
1927
|
+
bootstrap=bootstrap,
|
|
1928
|
+
oob_score=oob_score,
|
|
1929
|
+
n_jobs=n_jobs,
|
|
1930
|
+
random_state=random_state,
|
|
1931
|
+
verbose=verbose,
|
|
1932
|
+
warm_start=warm_start,
|
|
1933
|
+
max_samples=max_samples,
|
|
1934
|
+
)
|
|
1935
|
+
|
|
1936
|
+
self.criterion = criterion
|
|
1937
|
+
self.max_depth = max_depth
|
|
1938
|
+
self.min_samples_split = min_samples_split
|
|
1939
|
+
self.min_samples_leaf = min_samples_leaf
|
|
1940
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
1941
|
+
self.max_features = max_features
|
|
1942
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
1943
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
1944
|
+
self.ccp_alpha = ccp_alpha
|
|
1945
|
+
self.max_bins = max_bins
|
|
1946
|
+
self.min_bin_size = min_bin_size
|
|
1947
|
+
|
|
1948
|
+
else:
|
|
1949
|
+
|
|
1950
|
+
def __init__(
|
|
1951
|
+
self,
|
|
1952
|
+
n_estimators=100,
|
|
1953
|
+
*,
|
|
1954
|
+
criterion="mse",
|
|
1955
|
+
max_depth=None,
|
|
1956
|
+
min_samples_split=2,
|
|
1957
|
+
min_samples_leaf=1,
|
|
1958
|
+
min_weight_fraction_leaf=0.0,
|
|
1959
|
+
max_features="auto",
|
|
1960
|
+
max_leaf_nodes=None,
|
|
1961
|
+
min_impurity_decrease=0.0,
|
|
1962
|
+
min_impurity_split=None,
|
|
1963
|
+
bootstrap=False,
|
|
1964
|
+
oob_score=False,
|
|
1965
|
+
n_jobs=None,
|
|
1966
|
+
random_state=None,
|
|
1967
|
+
verbose=0,
|
|
1968
|
+
warm_start=False,
|
|
1969
|
+
ccp_alpha=0.0,
|
|
1970
|
+
max_samples=None,
|
|
1971
|
+
max_bins=256,
|
|
1972
|
+
min_bin_size=1,
|
|
1973
|
+
):
|
|
1974
|
+
super().__init__(
|
|
1975
|
+
ExtraTreeRegressor(),
|
|
1976
|
+
n_estimators=n_estimators,
|
|
1977
|
+
estimator_params=(
|
|
1978
|
+
"criterion",
|
|
1979
|
+
"max_depth",
|
|
1980
|
+
"min_samples_split",
|
|
1981
|
+
"min_samples_leaf",
|
|
1982
|
+
"min_weight_fraction_leaf",
|
|
1983
|
+
"max_features",
|
|
1984
|
+
"max_leaf_nodes",
|
|
1985
|
+
"min_impurity_decrease",
|
|
1986
|
+
"min_impurity_split" "random_state",
|
|
1987
|
+
"ccp_alpha",
|
|
1988
|
+
),
|
|
1989
|
+
bootstrap=bootstrap,
|
|
1990
|
+
oob_score=oob_score,
|
|
1991
|
+
n_jobs=n_jobs,
|
|
1992
|
+
random_state=random_state,
|
|
1993
|
+
verbose=verbose,
|
|
1994
|
+
warm_start=warm_start,
|
|
1995
|
+
max_samples=max_samples,
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
self.criterion = criterion
|
|
1999
|
+
self.max_depth = max_depth
|
|
2000
|
+
self.min_samples_split = min_samples_split
|
|
2001
|
+
self.min_samples_leaf = min_samples_leaf
|
|
2002
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
2003
|
+
self.max_features = max_features
|
|
2004
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
2005
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
2006
|
+
self.min_impurity_split = min_impurity_split
|
|
2007
|
+
self.ccp_alpha = ccp_alpha
|
|
2008
|
+
self.max_bins = max_bins
|
|
2009
|
+
self.min_bin_size = min_bin_size
|
|
2010
|
+
|
|
2011
|
+
|
|
2012
|
+
# Allow for isinstance calls without inheritance changes using ABCMeta
|
|
2013
|
+
sklearn_RandomForestClassifier.register(RandomForestClassifier)
|
|
2014
|
+
sklearn_RandomForestRegressor.register(RandomForestRegressor)
|
|
2015
|
+
sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
|
|
2016
|
+
sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)
|