scikit-learn-intelex 2025.1.0__py39-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-39-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-39-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +248 -0
- daal4py/sklearn/_utils.py +245 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +236 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +693 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +54 -0
- onedal/_device_offload.py +222 -0
- onedal/_onedal_py_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-39-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-39-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +110 -0
- onedal/cluster/kmeans.py +564 -0
- onedal/cluster/kmeans_init.py +115 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +59 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +125 -0
- onedal/common/tests/test_policy.py +76 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +125 -0
- onedal/covariance/incremental_covariance.py +146 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +122 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +154 -0
- onedal/datatypes/tests/common.py +126 -0
- onedal/datatypes/tests/test_data.py +414 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +204 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +198 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +727 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +258 -0
- onedal/linear_model/linear_model.py +329 -0
- onedal/linear_model/logistic_regression.py +249 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- onedal/linear_model/tests/test_linear_regression.py +250 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +767 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +153 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +82 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +117 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +97 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +176 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +57 -0
- onedal/tests/utils/_dataframes_support.py +162 -0
- onedal/tests/utils/_device_selection.py +102 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +81 -0
- onedal/utils/_dpep_helpers.py +56 -0
- onedal/utils/validation.py +440 -0
- scikit_learn_intelex-2025.1.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.1.0.dist-info/METADATA +231 -0
- scikit_learn_intelex-2025.1.0.dist-info/RECORD +280 -0
- scikit_learn_intelex-2025.1.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.1.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +66 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +116 -0
- sklearnex/_device_offload.py +126 -0
- sklearnex/_utils.py +132 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +230 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +197 -0
- sklearnex/cluster/k_means.py +395 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +159 -0
- sklearnex/conftest.py +82 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +398 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +425 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +543 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2029 -0
- sklearnex/ensemble/tests/test_forest.py +135 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +482 -0
- sklearnex/linear_model/incremental_ridge.py +425 -0
- sklearnex/linear_model/linear.py +341 -0
- sklearnex/linear_model/logistic_regression.py +413 -0
- sklearnex/linear_model/ridge.py +24 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- sklearnex/linear_model/tests/test_linear.py +167 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +21 -0
- sklearnex/manifold/tests/test_tsne.py +26 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +236 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +231 -0
- sklearnex/neighbors/knn_regression.py +207 -0
- sklearnex/neighbors/knn_unsupervised.py +178 -0
- sklearnex/neighbors/tests/test_neighbors.py +82 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +138 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +233 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- sklearnex/preview/linear_model/__init__.py +19 -0
- sklearnex/preview/linear_model/ridge.py +424 -0
- sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +339 -0
- sklearnex/svm/nusvc.py +371 -0
- sklearnex/svm/nusvr.py +170 -0
- sklearnex/svm/svc.py +399 -0
- sklearnex/svm/svr.py +167 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/test_common.py +390 -0
- sklearnex/tests/test_config.py +123 -0
- sklearnex/tests/test_memory_usage.py +379 -0
- sklearnex/tests/test_monkeypatch.py +276 -0
- sklearnex/tests/test_n_jobs_support.py +108 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +385 -0
- sklearnex/tests/test_run_to_run_stability.py +321 -0
- sklearnex/tests/utils/__init__.py +44 -0
- sklearnex/tests/utils/base.py +371 -0
- sklearnex/tests/utils/spmd.py +198 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_finite.py +89 -0
- sklearnex/utils/validation.py +17 -0
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
import warnings
|
|
19
|
+
from abc import ABCMeta, abstractmethod
|
|
20
|
+
from math import ceil
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from sklearn.ensemble import BaseEnsemble
|
|
24
|
+
from sklearn.utils import check_random_state
|
|
25
|
+
|
|
26
|
+
from daal4py.sklearn._utils import daal_check_version
|
|
27
|
+
from sklearnex import get_hyperparameters
|
|
28
|
+
|
|
29
|
+
from ..common._base import BaseEstimator
|
|
30
|
+
from ..common._estimator_checks import _check_is_fitted
|
|
31
|
+
from ..common._mixin import ClassifierMixin, RegressorMixin
|
|
32
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
33
|
+
from ..utils import (
|
|
34
|
+
_check_array,
|
|
35
|
+
_check_n_features,
|
|
36
|
+
_check_X_y,
|
|
37
|
+
_column_or_1d,
|
|
38
|
+
_validate_targets,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta):
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
n_estimators,
|
|
47
|
+
criterion,
|
|
48
|
+
max_depth,
|
|
49
|
+
min_samples_split,
|
|
50
|
+
min_samples_leaf,
|
|
51
|
+
min_weight_fraction_leaf,
|
|
52
|
+
max_features,
|
|
53
|
+
max_leaf_nodes,
|
|
54
|
+
min_impurity_decrease,
|
|
55
|
+
min_impurity_split,
|
|
56
|
+
bootstrap,
|
|
57
|
+
oob_score,
|
|
58
|
+
random_state,
|
|
59
|
+
warm_start,
|
|
60
|
+
class_weight,
|
|
61
|
+
ccp_alpha,
|
|
62
|
+
max_samples,
|
|
63
|
+
max_bins,
|
|
64
|
+
min_bin_size,
|
|
65
|
+
infer_mode,
|
|
66
|
+
splitter_mode,
|
|
67
|
+
voting_mode,
|
|
68
|
+
error_metric_mode,
|
|
69
|
+
variable_importance_mode,
|
|
70
|
+
algorithm,
|
|
71
|
+
**kwargs,
|
|
72
|
+
):
|
|
73
|
+
self.n_estimators = n_estimators
|
|
74
|
+
self.bootstrap = bootstrap
|
|
75
|
+
self.oob_score = oob_score
|
|
76
|
+
self.random_state = random_state
|
|
77
|
+
self.warm_start = warm_start
|
|
78
|
+
self.class_weight = class_weight
|
|
79
|
+
self.max_samples = max_samples
|
|
80
|
+
self.criterion = criterion
|
|
81
|
+
self.max_depth = max_depth
|
|
82
|
+
self.min_samples_split = min_samples_split
|
|
83
|
+
self.min_samples_leaf = min_samples_leaf
|
|
84
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
85
|
+
self.max_features = max_features
|
|
86
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
87
|
+
self.min_impurity_decrease = min_impurity_decrease
|
|
88
|
+
self.min_impurity_split = min_impurity_split
|
|
89
|
+
self.ccp_alpha = ccp_alpha
|
|
90
|
+
self.max_bins = max_bins
|
|
91
|
+
self.min_bin_size = min_bin_size
|
|
92
|
+
self.infer_mode = infer_mode
|
|
93
|
+
self.splitter_mode = splitter_mode
|
|
94
|
+
self.voting_mode = voting_mode
|
|
95
|
+
self.error_metric_mode = error_metric_mode
|
|
96
|
+
self.variable_importance_mode = variable_importance_mode
|
|
97
|
+
self.algorithm = algorithm
|
|
98
|
+
|
|
99
|
+
def _to_absolute_max_features(self, n_features):
|
|
100
|
+
if self.max_features is None:
|
|
101
|
+
return n_features
|
|
102
|
+
elif isinstance(self.max_features, str):
|
|
103
|
+
return max(1, int(getattr(np, self.max_features)(n_features)))
|
|
104
|
+
elif isinstance(self.max_features, (numbers.Integral, np.integer)):
|
|
105
|
+
return self.max_features
|
|
106
|
+
elif self.max_features > 0.0:
|
|
107
|
+
return max(1, int(self.max_features * n_features))
|
|
108
|
+
return 0
|
|
109
|
+
|
|
110
|
+
def _get_observations_per_tree_fraction(self, n_samples, max_samples):
|
|
111
|
+
if max_samples is None:
|
|
112
|
+
return 1.0
|
|
113
|
+
|
|
114
|
+
if isinstance(max_samples, numbers.Integral):
|
|
115
|
+
if not (1 <= max_samples <= n_samples):
|
|
116
|
+
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
117
|
+
raise ValueError(msg.format(n_samples, max_samples))
|
|
118
|
+
return max(float(max_samples / n_samples), 1 / n_samples)
|
|
119
|
+
|
|
120
|
+
if isinstance(max_samples, numbers.Real):
|
|
121
|
+
return max(float(max_samples), 1 / n_samples)
|
|
122
|
+
|
|
123
|
+
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
124
|
+
raise TypeError(msg.format(type(max_samples)))
|
|
125
|
+
|
|
126
|
+
def _get_onedal_params(self, data):
|
|
127
|
+
n_samples, n_features = data.shape
|
|
128
|
+
|
|
129
|
+
self.observations_per_tree_fraction = self._get_observations_per_tree_fraction(
|
|
130
|
+
n_samples=n_samples, max_samples=self.max_samples
|
|
131
|
+
)
|
|
132
|
+
self.observations_per_tree_fraction = (
|
|
133
|
+
self.observations_per_tree_fraction if bool(self.bootstrap) else 1.0
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if not self.bootstrap and self.max_samples is not None:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
139
|
+
"Either switch to `bootstrap=True` or set "
|
|
140
|
+
"`max_sample=None`."
|
|
141
|
+
)
|
|
142
|
+
if not self.bootstrap and self.oob_score:
|
|
143
|
+
raise ValueError("Out of bag estimation only available" " if bootstrap=True")
|
|
144
|
+
|
|
145
|
+
min_observations_in_leaf_node = (
|
|
146
|
+
self.min_samples_leaf
|
|
147
|
+
if isinstance(self.min_samples_leaf, numbers.Integral)
|
|
148
|
+
else int(ceil(self.min_samples_leaf * n_samples))
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
min_observations_in_split_node = (
|
|
152
|
+
self.min_samples_split
|
|
153
|
+
if isinstance(self.min_samples_split, numbers.Integral)
|
|
154
|
+
else int(ceil(self.min_samples_split * n_samples))
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
rs = check_random_state(self.random_state)
|
|
158
|
+
seed = rs.randint(0, np.iinfo("i").max)
|
|
159
|
+
|
|
160
|
+
onedal_params = {
|
|
161
|
+
"fptype": "float" if data.dtype == np.float32 else "double",
|
|
162
|
+
"method": self.algorithm,
|
|
163
|
+
"infer_mode": self.infer_mode,
|
|
164
|
+
"voting_mode": self.voting_mode,
|
|
165
|
+
"observations_per_tree_fraction": self.observations_per_tree_fraction,
|
|
166
|
+
"impurity_threshold": float(
|
|
167
|
+
0.0 if self.min_impurity_split is None else self.min_impurity_split
|
|
168
|
+
),
|
|
169
|
+
"min_weight_fraction_in_leaf_node": self.min_weight_fraction_leaf,
|
|
170
|
+
"min_impurity_decrease_in_split_node": self.min_impurity_decrease,
|
|
171
|
+
"tree_count": int(self.n_estimators),
|
|
172
|
+
"features_per_node": self._to_absolute_max_features(n_features),
|
|
173
|
+
"max_tree_depth": int(0 if self.max_depth is None else self.max_depth),
|
|
174
|
+
"min_observations_in_leaf_node": min_observations_in_leaf_node,
|
|
175
|
+
"min_observations_in_split_node": min_observations_in_split_node,
|
|
176
|
+
"max_leaf_nodes": (0 if self.max_leaf_nodes is None else self.max_leaf_nodes),
|
|
177
|
+
"max_bins": self.max_bins,
|
|
178
|
+
"min_bin_size": self.min_bin_size,
|
|
179
|
+
"seed": seed,
|
|
180
|
+
"memory_saving_mode": False,
|
|
181
|
+
"bootstrap": bool(self.bootstrap),
|
|
182
|
+
"error_metric_mode": self.error_metric_mode,
|
|
183
|
+
"variable_importance_mode": self.variable_importance_mode,
|
|
184
|
+
}
|
|
185
|
+
if isinstance(self, ClassifierMixin):
|
|
186
|
+
onedal_params["class_count"] = (
|
|
187
|
+
0 if self.classes_ is None else len(self.classes_)
|
|
188
|
+
)
|
|
189
|
+
if daal_check_version((2023, "P", 101)):
|
|
190
|
+
onedal_params["splitter_mode"] = self.splitter_mode
|
|
191
|
+
return onedal_params
|
|
192
|
+
|
|
193
|
+
def _check_parameters(self):
|
|
194
|
+
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
195
|
+
if not 1 <= self.min_samples_leaf:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
"min_samples_leaf must be at least 1 "
|
|
198
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
199
|
+
)
|
|
200
|
+
else: # float
|
|
201
|
+
if not 0.0 < self.min_samples_leaf <= 0.5:
|
|
202
|
+
raise ValueError(
|
|
203
|
+
"min_samples_leaf must be at least 1 "
|
|
204
|
+
"or in (0, 0.5], got %s" % self.min_samples_leaf
|
|
205
|
+
)
|
|
206
|
+
if isinstance(self.min_samples_split, numbers.Integral):
|
|
207
|
+
if not 2 <= self.min_samples_split:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"min_samples_split must be an integer "
|
|
210
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
211
|
+
"got the integer %s" % self.min_samples_split
|
|
212
|
+
)
|
|
213
|
+
else: # float
|
|
214
|
+
if not 0.0 < self.min_samples_split <= 1.0:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"min_samples_split must be an integer "
|
|
217
|
+
"greater than 1 or a float in (0.0, 1.0]; "
|
|
218
|
+
"got the float %s" % self.min_samples_split
|
|
219
|
+
)
|
|
220
|
+
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
|
|
221
|
+
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
|
|
222
|
+
if self.min_impurity_split is not None:
|
|
223
|
+
warnings.warn(
|
|
224
|
+
"The min_impurity_split parameter is deprecated. "
|
|
225
|
+
"Its default value has changed from 1e-7 to 0 in "
|
|
226
|
+
"version 0.23, and it will be removed in 0.25. "
|
|
227
|
+
"Use the min_impurity_decrease parameter instead.",
|
|
228
|
+
FutureWarning,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.min_impurity_split < 0.0:
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"min_impurity_split must be greater than " "or equal to 0"
|
|
234
|
+
)
|
|
235
|
+
if self.min_impurity_decrease < 0.0:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"min_impurity_decrease must be greater than " "or equal to 0"
|
|
238
|
+
)
|
|
239
|
+
if self.max_leaf_nodes is not None:
|
|
240
|
+
if not isinstance(self.max_leaf_nodes, numbers.Integral):
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"max_leaf_nodes must be integral number but was "
|
|
243
|
+
"%r" % self.max_leaf_nodes
|
|
244
|
+
)
|
|
245
|
+
if self.max_leaf_nodes < 2:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
("max_leaf_nodes {0} must be either None " "or larger than 1").format(
|
|
248
|
+
self.max_leaf_nodes
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
if isinstance(self.max_bins, numbers.Integral):
|
|
252
|
+
if not 2 <= self.max_bins:
|
|
253
|
+
raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
|
|
254
|
+
else:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"max_bins must be integral number but was " "%r" % self.max_bins
|
|
257
|
+
)
|
|
258
|
+
if isinstance(self.min_bin_size, numbers.Integral):
|
|
259
|
+
if not 1 <= self.min_bin_size:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
"min_bin_size must be at least 1, got %s" % self.min_bin_size
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"min_bin_size must be integral number but was " "%r" % self.min_bin_size
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _validate_targets(self, y, dtype):
|
|
269
|
+
self.class_weight_ = None
|
|
270
|
+
self.classes_ = None
|
|
271
|
+
return _column_or_1d(y, warn=True).astype(dtype, copy=False)
|
|
272
|
+
|
|
273
|
+
def _get_sample_weight(self, sample_weight, X):
|
|
274
|
+
sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel()
|
|
275
|
+
|
|
276
|
+
sample_weight = _check_array(
|
|
277
|
+
sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C"
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if sample_weight.size != X.shape[0]:
|
|
281
|
+
raise ValueError(
|
|
282
|
+
"sample_weight and X have incompatible shapes: "
|
|
283
|
+
"%r vs %r\n"
|
|
284
|
+
"Note: Sparse matrices cannot be indexed w/"
|
|
285
|
+
"boolean masks (use `indices=True` in CV)."
|
|
286
|
+
% (sample_weight.shape, X.shape)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return sample_weight
|
|
290
|
+
|
|
291
|
+
def _fit(self, X, y, sample_weight, module, queue):
|
|
292
|
+
X, y = _check_X_y(
|
|
293
|
+
X,
|
|
294
|
+
y,
|
|
295
|
+
dtype=[np.float64, np.float32],
|
|
296
|
+
force_all_finite=True,
|
|
297
|
+
accept_sparse="csr",
|
|
298
|
+
)
|
|
299
|
+
y = self._validate_targets(y, X.dtype)
|
|
300
|
+
|
|
301
|
+
self.n_features_in_ = X.shape[1]
|
|
302
|
+
|
|
303
|
+
if sample_weight is not None and len(sample_weight) > 0:
|
|
304
|
+
sample_weight = self._get_sample_weight(sample_weight, X)
|
|
305
|
+
data = (X, y, sample_weight)
|
|
306
|
+
else:
|
|
307
|
+
data = (X, y)
|
|
308
|
+
policy = self._get_policy(queue, *data)
|
|
309
|
+
data = _convert_to_supported(policy, *data)
|
|
310
|
+
params = self._get_onedal_params(data[0])
|
|
311
|
+
train_result = module.train(policy, params, *to_table(*data))
|
|
312
|
+
|
|
313
|
+
self._onedal_model = train_result.model
|
|
314
|
+
|
|
315
|
+
if self.oob_score:
|
|
316
|
+
if isinstance(self, ClassifierMixin):
|
|
317
|
+
self.oob_score_ = from_table(train_result.oob_err_accuracy).item()
|
|
318
|
+
self.oob_decision_function_ = from_table(
|
|
319
|
+
train_result.oob_err_decision_function
|
|
320
|
+
)
|
|
321
|
+
if np.any(self.oob_decision_function_ == 0):
|
|
322
|
+
warnings.warn(
|
|
323
|
+
"Some inputs do not have OOB scores. This probably means "
|
|
324
|
+
"too few trees were used to compute any reliable OOB "
|
|
325
|
+
"estimates.",
|
|
326
|
+
UserWarning,
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
self.oob_score_ = from_table(train_result.oob_err_r2).item()
|
|
330
|
+
self.oob_prediction_ = from_table(
|
|
331
|
+
train_result.oob_err_prediction
|
|
332
|
+
).reshape(-1)
|
|
333
|
+
if np.any(self.oob_prediction_ == 0):
|
|
334
|
+
warnings.warn(
|
|
335
|
+
"Some inputs do not have OOB scores. This probably means "
|
|
336
|
+
"too few trees were used to compute any reliable OOB "
|
|
337
|
+
"estimates.",
|
|
338
|
+
UserWarning,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
return self
|
|
342
|
+
|
|
343
|
+
def _create_model(self, module):
|
|
344
|
+
# TODO:
|
|
345
|
+
# upate error msg.
|
|
346
|
+
raise NotImplementedError("Creating model is not supported.")
|
|
347
|
+
|
|
348
|
+
def _predict(self, X, module, queue, hparams=None):
|
|
349
|
+
_check_is_fitted(self)
|
|
350
|
+
X = _check_array(
|
|
351
|
+
X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
|
|
352
|
+
)
|
|
353
|
+
_check_n_features(self, X, False)
|
|
354
|
+
policy = self._get_policy(queue, X)
|
|
355
|
+
|
|
356
|
+
model = self._onedal_model
|
|
357
|
+
X = _convert_to_supported(policy, X)
|
|
358
|
+
params = self._get_onedal_params(X)
|
|
359
|
+
if hparams is not None and not hparams.is_default:
|
|
360
|
+
result = module.infer(policy, params, hparams.backend, model, to_table(X))
|
|
361
|
+
else:
|
|
362
|
+
result = module.infer(policy, params, model, to_table(X))
|
|
363
|
+
|
|
364
|
+
y = from_table(result.responses)
|
|
365
|
+
return y
|
|
366
|
+
|
|
367
|
+
def _predict_proba(self, X, module, queue):
|
|
368
|
+
_check_is_fitted(self)
|
|
369
|
+
X = _check_array(
|
|
370
|
+
X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False
|
|
371
|
+
)
|
|
372
|
+
_check_n_features(self, X, False)
|
|
373
|
+
policy = self._get_policy(queue, X)
|
|
374
|
+
X = _convert_to_supported(policy, X)
|
|
375
|
+
params = self._get_onedal_params(X)
|
|
376
|
+
params["infer_mode"] = "class_probabilities"
|
|
377
|
+
|
|
378
|
+
model = self._onedal_model
|
|
379
|
+
result = module.infer(policy, params, model, to_table(X))
|
|
380
|
+
y = from_table(result.probabilities)
|
|
381
|
+
return y
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
|
|
385
|
+
def __init__(
|
|
386
|
+
self,
|
|
387
|
+
n_estimators=100,
|
|
388
|
+
criterion="gini",
|
|
389
|
+
max_depth=None,
|
|
390
|
+
min_samples_split=2,
|
|
391
|
+
min_samples_leaf=1,
|
|
392
|
+
min_weight_fraction_leaf=0.0,
|
|
393
|
+
max_features="sqrt",
|
|
394
|
+
max_leaf_nodes=None,
|
|
395
|
+
min_impurity_decrease=0.0,
|
|
396
|
+
min_impurity_split=None,
|
|
397
|
+
bootstrap=True,
|
|
398
|
+
oob_score=False,
|
|
399
|
+
random_state=None,
|
|
400
|
+
warm_start=False,
|
|
401
|
+
class_weight=None,
|
|
402
|
+
ccp_alpha=0.0,
|
|
403
|
+
max_samples=None,
|
|
404
|
+
max_bins=256,
|
|
405
|
+
min_bin_size=1,
|
|
406
|
+
infer_mode="class_responses",
|
|
407
|
+
splitter_mode="best",
|
|
408
|
+
voting_mode="weighted",
|
|
409
|
+
error_metric_mode="none",
|
|
410
|
+
variable_importance_mode="none",
|
|
411
|
+
algorithm="hist",
|
|
412
|
+
**kwargs,
|
|
413
|
+
):
|
|
414
|
+
super().__init__(
|
|
415
|
+
n_estimators=n_estimators,
|
|
416
|
+
criterion=criterion,
|
|
417
|
+
max_depth=max_depth,
|
|
418
|
+
min_samples_split=min_samples_split,
|
|
419
|
+
min_samples_leaf=min_samples_leaf,
|
|
420
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
421
|
+
max_features=max_features,
|
|
422
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
423
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
424
|
+
min_impurity_split=min_impurity_split,
|
|
425
|
+
bootstrap=bootstrap,
|
|
426
|
+
oob_score=oob_score,
|
|
427
|
+
random_state=random_state,
|
|
428
|
+
warm_start=warm_start,
|
|
429
|
+
class_weight=class_weight,
|
|
430
|
+
ccp_alpha=ccp_alpha,
|
|
431
|
+
max_samples=max_samples,
|
|
432
|
+
max_bins=max_bins,
|
|
433
|
+
min_bin_size=min_bin_size,
|
|
434
|
+
infer_mode=infer_mode,
|
|
435
|
+
splitter_mode=splitter_mode,
|
|
436
|
+
voting_mode=voting_mode,
|
|
437
|
+
error_metric_mode=error_metric_mode,
|
|
438
|
+
variable_importance_mode=variable_importance_mode,
|
|
439
|
+
algorithm=algorithm,
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
def _validate_targets(self, y, dtype):
|
|
443
|
+
y, self.class_weight_, self.classes_ = _validate_targets(
|
|
444
|
+
y, self.class_weight, dtype
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Decapsulate classes_ attributes
|
|
448
|
+
# TODO:
|
|
449
|
+
# align with `n_classes_` and `classes_` attr with daal4py implementations.
|
|
450
|
+
# if hasattr(self, "classes_"):
|
|
451
|
+
# self.n_classes_ = self.classes_
|
|
452
|
+
return y
|
|
453
|
+
|
|
454
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
455
|
+
return self._fit(
|
|
456
|
+
X,
|
|
457
|
+
y,
|
|
458
|
+
sample_weight,
|
|
459
|
+
self._get_backend("decision_forest", "classification", None),
|
|
460
|
+
queue,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
def predict(self, X, queue=None):
|
|
464
|
+
hparams = get_hyperparameters("decision_forest", "infer")
|
|
465
|
+
pred = super()._predict(
|
|
466
|
+
X,
|
|
467
|
+
self._get_backend("decision_forest", "classification", None),
|
|
468
|
+
queue,
|
|
469
|
+
hparams,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
|
|
473
|
+
|
|
474
|
+
def predict_proba(self, X, queue=None):
|
|
475
|
+
return super()._predict_proba(
|
|
476
|
+
X, self._get_backend("decision_forest", "classification", None), queue
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class RandomForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
|
|
481
|
+
def __init__(
|
|
482
|
+
self,
|
|
483
|
+
n_estimators=100,
|
|
484
|
+
criterion="squared_error",
|
|
485
|
+
max_depth=None,
|
|
486
|
+
min_samples_split=2,
|
|
487
|
+
min_samples_leaf=1,
|
|
488
|
+
min_weight_fraction_leaf=0.0,
|
|
489
|
+
max_features=1.0,
|
|
490
|
+
max_leaf_nodes=None,
|
|
491
|
+
min_impurity_decrease=0.0,
|
|
492
|
+
min_impurity_split=None,
|
|
493
|
+
bootstrap=True,
|
|
494
|
+
oob_score=False,
|
|
495
|
+
random_state=None,
|
|
496
|
+
warm_start=False,
|
|
497
|
+
class_weight=None,
|
|
498
|
+
ccp_alpha=0.0,
|
|
499
|
+
max_samples=None,
|
|
500
|
+
max_bins=256,
|
|
501
|
+
min_bin_size=1,
|
|
502
|
+
infer_mode="class_responses",
|
|
503
|
+
splitter_mode="best",
|
|
504
|
+
voting_mode="weighted",
|
|
505
|
+
error_metric_mode="none",
|
|
506
|
+
variable_importance_mode="none",
|
|
507
|
+
algorithm="hist",
|
|
508
|
+
**kwargs,
|
|
509
|
+
):
|
|
510
|
+
super().__init__(
|
|
511
|
+
n_estimators=n_estimators,
|
|
512
|
+
criterion=criterion,
|
|
513
|
+
max_depth=max_depth,
|
|
514
|
+
min_samples_split=min_samples_split,
|
|
515
|
+
min_samples_leaf=min_samples_leaf,
|
|
516
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
517
|
+
max_features=max_features,
|
|
518
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
519
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
520
|
+
min_impurity_split=min_impurity_split,
|
|
521
|
+
bootstrap=bootstrap,
|
|
522
|
+
oob_score=oob_score,
|
|
523
|
+
random_state=random_state,
|
|
524
|
+
warm_start=warm_start,
|
|
525
|
+
class_weight=class_weight,
|
|
526
|
+
ccp_alpha=ccp_alpha,
|
|
527
|
+
max_samples=max_samples,
|
|
528
|
+
max_bins=max_bins,
|
|
529
|
+
min_bin_size=min_bin_size,
|
|
530
|
+
infer_mode=infer_mode,
|
|
531
|
+
splitter_mode=splitter_mode,
|
|
532
|
+
voting_mode=voting_mode,
|
|
533
|
+
error_metric_mode=error_metric_mode,
|
|
534
|
+
variable_importance_mode=variable_importance_mode,
|
|
535
|
+
algorithm=algorithm,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
539
|
+
if sample_weight is not None:
|
|
540
|
+
if hasattr(sample_weight, "__array__"):
|
|
541
|
+
sample_weight[sample_weight == 0.0] = 1.0
|
|
542
|
+
sample_weight = [sample_weight]
|
|
543
|
+
return super()._fit(
|
|
544
|
+
X,
|
|
545
|
+
y,
|
|
546
|
+
sample_weight,
|
|
547
|
+
self._get_backend("decision_forest", "regression", None),
|
|
548
|
+
queue,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
def predict(self, X, queue=None):
|
|
552
|
+
return (
|
|
553
|
+
super()
|
|
554
|
+
._predict(X, self._get_backend("decision_forest", "regression", None), queue)
|
|
555
|
+
.ravel()
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
|
|
560
|
+
def __init__(
|
|
561
|
+
self,
|
|
562
|
+
n_estimators=100,
|
|
563
|
+
criterion="gini",
|
|
564
|
+
max_depth=None,
|
|
565
|
+
min_samples_split=2,
|
|
566
|
+
min_samples_leaf=1,
|
|
567
|
+
min_weight_fraction_leaf=0.0,
|
|
568
|
+
max_features="sqrt",
|
|
569
|
+
max_leaf_nodes=None,
|
|
570
|
+
min_impurity_decrease=0.0,
|
|
571
|
+
min_impurity_split=None,
|
|
572
|
+
bootstrap=False,
|
|
573
|
+
oob_score=False,
|
|
574
|
+
random_state=None,
|
|
575
|
+
warm_start=False,
|
|
576
|
+
class_weight=None,
|
|
577
|
+
ccp_alpha=0.0,
|
|
578
|
+
max_samples=None,
|
|
579
|
+
max_bins=256,
|
|
580
|
+
min_bin_size=1,
|
|
581
|
+
infer_mode="class_responses",
|
|
582
|
+
splitter_mode="random",
|
|
583
|
+
voting_mode="weighted",
|
|
584
|
+
error_metric_mode="none",
|
|
585
|
+
variable_importance_mode="none",
|
|
586
|
+
algorithm="hist",
|
|
587
|
+
**kwargs,
|
|
588
|
+
):
|
|
589
|
+
super().__init__(
|
|
590
|
+
n_estimators=n_estimators,
|
|
591
|
+
criterion=criterion,
|
|
592
|
+
max_depth=max_depth,
|
|
593
|
+
min_samples_split=min_samples_split,
|
|
594
|
+
min_samples_leaf=min_samples_leaf,
|
|
595
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
596
|
+
max_features=max_features,
|
|
597
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
598
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
599
|
+
min_impurity_split=min_impurity_split,
|
|
600
|
+
bootstrap=bootstrap,
|
|
601
|
+
oob_score=oob_score,
|
|
602
|
+
random_state=random_state,
|
|
603
|
+
warm_start=warm_start,
|
|
604
|
+
class_weight=class_weight,
|
|
605
|
+
ccp_alpha=ccp_alpha,
|
|
606
|
+
max_samples=max_samples,
|
|
607
|
+
max_bins=max_bins,
|
|
608
|
+
min_bin_size=min_bin_size,
|
|
609
|
+
infer_mode=infer_mode,
|
|
610
|
+
splitter_mode=splitter_mode,
|
|
611
|
+
voting_mode=voting_mode,
|
|
612
|
+
error_metric_mode=error_metric_mode,
|
|
613
|
+
variable_importance_mode=variable_importance_mode,
|
|
614
|
+
algorithm=algorithm,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def _validate_targets(self, y, dtype):
|
|
618
|
+
y, self.class_weight_, self.classes_ = _validate_targets(
|
|
619
|
+
y, self.class_weight, dtype
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# Decapsulate classes_ attributes
|
|
623
|
+
# TODO:
|
|
624
|
+
# align with `n_classes_` and `classes_` attr with daal4py implementations.
|
|
625
|
+
# if hasattr(self, "classes_"):
|
|
626
|
+
# self.n_classes_ = self.classes_
|
|
627
|
+
return y
|
|
628
|
+
|
|
629
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
630
|
+
return self._fit(
|
|
631
|
+
X,
|
|
632
|
+
y,
|
|
633
|
+
sample_weight,
|
|
634
|
+
self._get_backend("decision_forest", "classification", None),
|
|
635
|
+
queue,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
def predict(self, X, queue=None):
|
|
639
|
+
pred = super()._predict(
|
|
640
|
+
X, self._get_backend("decision_forest", "classification", None), queue
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))
|
|
644
|
+
|
|
645
|
+
def predict_proba(self, X, queue=None):
|
|
646
|
+
return super()._predict_proba(
|
|
647
|
+
X, self._get_backend("decision_forest", "classification", None), queue
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
|
|
652
|
+
def __init__(
|
|
653
|
+
self,
|
|
654
|
+
n_estimators=100,
|
|
655
|
+
criterion="squared_error",
|
|
656
|
+
max_depth=None,
|
|
657
|
+
min_samples_split=2,
|
|
658
|
+
min_samples_leaf=1,
|
|
659
|
+
min_weight_fraction_leaf=0.0,
|
|
660
|
+
max_features=1.0,
|
|
661
|
+
max_leaf_nodes=None,
|
|
662
|
+
min_impurity_decrease=0.0,
|
|
663
|
+
min_impurity_split=None,
|
|
664
|
+
bootstrap=False,
|
|
665
|
+
oob_score=False,
|
|
666
|
+
random_state=None,
|
|
667
|
+
warm_start=False,
|
|
668
|
+
class_weight=None,
|
|
669
|
+
ccp_alpha=0.0,
|
|
670
|
+
max_samples=None,
|
|
671
|
+
max_bins=256,
|
|
672
|
+
min_bin_size=1,
|
|
673
|
+
infer_mode="class_responses",
|
|
674
|
+
splitter_mode="random",
|
|
675
|
+
voting_mode="weighted",
|
|
676
|
+
error_metric_mode="none",
|
|
677
|
+
variable_importance_mode="none",
|
|
678
|
+
algorithm="hist",
|
|
679
|
+
**kwargs,
|
|
680
|
+
):
|
|
681
|
+
super().__init__(
|
|
682
|
+
n_estimators=n_estimators,
|
|
683
|
+
criterion=criterion,
|
|
684
|
+
max_depth=max_depth,
|
|
685
|
+
min_samples_split=min_samples_split,
|
|
686
|
+
min_samples_leaf=min_samples_leaf,
|
|
687
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
688
|
+
max_features=max_features,
|
|
689
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
690
|
+
min_impurity_decrease=min_impurity_decrease,
|
|
691
|
+
min_impurity_split=min_impurity_split,
|
|
692
|
+
bootstrap=bootstrap,
|
|
693
|
+
oob_score=oob_score,
|
|
694
|
+
random_state=random_state,
|
|
695
|
+
warm_start=warm_start,
|
|
696
|
+
class_weight=class_weight,
|
|
697
|
+
ccp_alpha=ccp_alpha,
|
|
698
|
+
max_samples=max_samples,
|
|
699
|
+
max_bins=max_bins,
|
|
700
|
+
min_bin_size=min_bin_size,
|
|
701
|
+
infer_mode=infer_mode,
|
|
702
|
+
splitter_mode=splitter_mode,
|
|
703
|
+
voting_mode=voting_mode,
|
|
704
|
+
error_metric_mode=error_metric_mode,
|
|
705
|
+
variable_importance_mode=variable_importance_mode,
|
|
706
|
+
algorithm=algorithm,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
710
|
+
if sample_weight is not None:
|
|
711
|
+
if hasattr(sample_weight, "__array__"):
|
|
712
|
+
sample_weight[sample_weight == 0.0] = 1.0
|
|
713
|
+
sample_weight = [sample_weight]
|
|
714
|
+
return super()._fit(
|
|
715
|
+
X,
|
|
716
|
+
y,
|
|
717
|
+
sample_weight,
|
|
718
|
+
self._get_backend("decision_forest", "regression", None),
|
|
719
|
+
queue,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
def predict(self, X, queue=None):
|
|
723
|
+
return (
|
|
724
|
+
super()
|
|
725
|
+
._predict(X, self._get_backend("decision_forest", "regression", None), queue)
|
|
726
|
+
.ravel()
|
|
727
|
+
)
|