scikit-learn-intelex 2023.2.1__py39-none-win_amd64.whl → 2024.0.1__py39-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
- scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
- scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py
DELETED
|
@@ -1,1261 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python
|
|
2
|
-
#===============================================================================
|
|
3
|
-
# Copyright 2021 Intel Corporation
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
#===============================================================================
|
|
17
|
-
|
|
18
|
-
from daal4py.sklearn._utils import (
|
|
19
|
-
daal_check_version, sklearn_check_version,
|
|
20
|
-
make2d, PatchingConditionsChain, check_tree_nodes
|
|
21
|
-
)
|
|
22
|
-
|
|
23
|
-
import numpy as np
|
|
24
|
-
|
|
25
|
-
import numbers
|
|
26
|
-
|
|
27
|
-
import warnings
|
|
28
|
-
|
|
29
|
-
from abc import ABC
|
|
30
|
-
|
|
31
|
-
from sklearn.exceptions import DataConversionWarning
|
|
32
|
-
|
|
33
|
-
from ..._config import get_config
|
|
34
|
-
from ..._device_offload import dispatch, wrap_output_data
|
|
35
|
-
|
|
36
|
-
from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
|
|
37
|
-
from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
|
|
38
|
-
|
|
39
|
-
from sklearn.utils.validation import (
|
|
40
|
-
check_is_fitted,
|
|
41
|
-
check_consistent_length,
|
|
42
|
-
check_array,
|
|
43
|
-
check_X_y)
|
|
44
|
-
|
|
45
|
-
from onedal.datatypes import _num_features, _num_samples
|
|
46
|
-
|
|
47
|
-
from sklearn.utils import check_random_state, deprecated
|
|
48
|
-
|
|
49
|
-
from sklearn.base import clone
|
|
50
|
-
|
|
51
|
-
from sklearn.tree import ExtraTreeClassifier, ExtraTreeRegressor
|
|
52
|
-
from sklearn.tree._tree import Tree
|
|
53
|
-
|
|
54
|
-
from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
|
|
55
|
-
from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
|
|
56
|
-
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
|
|
57
|
-
|
|
58
|
-
from scipy import sparse as sp
|
|
59
|
-
|
|
60
|
-
if sklearn_check_version('1.2'):
|
|
61
|
-
from sklearn.utils._param_validation import Interval
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class BaseTree(ABC):
|
|
65
|
-
def _fit_proba(self, X, y, sample_weight=None, queue=None):
|
|
66
|
-
params = self.get_params()
|
|
67
|
-
self.__class__(**params)
|
|
68
|
-
|
|
69
|
-
# We use stock metaestimators below, so the only way
|
|
70
|
-
# to pass a queue is using config_context.
|
|
71
|
-
cfg = get_config()
|
|
72
|
-
cfg['target_offload'] = queue
|
|
73
|
-
|
|
74
|
-
def _save_attributes(self):
|
|
75
|
-
self._onedal_model = self._onedal_estimator._onedal_model
|
|
76
|
-
|
|
77
|
-
if self.oob_score:
|
|
78
|
-
self.oob_score_ = self._onedal_estimator.oob_score_
|
|
79
|
-
if hasattr(self._onedal_estimator, "oob_prediction_"):
|
|
80
|
-
self.oob_prediction_ = self._onedal_estimator.oob_prediction_
|
|
81
|
-
if hasattr(self._onedal_estimator, "oob_decision_function_"):
|
|
82
|
-
self.oob_decision_function_ = \
|
|
83
|
-
self._onedal_estimator.oob_decision_function_
|
|
84
|
-
return self
|
|
85
|
-
|
|
86
|
-
def _onedal_classifier(self, **onedal_params):
|
|
87
|
-
return onedal_ExtraTreesClassifier(**onedal_params)
|
|
88
|
-
|
|
89
|
-
def _onedal_regressor(self, **onedal_params):
|
|
90
|
-
return onedal_ExtraTreesRegressor(**onedal_params)
|
|
91
|
-
|
|
92
|
-
# TODO:
|
|
93
|
-
# move to onedal modul.
|
|
94
|
-
def _check_parameters(self):
|
|
95
|
-
|
|
96
|
-
if isinstance(self.min_samples_leaf, numbers.Integral):
|
|
97
|
-
if not 1 <= self.min_samples_leaf:
|
|
98
|
-
raise ValueError("min_samples_leaf must be at least 1 "
|
|
99
|
-
"or in (0, 0.5], got %s"
|
|
100
|
-
% self.min_samples_leaf)
|
|
101
|
-
else: # float
|
|
102
|
-
if not 0. < self.min_samples_leaf <= 0.5:
|
|
103
|
-
raise ValueError("min_samples_leaf must be at least 1 "
|
|
104
|
-
"or in (0, 0.5], got %s"
|
|
105
|
-
% self.min_samples_leaf)
|
|
106
|
-
if isinstance(self.min_samples_split, numbers.Integral):
|
|
107
|
-
if not 2 <= self.min_samples_split:
|
|
108
|
-
raise ValueError("min_samples_split must be an integer "
|
|
109
|
-
"greater than 1 or a float in (0.0, 1.0]; "
|
|
110
|
-
"got the integer %s"
|
|
111
|
-
% self.min_samples_split)
|
|
112
|
-
else: # float
|
|
113
|
-
if not 0. < self.min_samples_split <= 1.:
|
|
114
|
-
raise ValueError("min_samples_split must be an integer "
|
|
115
|
-
"greater than 1 or a float in (0.0, 1.0]; "
|
|
116
|
-
"got the float %s"
|
|
117
|
-
% self.min_samples_split)
|
|
118
|
-
if not 0 <= self.min_weight_fraction_leaf <= 0.5:
|
|
119
|
-
raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
|
|
120
|
-
if getattr(self, "min_impurity_split", None) is not None:
|
|
121
|
-
warnings.warn("The min_impurity_split parameter is deprecated. "
|
|
122
|
-
"Its default value has changed from 1e-7 to 0 in "
|
|
123
|
-
"version 0.23, and it will be removed in 0.25. "
|
|
124
|
-
"Use the min_impurity_decrease parameter instead.",
|
|
125
|
-
FutureWarning)
|
|
126
|
-
|
|
127
|
-
if getattr(self, "min_impurity_split") < 0.:
|
|
128
|
-
raise ValueError("min_impurity_split must be greater than "
|
|
129
|
-
"or equal to 0")
|
|
130
|
-
if self.min_impurity_decrease < 0.:
|
|
131
|
-
raise ValueError("min_impurity_decrease must be greater than "
|
|
132
|
-
"or equal to 0")
|
|
133
|
-
if self.max_leaf_nodes is not None:
|
|
134
|
-
if not isinstance(self.max_leaf_nodes, numbers.Integral):
|
|
135
|
-
raise ValueError(
|
|
136
|
-
"max_leaf_nodes must be integral number but was "
|
|
137
|
-
"%r" %
|
|
138
|
-
self.max_leaf_nodes)
|
|
139
|
-
if self.max_leaf_nodes < 2:
|
|
140
|
-
raise ValueError(
|
|
141
|
-
("max_leaf_nodes {0} must be either None "
|
|
142
|
-
"or larger than 1").format(
|
|
143
|
-
self.max_leaf_nodes))
|
|
144
|
-
if isinstance(self.max_bins, numbers.Integral):
|
|
145
|
-
if not 2 <= self.max_bins:
|
|
146
|
-
raise ValueError("max_bins must be at least 2, got %s"
|
|
147
|
-
% self.max_bins)
|
|
148
|
-
else:
|
|
149
|
-
raise ValueError("max_bins must be integral number but was "
|
|
150
|
-
"%r" % self.max_bins)
|
|
151
|
-
if isinstance(self.min_bin_size, numbers.Integral):
|
|
152
|
-
if not 1 <= self.min_bin_size:
|
|
153
|
-
raise ValueError("min_bin_size must be at least 1, got %s"
|
|
154
|
-
% self.min_bin_size)
|
|
155
|
-
else:
|
|
156
|
-
raise ValueError("min_bin_size must be integral number but was "
|
|
157
|
-
"%r" % self.min_bin_size)
|
|
158
|
-
|
|
159
|
-
def check_sample_weight(self, sample_weight, X, dtype=None):
|
|
160
|
-
n_samples = _num_samples(X)
|
|
161
|
-
|
|
162
|
-
if dtype is not None and dtype not in [np.float32, np.float64]:
|
|
163
|
-
dtype = np.float64
|
|
164
|
-
|
|
165
|
-
if sample_weight is None:
|
|
166
|
-
sample_weight = np.ones(n_samples, dtype=dtype)
|
|
167
|
-
elif isinstance(sample_weight, numbers.Number):
|
|
168
|
-
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
|
|
169
|
-
else:
|
|
170
|
-
if dtype is None:
|
|
171
|
-
dtype = [np.float64, np.float32]
|
|
172
|
-
sample_weight = check_array(
|
|
173
|
-
sample_weight,
|
|
174
|
-
accept_sparse=False,
|
|
175
|
-
ensure_2d=False,
|
|
176
|
-
dtype=dtype,
|
|
177
|
-
order="C")
|
|
178
|
-
if sample_weight.ndim != 1:
|
|
179
|
-
raise ValueError("Sample weights must be 1D array or scalar")
|
|
180
|
-
|
|
181
|
-
if sample_weight.shape != (n_samples,):
|
|
182
|
-
raise ValueError("sample_weight.shape == {}, expected {}!"
|
|
183
|
-
.format(sample_weight.shape, (n_samples,)))
|
|
184
|
-
return sample_weight
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
class ExtraTreesClassifier(sklearn_ExtraTreesClassifier, BaseTree):
|
|
188
|
-
__doc__ = sklearn_ExtraTreesClassifier.__doc__
|
|
189
|
-
|
|
190
|
-
if sklearn_check_version('1.2'):
|
|
191
|
-
_parameter_constraints: dict = {
|
|
192
|
-
**sklearn_ExtraTreesClassifier._parameter_constraints,
|
|
193
|
-
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
194
|
-
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
if sklearn_check_version('1.0'):
|
|
198
|
-
def __init__(
|
|
199
|
-
self,
|
|
200
|
-
n_estimators=100,
|
|
201
|
-
criterion="gini",
|
|
202
|
-
max_depth=None,
|
|
203
|
-
min_samples_split=2,
|
|
204
|
-
min_samples_leaf=1,
|
|
205
|
-
min_weight_fraction_leaf=0.,
|
|
206
|
-
max_features='sqrt' if sklearn_check_version('1.1') else 'auto',
|
|
207
|
-
max_leaf_nodes=None,
|
|
208
|
-
min_impurity_decrease=0.,
|
|
209
|
-
bootstrap=False,
|
|
210
|
-
oob_score=False,
|
|
211
|
-
n_jobs=None,
|
|
212
|
-
random_state=None,
|
|
213
|
-
verbose=0,
|
|
214
|
-
warm_start=False,
|
|
215
|
-
class_weight=None,
|
|
216
|
-
ccp_alpha=0.0,
|
|
217
|
-
max_samples=None,
|
|
218
|
-
max_bins=256,
|
|
219
|
-
min_bin_size=1):
|
|
220
|
-
super(ExtraTreesClassifier, self).__init__(
|
|
221
|
-
n_estimators=n_estimators,
|
|
222
|
-
criterion=criterion,
|
|
223
|
-
max_depth=max_depth,
|
|
224
|
-
min_samples_split=min_samples_split,
|
|
225
|
-
min_samples_leaf=min_samples_leaf,
|
|
226
|
-
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
227
|
-
max_features=max_features,
|
|
228
|
-
max_leaf_nodes=max_leaf_nodes,
|
|
229
|
-
min_impurity_decrease=min_impurity_decrease,
|
|
230
|
-
bootstrap=bootstrap,
|
|
231
|
-
oob_score=oob_score,
|
|
232
|
-
n_jobs=n_jobs,
|
|
233
|
-
random_state=random_state,
|
|
234
|
-
verbose=verbose,
|
|
235
|
-
warm_start=warm_start,
|
|
236
|
-
class_weight=class_weight
|
|
237
|
-
)
|
|
238
|
-
self.warm_start = warm_start
|
|
239
|
-
self.ccp_alpha = ccp_alpha
|
|
240
|
-
self.max_samples = max_samples
|
|
241
|
-
self.max_bins = max_bins
|
|
242
|
-
self.min_bin_size = min_bin_size
|
|
243
|
-
|
|
244
|
-
else:
|
|
245
|
-
def __init__(self,
|
|
246
|
-
n_estimators=100,
|
|
247
|
-
criterion="gini",
|
|
248
|
-
max_depth=None,
|
|
249
|
-
min_samples_split=2,
|
|
250
|
-
min_samples_leaf=1,
|
|
251
|
-
min_weight_fraction_leaf=0.,
|
|
252
|
-
max_features="auto",
|
|
253
|
-
max_leaf_nodes=None,
|
|
254
|
-
min_impurity_decrease=0.,
|
|
255
|
-
min_impurity_split=None,
|
|
256
|
-
bootstrap=False,
|
|
257
|
-
oob_score=False,
|
|
258
|
-
n_jobs=None,
|
|
259
|
-
random_state=None,
|
|
260
|
-
verbose=0,
|
|
261
|
-
warm_start=False,
|
|
262
|
-
class_weight=None,
|
|
263
|
-
ccp_alpha=0.0,
|
|
264
|
-
max_samples=None,
|
|
265
|
-
max_bins=256,
|
|
266
|
-
min_bin_size=1):
|
|
267
|
-
super(ExtraTreesClassifier, self).__init__(
|
|
268
|
-
n_estimators=n_estimators,
|
|
269
|
-
criterion=criterion,
|
|
270
|
-
max_depth=max_depth,
|
|
271
|
-
min_samples_split=min_samples_split,
|
|
272
|
-
min_samples_leaf=min_samples_leaf,
|
|
273
|
-
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
274
|
-
max_features=max_features,
|
|
275
|
-
max_leaf_nodes=max_leaf_nodes,
|
|
276
|
-
min_impurity_decrease=min_impurity_decrease,
|
|
277
|
-
min_impurity_split=min_impurity_split,
|
|
278
|
-
bootstrap=bootstrap,
|
|
279
|
-
oob_score=oob_score,
|
|
280
|
-
n_jobs=n_jobs,
|
|
281
|
-
random_state=random_state,
|
|
282
|
-
verbose=verbose,
|
|
283
|
-
warm_start=warm_start,
|
|
284
|
-
class_weight=class_weight,
|
|
285
|
-
ccp_alpha=ccp_alpha,
|
|
286
|
-
max_samples=max_samples
|
|
287
|
-
)
|
|
288
|
-
self.warm_start = warm_start
|
|
289
|
-
self.ccp_alpha = ccp_alpha
|
|
290
|
-
self.max_samples = max_samples
|
|
291
|
-
self.max_bins = max_bins
|
|
292
|
-
self.min_bin_size = min_bin_size
|
|
293
|
-
|
|
294
|
-
def fit(self, X, y, sample_weight=None):
|
|
295
|
-
"""
|
|
296
|
-
Build a forest of trees from the training set (X, y).
|
|
297
|
-
|
|
298
|
-
Parameters
|
|
299
|
-
----------
|
|
300
|
-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
301
|
-
The training input samples. Internally, its dtype will be converted
|
|
302
|
-
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
303
|
-
converted into a sparse ``csc_matrix``.
|
|
304
|
-
|
|
305
|
-
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
|
|
306
|
-
The target values (class labels in classification, real numbers in
|
|
307
|
-
regression).
|
|
308
|
-
|
|
309
|
-
sample_weight : array-like of shape (n_samples,), default=None
|
|
310
|
-
Sample weights. If None, then samples are equally weighted. Splits
|
|
311
|
-
that would create child nodes with net zero or negative weight are
|
|
312
|
-
ignored while searching for a split in each node. In the case of
|
|
313
|
-
classification, splits are also ignored if they would result in any
|
|
314
|
-
single class carrying a negative weight in either child node.
|
|
315
|
-
|
|
316
|
-
Returns
|
|
317
|
-
-------
|
|
318
|
-
self : object
|
|
319
|
-
"""
|
|
320
|
-
dispatch(self, 'fit', {
|
|
321
|
-
'onedal': self.__class__._onedal_fit,
|
|
322
|
-
'sklearn': sklearn_ExtraTreesClassifier.fit,
|
|
323
|
-
}, X, y, sample_weight)
|
|
324
|
-
return self
|
|
325
|
-
|
|
326
|
-
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
327
|
-
if sp.issparse(y):
|
|
328
|
-
raise ValueError(
|
|
329
|
-
"sparse multilabel-indicator for y is not supported."
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
if sklearn_check_version("1.2"):
|
|
333
|
-
self._validate_params()
|
|
334
|
-
else:
|
|
335
|
-
self._check_parameters()
|
|
336
|
-
|
|
337
|
-
if not self.bootstrap and self.oob_score:
|
|
338
|
-
raise ValueError("Out of bag estimation only available"
|
|
339
|
-
" if bootstrap=True")
|
|
340
|
-
|
|
341
|
-
ready = patching_status.and_conditions([
|
|
342
|
-
(self.oob_score and daal_check_version((2021, 'P', 500)) or not
|
|
343
|
-
self.oob_score,
|
|
344
|
-
"OOB score is only supported starting from 2021.5 version of oneDAL."),
|
|
345
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
346
|
-
(self.ccp_alpha == 0.0,
|
|
347
|
-
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported."),
|
|
348
|
-
(self.criterion == "gini",
|
|
349
|
-
f"'{self.criterion}' criterion is not supported. "
|
|
350
|
-
"Only 'gini' criterion is supported."),
|
|
351
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
352
|
-
(self.n_estimators <= 6024, "More than 6024 estimators is not supported.")
|
|
353
|
-
])
|
|
354
|
-
|
|
355
|
-
if ready:
|
|
356
|
-
if sklearn_check_version("1.0"):
|
|
357
|
-
self._check_feature_names(X, reset=True)
|
|
358
|
-
X = check_array(X, dtype=[np.float32, np.float64])
|
|
359
|
-
y = np.asarray(y)
|
|
360
|
-
y = np.atleast_1d(y)
|
|
361
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
362
|
-
warnings.warn(
|
|
363
|
-
"A column-vector y was passed when a 1d array was"
|
|
364
|
-
" expected. Please change the shape of y to "
|
|
365
|
-
"(n_samples,), for example using ravel().",
|
|
366
|
-
DataConversionWarning,
|
|
367
|
-
stacklevel=2)
|
|
368
|
-
check_consistent_length(X, y)
|
|
369
|
-
|
|
370
|
-
y = make2d(y)
|
|
371
|
-
self.n_outputs_ = y.shape[1]
|
|
372
|
-
ready = patching_status.and_conditions([
|
|
373
|
-
(self.n_outputs_ == 1,
|
|
374
|
-
f"Number of outputs ({self.n_outputs_}) is not 1."),
|
|
375
|
-
(y.dtype in [np.float32, np.float64, np.int32, np.int64],
|
|
376
|
-
f"Datatype ({y.dtype}) for y is not supported.")
|
|
377
|
-
])
|
|
378
|
-
# TODO: Fix to support integers as input
|
|
379
|
-
|
|
380
|
-
n_samples = X.shape[0]
|
|
381
|
-
if isinstance(self.max_samples, numbers.Integral):
|
|
382
|
-
if not sklearn_check_version('1.2'):
|
|
383
|
-
if not (1 <= self.max_samples <= n_samples):
|
|
384
|
-
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
385
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
386
|
-
else:
|
|
387
|
-
if self.max_samples > n_samples:
|
|
388
|
-
msg = "`max_samples` must be <= n_samples={} but got value {}"
|
|
389
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
390
|
-
elif isinstance(self.max_samples, numbers.Real):
|
|
391
|
-
if sklearn_check_version('1.2'):
|
|
392
|
-
pass
|
|
393
|
-
elif sklearn_check_version('1.0'):
|
|
394
|
-
if not (0 < float(self.max_samples) <= 1):
|
|
395
|
-
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
|
|
396
|
-
raise ValueError(msg.format(self.max_samples))
|
|
397
|
-
else:
|
|
398
|
-
if not (0 < float(self.max_samples) < 1):
|
|
399
|
-
msg = "`max_samples` must be in range (0, 1) but got value {}"
|
|
400
|
-
raise ValueError(msg.format(self.max_samples))
|
|
401
|
-
elif self.max_samples is not None:
|
|
402
|
-
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
403
|
-
raise TypeError(msg.format(type(self.max_samples)))
|
|
404
|
-
|
|
405
|
-
if not self.bootstrap and self.max_samples is not None:
|
|
406
|
-
raise ValueError(
|
|
407
|
-
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
408
|
-
"Either switch to `bootstrap=True` or set "
|
|
409
|
-
"`max_sample=None`."
|
|
410
|
-
)
|
|
411
|
-
|
|
412
|
-
return ready, X, y, sample_weight
|
|
413
|
-
|
|
414
|
-
@wrap_output_data
|
|
415
|
-
def predict(self, X):
|
|
416
|
-
"""
|
|
417
|
-
Predict class for X.
|
|
418
|
-
|
|
419
|
-
The predicted class of an input sample is a vote by the trees in
|
|
420
|
-
the forest, weighted by their probability estimates. That is,
|
|
421
|
-
the predicted class is the one with highest mean probability
|
|
422
|
-
estimate across the trees.
|
|
423
|
-
|
|
424
|
-
Parameters
|
|
425
|
-
----------
|
|
426
|
-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
427
|
-
The input samples. Internally, its dtype will be converted to
|
|
428
|
-
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
429
|
-
converted into a sparse ``csr_matrix``.
|
|
430
|
-
|
|
431
|
-
Returns
|
|
432
|
-
-------
|
|
433
|
-
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
|
|
434
|
-
The predicted classes.
|
|
435
|
-
"""
|
|
436
|
-
return dispatch(self, 'predict', {
|
|
437
|
-
'onedal': self.__class__._onedal_predict,
|
|
438
|
-
'sklearn': sklearn_ExtraTreesClassifier.predict,
|
|
439
|
-
}, X)
|
|
440
|
-
|
|
441
|
-
@wrap_output_data
|
|
442
|
-
def predict_proba(self, X):
|
|
443
|
-
"""
|
|
444
|
-
Predict class probabilities for X.
|
|
445
|
-
|
|
446
|
-
The predicted class probabilities of an input sample are computed as
|
|
447
|
-
the mean predicted class probabilities of the trees in the forest.
|
|
448
|
-
The class probability of a single tree is the fraction of samples of
|
|
449
|
-
the same class in a leaf.
|
|
450
|
-
|
|
451
|
-
Parameters
|
|
452
|
-
----------
|
|
453
|
-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
454
|
-
The input samples. Internally, its dtype will be converted to
|
|
455
|
-
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
456
|
-
converted into a sparse ``csr_matrix``.
|
|
457
|
-
|
|
458
|
-
Returns
|
|
459
|
-
-------
|
|
460
|
-
p : ndarray of shape (n_samples, n_classes), or a list of n_outputs
|
|
461
|
-
such arrays if n_outputs > 1.
|
|
462
|
-
The class probabilities of the input samples. The order of the
|
|
463
|
-
classes corresponds to that in the attribute :term:`classes_`.
|
|
464
|
-
"""
|
|
465
|
-
# TODO:
|
|
466
|
-
# _check_proba()
|
|
467
|
-
# self._check_proba()
|
|
468
|
-
if sklearn_check_version("1.0"):
|
|
469
|
-
self._check_feature_names(X, reset=False)
|
|
470
|
-
if hasattr(self, 'n_features_in_'):
|
|
471
|
-
try:
|
|
472
|
-
num_features = _num_features(X)
|
|
473
|
-
except TypeError:
|
|
474
|
-
num_features = _num_samples(X)
|
|
475
|
-
if num_features != self.n_features_in_:
|
|
476
|
-
raise ValueError(
|
|
477
|
-
(f'X has {num_features} features, '
|
|
478
|
-
f'but ExtraTreesClassifier is expecting '
|
|
479
|
-
f'{self.n_features_in_} features as input'))
|
|
480
|
-
return dispatch(self, 'predict_proba', {
|
|
481
|
-
'onedal': self.__class__._onedal_predict_proba,
|
|
482
|
-
'sklearn': sklearn_ExtraTreesClassifier.predict_proba,
|
|
483
|
-
}, X)
|
|
484
|
-
|
|
485
|
-
if sklearn_check_version('1.0'):
|
|
486
|
-
@deprecated(
|
|
487
|
-
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
488
|
-
"removed in 1.2. Use `n_features_in_` instead.")
|
|
489
|
-
@property
|
|
490
|
-
def n_features_(self):
|
|
491
|
-
return self.n_features_in_
|
|
492
|
-
|
|
493
|
-
@property
|
|
494
|
-
def _estimators_(self):
|
|
495
|
-
if hasattr(self, '_cached_estimators_'):
|
|
496
|
-
if self._cached_estimators_:
|
|
497
|
-
return self._cached_estimators_
|
|
498
|
-
if sklearn_check_version('0.22'):
|
|
499
|
-
check_is_fitted(self)
|
|
500
|
-
else:
|
|
501
|
-
check_is_fitted(self, '_onedal_model')
|
|
502
|
-
classes_ = self.classes_[0]
|
|
503
|
-
n_classes_ = self.n_classes_[0]
|
|
504
|
-
# convert model to estimators
|
|
505
|
-
params = {
|
|
506
|
-
'criterion': self.criterion,
|
|
507
|
-
'max_depth': self.max_depth,
|
|
508
|
-
'min_samples_split': self.min_samples_split,
|
|
509
|
-
'min_samples_leaf': self.min_samples_leaf,
|
|
510
|
-
'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
|
|
511
|
-
'max_features': self.max_features,
|
|
512
|
-
'max_leaf_nodes': self.max_leaf_nodes,
|
|
513
|
-
'min_impurity_decrease': self.min_impurity_decrease,
|
|
514
|
-
'random_state': None,
|
|
515
|
-
}
|
|
516
|
-
if not sklearn_check_version('1.0'):
|
|
517
|
-
params['min_impurity_split'] = self.min_impurity_split
|
|
518
|
-
est = ExtraTreeClassifier(**params)
|
|
519
|
-
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
520
|
-
# oneAPI Data Analytics Library solution
|
|
521
|
-
estimators_ = []
|
|
522
|
-
random_state_checked = check_random_state(self.random_state)
|
|
523
|
-
for i in range(self.n_estimators):
|
|
524
|
-
est_i = clone(est)
|
|
525
|
-
est_i.set_params(
|
|
526
|
-
random_state=random_state_checked.randint(
|
|
527
|
-
np.iinfo(
|
|
528
|
-
np.int32).max))
|
|
529
|
-
if sklearn_check_version('1.0'):
|
|
530
|
-
est_i.n_features_in_ = self.n_features_in_
|
|
531
|
-
else:
|
|
532
|
-
est_i.n_features_ = self.n_features_in_
|
|
533
|
-
est_i.n_outputs_ = self.n_outputs_
|
|
534
|
-
est_i.classes_ = classes_
|
|
535
|
-
est_i.n_classes_ = n_classes_
|
|
536
|
-
tree_i_state_class = get_tree_state_cls(
|
|
537
|
-
self._onedal_model, i, n_classes_)
|
|
538
|
-
tree_i_state_dict = {
|
|
539
|
-
'max_depth': tree_i_state_class.max_depth,
|
|
540
|
-
'node_count': tree_i_state_class.node_count,
|
|
541
|
-
'nodes': check_tree_nodes(tree_i_state_class.node_ar),
|
|
542
|
-
'values': tree_i_state_class.value_ar}
|
|
543
|
-
est_i.tree_ = Tree(
|
|
544
|
-
self.n_features_in_,
|
|
545
|
-
np.array(
|
|
546
|
-
[n_classes_],
|
|
547
|
-
dtype=np.intp),
|
|
548
|
-
self.n_outputs_)
|
|
549
|
-
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
550
|
-
estimators_.append(est_i)
|
|
551
|
-
|
|
552
|
-
self._cached_estimators_ = estimators_
|
|
553
|
-
return estimators_
|
|
554
|
-
|
|
555
|
-
def _onedal_cpu_supported(self, method_name, *data):
|
|
556
|
-
class_name = self.__class__.__name__
|
|
557
|
-
_patching_status = PatchingConditionsChain(
|
|
558
|
-
f'sklearn.ensemble.{class_name}.{method_name}')
|
|
559
|
-
|
|
560
|
-
if method_name == 'fit':
|
|
561
|
-
ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
|
|
562
|
-
|
|
563
|
-
dal_ready = ready and _patching_status.and_conditions([
|
|
564
|
-
(daal_check_version((2023, 'P', 200)),
|
|
565
|
-
"ExtraTrees only supported starting from oneDAL version 2023.2"),
|
|
566
|
-
(not sp.issparse(sample_weight), "sample_weight is sparse. "
|
|
567
|
-
"Sparse input is not supported."),
|
|
568
|
-
])
|
|
569
|
-
|
|
570
|
-
dal_ready = dal_ready and not hasattr(self, 'estimators_')
|
|
571
|
-
|
|
572
|
-
if dal_ready and (self.random_state is not None):
|
|
573
|
-
warnings.warn("Setting 'random_state' value is not supported. "
|
|
574
|
-
"State set by oneDAL to default value (777).",
|
|
575
|
-
RuntimeWarning)
|
|
576
|
-
|
|
577
|
-
elif method_name in ['predict',
|
|
578
|
-
'predict_proba']:
|
|
579
|
-
|
|
580
|
-
X = data[0]
|
|
581
|
-
|
|
582
|
-
dal_ready = _patching_status.and_conditions([
|
|
583
|
-
(hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
|
|
584
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
585
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
586
|
-
(daal_check_version((2023, 'P', 100)),
|
|
587
|
-
"ExtraTrees only supported starting from oneDAL version 2023.2")
|
|
588
|
-
])
|
|
589
|
-
if hasattr(self, 'n_outputs_'):
|
|
590
|
-
dal_ready = dal_ready and _patching_status.and_conditions([
|
|
591
|
-
(self.n_outputs_ == 1,
|
|
592
|
-
f"Number of outputs ({self.n_outputs_}) is not 1."),
|
|
593
|
-
])
|
|
594
|
-
else:
|
|
595
|
-
dal_ready = False
|
|
596
|
-
|
|
597
|
-
else:
|
|
598
|
-
raise RuntimeError(
|
|
599
|
-
f'Unknown method {method_name} in {self.__class__.__name__}')
|
|
600
|
-
|
|
601
|
-
_patching_status.write_log()
|
|
602
|
-
return dal_ready
|
|
603
|
-
|
|
604
|
-
def _onedal_gpu_supported(self, method_name, *data):
|
|
605
|
-
class_name = self.__class__.__name__
|
|
606
|
-
_patching_status = PatchingConditionsChain(
|
|
607
|
-
f'sklearn.ensemble.{class_name}.{method_name}')
|
|
608
|
-
|
|
609
|
-
if method_name == 'fit':
|
|
610
|
-
ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
|
|
611
|
-
|
|
612
|
-
dal_ready = ready and _patching_status.and_conditions([
|
|
613
|
-
(daal_check_version((2023, 'P', 100)),
|
|
614
|
-
"ExtraTrees only supported starting from oneDAL version 2023.1"),
|
|
615
|
-
(sample_weight is not None, "sample_weight is not supported.")
|
|
616
|
-
])
|
|
617
|
-
|
|
618
|
-
dal_ready &= not hasattr(self, 'estimators_')
|
|
619
|
-
|
|
620
|
-
if dal_ready and (self.random_state is not None):
|
|
621
|
-
warnings.warn("Setting 'random_state' value is not supported. "
|
|
622
|
-
"State set by oneDAL to default value (777).",
|
|
623
|
-
RuntimeWarning)
|
|
624
|
-
|
|
625
|
-
elif method_name in ['predict',
|
|
626
|
-
'predict_proba']:
|
|
627
|
-
|
|
628
|
-
X = data[0]
|
|
629
|
-
|
|
630
|
-
dal_ready = hasattr(self, '_onedal_model') and hasattr(self, 'n_outputs_')
|
|
631
|
-
|
|
632
|
-
if dal_ready:
|
|
633
|
-
dal_ready = _patching_status.and_conditions([
|
|
634
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
635
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
636
|
-
(daal_check_version((2023, 'P', 100)),
|
|
637
|
-
"ExtraTrees only supported starting from oneDAL version 2023.1")
|
|
638
|
-
])
|
|
639
|
-
|
|
640
|
-
if hasattr(self, 'n_outputs_'):
|
|
641
|
-
dal_ready &= _patching_status.and_conditions([
|
|
642
|
-
(self.n_outputs_ == 1,
|
|
643
|
-
f"Number of outputs ({self.n_outputs_}) is not 1."),
|
|
644
|
-
])
|
|
645
|
-
|
|
646
|
-
else:
|
|
647
|
-
raise RuntimeError(
|
|
648
|
-
f'Unknown method {method_name} in {self.__class__.__name__}')
|
|
649
|
-
|
|
650
|
-
_patching_status.write_log()
|
|
651
|
-
return dal_ready
|
|
652
|
-
|
|
653
|
-
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
654
|
-
if sklearn_check_version('1.2'):
|
|
655
|
-
X, y = self._validate_data(
|
|
656
|
-
X, y, multi_output=False, accept_sparse=False,
|
|
657
|
-
dtype=[np.float64, np.float32]
|
|
658
|
-
)
|
|
659
|
-
else:
|
|
660
|
-
X, y = check_X_y(
|
|
661
|
-
X, y, accept_sparse=False, dtype=[np.float64, np.float32],
|
|
662
|
-
multi_output=False
|
|
663
|
-
)
|
|
664
|
-
|
|
665
|
-
if sample_weight is not None:
|
|
666
|
-
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
667
|
-
|
|
668
|
-
y = np.atleast_1d(y)
|
|
669
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
670
|
-
warnings.warn(
|
|
671
|
-
"A column-vector y was passed when a 1d array was"
|
|
672
|
-
" expected. Please change the shape of y to "
|
|
673
|
-
"(n_samples,), for example using ravel().",
|
|
674
|
-
DataConversionWarning,
|
|
675
|
-
stacklevel=2,
|
|
676
|
-
)
|
|
677
|
-
if y.ndim == 1:
|
|
678
|
-
# reshape is necessary to preserve the data contiguity against vs
|
|
679
|
-
# [:, np.newaxis] that does not.
|
|
680
|
-
y = np.reshape(y, (-1, 1))
|
|
681
|
-
|
|
682
|
-
y, expanded_class_weight = self._validate_y_class_weight(y)
|
|
683
|
-
|
|
684
|
-
n_classes_ = self.n_classes_[0]
|
|
685
|
-
self.n_features_in_ = X.shape[1]
|
|
686
|
-
if not sklearn_check_version('1.0'):
|
|
687
|
-
self.n_features_ = self.n_features_in_
|
|
688
|
-
|
|
689
|
-
if expanded_class_weight is not None:
|
|
690
|
-
if sample_weight is not None:
|
|
691
|
-
sample_weight = sample_weight * expanded_class_weight
|
|
692
|
-
else:
|
|
693
|
-
sample_weight = expanded_class_weight
|
|
694
|
-
if sample_weight is not None:
|
|
695
|
-
sample_weight = [sample_weight]
|
|
696
|
-
|
|
697
|
-
if n_classes_ < 2:
|
|
698
|
-
raise ValueError(
|
|
699
|
-
"Training data only contain information about one class.")
|
|
700
|
-
|
|
701
|
-
if self.oob_score:
|
|
702
|
-
err = 'out_of_bag_error_accuracy|out_of_bag_error_decision_function'
|
|
703
|
-
else:
|
|
704
|
-
err = 'none'
|
|
705
|
-
|
|
706
|
-
onedal_params = {
|
|
707
|
-
'n_estimators': self.n_estimators,
|
|
708
|
-
'criterion': self.criterion,
|
|
709
|
-
'max_depth': self.max_depth,
|
|
710
|
-
'min_samples_split': self.min_samples_split,
|
|
711
|
-
'min_samples_leaf': self.min_samples_leaf,
|
|
712
|
-
'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
|
|
713
|
-
'max_features': self.max_features,
|
|
714
|
-
'max_leaf_nodes': self.max_leaf_nodes,
|
|
715
|
-
'min_impurity_decrease': self.min_impurity_decrease,
|
|
716
|
-
'bootstrap': self.bootstrap,
|
|
717
|
-
'oob_score': self.oob_score,
|
|
718
|
-
'n_jobs': self.n_jobs,
|
|
719
|
-
'random_state': self.random_state,
|
|
720
|
-
'verbose': self.verbose,
|
|
721
|
-
'warm_start': self.warm_start,
|
|
722
|
-
'error_metric_mode': err,
|
|
723
|
-
'variable_importance_mode': 'mdi',
|
|
724
|
-
'class_weight': self.class_weight,
|
|
725
|
-
'max_bins': self.max_bins,
|
|
726
|
-
'min_bin_size': self.min_bin_size,
|
|
727
|
-
'max_samples': self.max_samples
|
|
728
|
-
}
|
|
729
|
-
if daal_check_version((2023, 'P', 101)):
|
|
730
|
-
onedal_params['splitter_mode'] = "random"
|
|
731
|
-
if not sklearn_check_version('1.0'):
|
|
732
|
-
onedal_params['min_impurity_split'] = self.min_impurity_split
|
|
733
|
-
else:
|
|
734
|
-
onedal_params['min_impurity_split'] = None
|
|
735
|
-
self._cached_estimators_ = None
|
|
736
|
-
|
|
737
|
-
# Compute
|
|
738
|
-
self._onedal_estimator = self._onedal_classifier(**onedal_params)
|
|
739
|
-
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
|
|
740
|
-
|
|
741
|
-
self._save_attributes()
|
|
742
|
-
if sklearn_check_version("1.2"):
|
|
743
|
-
self._estimator = ExtraTreeClassifier()
|
|
744
|
-
self.estimators_ = self._estimators_
|
|
745
|
-
# Decapsulate classes_ attributes
|
|
746
|
-
self.n_classes_ = self.n_classes_[0]
|
|
747
|
-
self.classes_ = self.classes_[0]
|
|
748
|
-
return self
|
|
749
|
-
|
|
750
|
-
def _onedal_predict(self, X, queue=None):
|
|
751
|
-
X = check_array(X, dtype=[np.float32, np.float64])
|
|
752
|
-
check_is_fitted(self)
|
|
753
|
-
if sklearn_check_version("1.0"):
|
|
754
|
-
self._check_feature_names(X, reset=False)
|
|
755
|
-
|
|
756
|
-
res = self._onedal_estimator.predict(X, queue=queue)
|
|
757
|
-
return np.take(self.classes_,
|
|
758
|
-
res.ravel().astype(np.int64, casting='unsafe'))
|
|
759
|
-
|
|
760
|
-
def _onedal_predict_proba(self, X, queue=None):
|
|
761
|
-
X = check_array(X, dtype=[np.float64, np.float32])
|
|
762
|
-
check_is_fitted(self)
|
|
763
|
-
if sklearn_check_version('0.23'):
|
|
764
|
-
self._check_n_features(X, reset=False)
|
|
765
|
-
if sklearn_check_version("1.0"):
|
|
766
|
-
self._check_feature_names(X, reset=False)
|
|
767
|
-
return self._onedal_estimator.predict_proba(X, queue=queue)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
class ExtraTreesRegressor(sklearn_ExtraTreesRegressor, BaseTree):
|
|
771
|
-
__doc__ = sklearn_ExtraTreesRegressor.__doc__
|
|
772
|
-
|
|
773
|
-
if sklearn_check_version('1.2'):
|
|
774
|
-
_parameter_constraints: dict = {
|
|
775
|
-
**sklearn_ExtraTreesRegressor._parameter_constraints,
|
|
776
|
-
"max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
|
|
777
|
-
"min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
|
|
778
|
-
}
|
|
779
|
-
|
|
780
|
-
if sklearn_check_version('1.0'):
|
|
781
|
-
def __init__(
|
|
782
|
-
self,
|
|
783
|
-
n_estimators=100,
|
|
784
|
-
*,
|
|
785
|
-
criterion="squared_error",
|
|
786
|
-
max_depth=None,
|
|
787
|
-
min_samples_split=2,
|
|
788
|
-
min_samples_leaf=1,
|
|
789
|
-
min_weight_fraction_leaf=0.,
|
|
790
|
-
max_features=1.0 if sklearn_check_version('1.1') else 'auto',
|
|
791
|
-
max_leaf_nodes=None,
|
|
792
|
-
min_impurity_decrease=0.,
|
|
793
|
-
bootstrap=False,
|
|
794
|
-
oob_score=False,
|
|
795
|
-
n_jobs=None,
|
|
796
|
-
random_state=None,
|
|
797
|
-
verbose=0,
|
|
798
|
-
warm_start=False,
|
|
799
|
-
ccp_alpha=0.0,
|
|
800
|
-
max_samples=None,
|
|
801
|
-
max_bins=256,
|
|
802
|
-
min_bin_size=1):
|
|
803
|
-
super(ExtraTreesRegressor, self).__init__(
|
|
804
|
-
n_estimators=n_estimators,
|
|
805
|
-
criterion=criterion,
|
|
806
|
-
max_depth=max_depth,
|
|
807
|
-
min_samples_split=min_samples_split,
|
|
808
|
-
min_samples_leaf=min_samples_leaf,
|
|
809
|
-
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
810
|
-
max_features=max_features,
|
|
811
|
-
max_leaf_nodes=max_leaf_nodes,
|
|
812
|
-
min_impurity_decrease=min_impurity_decrease,
|
|
813
|
-
bootstrap=bootstrap,
|
|
814
|
-
oob_score=oob_score,
|
|
815
|
-
n_jobs=n_jobs,
|
|
816
|
-
random_state=random_state,
|
|
817
|
-
verbose=verbose,
|
|
818
|
-
warm_start=warm_start
|
|
819
|
-
)
|
|
820
|
-
self.warm_start = warm_start
|
|
821
|
-
self.ccp_alpha = ccp_alpha
|
|
822
|
-
self.max_samples = max_samples
|
|
823
|
-
self.max_bins = max_bins
|
|
824
|
-
self.min_bin_size = min_bin_size
|
|
825
|
-
else:
|
|
826
|
-
def __init__(self,
|
|
827
|
-
n_estimators=100, *,
|
|
828
|
-
criterion="mse",
|
|
829
|
-
max_depth=None,
|
|
830
|
-
min_samples_split=2,
|
|
831
|
-
min_samples_leaf=1,
|
|
832
|
-
min_weight_fraction_leaf=0.,
|
|
833
|
-
max_features="auto",
|
|
834
|
-
max_leaf_nodes=None,
|
|
835
|
-
min_impurity_decrease=0.,
|
|
836
|
-
min_impurity_split=None,
|
|
837
|
-
bootstrap=False,
|
|
838
|
-
oob_score=False,
|
|
839
|
-
n_jobs=None,
|
|
840
|
-
random_state=None,
|
|
841
|
-
verbose=0,
|
|
842
|
-
warm_start=False,
|
|
843
|
-
ccp_alpha=0.0,
|
|
844
|
-
max_samples=None,
|
|
845
|
-
max_bins=256,
|
|
846
|
-
min_bin_size=1
|
|
847
|
-
):
|
|
848
|
-
super(ExtraTreesRegressor, self).__init__(
|
|
849
|
-
n_estimators=n_estimators,
|
|
850
|
-
criterion=criterion,
|
|
851
|
-
max_depth=max_depth,
|
|
852
|
-
min_samples_split=min_samples_split,
|
|
853
|
-
min_samples_leaf=min_samples_leaf,
|
|
854
|
-
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
855
|
-
max_features=max_features,
|
|
856
|
-
max_leaf_nodes=max_leaf_nodes,
|
|
857
|
-
min_impurity_decrease=min_impurity_decrease,
|
|
858
|
-
min_impurity_split=min_impurity_split,
|
|
859
|
-
bootstrap=bootstrap,
|
|
860
|
-
oob_score=oob_score,
|
|
861
|
-
n_jobs=n_jobs,
|
|
862
|
-
random_state=random_state,
|
|
863
|
-
verbose=verbose,
|
|
864
|
-
warm_start=warm_start,
|
|
865
|
-
ccp_alpha=ccp_alpha,
|
|
866
|
-
max_samples=max_samples
|
|
867
|
-
)
|
|
868
|
-
self.warm_start = warm_start
|
|
869
|
-
self.ccp_alpha = ccp_alpha
|
|
870
|
-
self.max_samples = max_samples
|
|
871
|
-
self.max_bins = max_bins
|
|
872
|
-
self.min_bin_size = min_bin_size
|
|
873
|
-
|
|
874
|
-
@property
|
|
875
|
-
def _estimators_(self):
|
|
876
|
-
if hasattr(self, '_cached_estimators_'):
|
|
877
|
-
if self._cached_estimators_:
|
|
878
|
-
return self._cached_estimators_
|
|
879
|
-
if sklearn_check_version('0.22'):
|
|
880
|
-
check_is_fitted(self)
|
|
881
|
-
else:
|
|
882
|
-
check_is_fitted(self, '_onedal_model')
|
|
883
|
-
# convert model to estimators
|
|
884
|
-
params = {
|
|
885
|
-
'criterion': self.criterion,
|
|
886
|
-
'max_depth': self.max_depth,
|
|
887
|
-
'min_samples_split': self.min_samples_split,
|
|
888
|
-
'min_samples_leaf': self.min_samples_leaf,
|
|
889
|
-
'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
|
|
890
|
-
'max_features': self.max_features,
|
|
891
|
-
'max_leaf_nodes': self.max_leaf_nodes,
|
|
892
|
-
'min_impurity_decrease': self.min_impurity_decrease,
|
|
893
|
-
'random_state': None,
|
|
894
|
-
}
|
|
895
|
-
if not sklearn_check_version('1.0'):
|
|
896
|
-
params['min_impurity_split'] = self.min_impurity_split
|
|
897
|
-
est = ExtraTreeRegressor(**params)
|
|
898
|
-
# we need to set est.tree_ field with Trees constructed from Intel(R)
|
|
899
|
-
# oneAPI Data Analytics Library solution
|
|
900
|
-
estimators_ = []
|
|
901
|
-
random_state_checked = check_random_state(self.random_state)
|
|
902
|
-
for i in range(self.n_estimators):
|
|
903
|
-
est_i = clone(est)
|
|
904
|
-
est_i.set_params(
|
|
905
|
-
random_state=random_state_checked.randint(
|
|
906
|
-
np.iinfo(
|
|
907
|
-
np.int32).max))
|
|
908
|
-
if sklearn_check_version('1.0'):
|
|
909
|
-
est_i.n_features_in_ = self.n_features_in_
|
|
910
|
-
else:
|
|
911
|
-
est_i.n_features_ = self.n_features_in_
|
|
912
|
-
est_i.n_classes_ = 1
|
|
913
|
-
est_i.n_outputs_ = self.n_outputs_
|
|
914
|
-
tree_i_state_class = get_tree_state_reg(
|
|
915
|
-
self._onedal_model, i)
|
|
916
|
-
tree_i_state_dict = {
|
|
917
|
-
'max_depth': tree_i_state_class.max_depth,
|
|
918
|
-
'node_count': tree_i_state_class.node_count,
|
|
919
|
-
'nodes': check_tree_nodes(tree_i_state_class.node_ar),
|
|
920
|
-
'values': tree_i_state_class.value_ar}
|
|
921
|
-
|
|
922
|
-
est_i.tree_ = Tree(
|
|
923
|
-
self.n_features_in_, np.array(
|
|
924
|
-
[1], dtype=np.intp), self.n_outputs_)
|
|
925
|
-
est_i.tree_.__setstate__(tree_i_state_dict)
|
|
926
|
-
estimators_.append(est_i)
|
|
927
|
-
|
|
928
|
-
return estimators_
|
|
929
|
-
|
|
930
|
-
def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
|
|
931
|
-
if sp.issparse(y):
|
|
932
|
-
raise ValueError(
|
|
933
|
-
"sparse multilabel-indicator for y is not supported."
|
|
934
|
-
)
|
|
935
|
-
|
|
936
|
-
if sklearn_check_version("1.2"):
|
|
937
|
-
self._validate_params()
|
|
938
|
-
else:
|
|
939
|
-
self._check_parameters()
|
|
940
|
-
|
|
941
|
-
if not self.bootstrap and self.oob_score:
|
|
942
|
-
raise ValueError("Out of bag estimation only available"
|
|
943
|
-
" if bootstrap=True")
|
|
944
|
-
|
|
945
|
-
if sklearn_check_version('1.0') and self.criterion == "mse":
|
|
946
|
-
warnings.warn(
|
|
947
|
-
"Criterion 'mse' was deprecated in v1.0 and will be "
|
|
948
|
-
"removed in version 1.2. Use `criterion='squared_error'` "
|
|
949
|
-
"which is equivalent.",
|
|
950
|
-
FutureWarning
|
|
951
|
-
)
|
|
952
|
-
|
|
953
|
-
ready = patching_status.and_conditions([
|
|
954
|
-
(self.oob_score and daal_check_version((2021, 'P', 500)) or not
|
|
955
|
-
self.oob_score,
|
|
956
|
-
"OOB score is only supported starting from 2021.5 version of oneDAL."),
|
|
957
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
958
|
-
(self.criterion in ["mse", "squared_error"],
|
|
959
|
-
f"'{self.criterion}' criterion is not supported. "
|
|
960
|
-
"Only 'mse' and 'squared_error' criteria are supported."),
|
|
961
|
-
(self.ccp_alpha == 0.0,
|
|
962
|
-
f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported."),
|
|
963
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
964
|
-
(self.n_estimators <= 6024, "More than 6024 estimators is not supported.")
|
|
965
|
-
])
|
|
966
|
-
|
|
967
|
-
if ready:
|
|
968
|
-
if sklearn_check_version("1.0"):
|
|
969
|
-
self._check_feature_names(X, reset=True)
|
|
970
|
-
X = check_array(X, dtype=[np.float64, np.float32])
|
|
971
|
-
y = np.asarray(y)
|
|
972
|
-
y = np.atleast_1d(y)
|
|
973
|
-
|
|
974
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
975
|
-
warnings.warn("A column-vector y was passed when a 1d array was"
|
|
976
|
-
" expected. Please change the shape of y to "
|
|
977
|
-
"(n_samples,), for example using ravel().",
|
|
978
|
-
DataConversionWarning, stacklevel=2)
|
|
979
|
-
|
|
980
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype)
|
|
981
|
-
check_consistent_length(X, y)
|
|
982
|
-
|
|
983
|
-
if y.ndim == 1:
|
|
984
|
-
# reshape is necessary to preserve the data contiguity against vs
|
|
985
|
-
# [:, np.newaxis] that does not.
|
|
986
|
-
y = np.reshape(y, (-1, 1))
|
|
987
|
-
|
|
988
|
-
self.n_outputs_ = y.shape[1]
|
|
989
|
-
ready = patching_status.and_conditions([
|
|
990
|
-
(self.n_outputs_ == 1,
|
|
991
|
-
f"Number of outputs ({self.n_outputs_}) is not 1.")
|
|
992
|
-
])
|
|
993
|
-
|
|
994
|
-
n_samples = X.shape[0]
|
|
995
|
-
if isinstance(self.max_samples, numbers.Integral):
|
|
996
|
-
if not sklearn_check_version('1.2'):
|
|
997
|
-
if not (1 <= self.max_samples <= n_samples):
|
|
998
|
-
msg = "`max_samples` must be in range 1 to {} but got value {}"
|
|
999
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
1000
|
-
else:
|
|
1001
|
-
if self.max_samples > n_samples:
|
|
1002
|
-
msg = "`max_samples` must be <= n_samples={} but got value {}"
|
|
1003
|
-
raise ValueError(msg.format(n_samples, self.max_samples))
|
|
1004
|
-
elif isinstance(self.max_samples, numbers.Real):
|
|
1005
|
-
if sklearn_check_version('1.2'):
|
|
1006
|
-
pass
|
|
1007
|
-
elif sklearn_check_version('1.0'):
|
|
1008
|
-
if not (0 < float(self.max_samples) <= 1):
|
|
1009
|
-
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
|
|
1010
|
-
raise ValueError(msg.format(self.max_samples))
|
|
1011
|
-
else:
|
|
1012
|
-
if not (0 < float(self.max_samples) < 1):
|
|
1013
|
-
msg = "`max_samples` must be in range (0, 1) but got value {}"
|
|
1014
|
-
raise ValueError(msg.format(self.max_samples))
|
|
1015
|
-
elif self.max_samples is not None:
|
|
1016
|
-
msg = "`max_samples` should be int or float, but got type '{}'"
|
|
1017
|
-
raise TypeError(msg.format(type(self.max_samples)))
|
|
1018
|
-
|
|
1019
|
-
if not self.bootstrap and self.max_samples is not None:
|
|
1020
|
-
raise ValueError(
|
|
1021
|
-
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
1022
|
-
"Either switch to `bootstrap=True` or set "
|
|
1023
|
-
"`max_sample=None`."
|
|
1024
|
-
)
|
|
1025
|
-
|
|
1026
|
-
return ready, X, y, sample_weight
|
|
1027
|
-
|
|
1028
|
-
def _onedal_cpu_supported(self, method_name, *data):
|
|
1029
|
-
class_name = self.__class__.__name__
|
|
1030
|
-
_patching_status = PatchingConditionsChain(
|
|
1031
|
-
f'sklearn.ensemble.{class_name}.{method_name}')
|
|
1032
|
-
|
|
1033
|
-
if method_name == 'fit':
|
|
1034
|
-
ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
|
|
1035
|
-
|
|
1036
|
-
dal_ready = ready and _patching_status.and_conditions([
|
|
1037
|
-
(daal_check_version((2023, 'P', 200)),
|
|
1038
|
-
"ExtraTrees only supported starting from oneDAL version 2023.2"),
|
|
1039
|
-
(not sp.issparse(sample_weight), "sample_weight is sparse. "
|
|
1040
|
-
"Sparse input is not supported."),
|
|
1041
|
-
])
|
|
1042
|
-
|
|
1043
|
-
dal_ready &= not hasattr(self, 'estimators_')
|
|
1044
|
-
|
|
1045
|
-
if dal_ready and (self.random_state is not None):
|
|
1046
|
-
warnings.warn("Setting 'random_state' value is not supported. "
|
|
1047
|
-
"State set by oneDAL to default value (777).",
|
|
1048
|
-
RuntimeWarning)
|
|
1049
|
-
|
|
1050
|
-
elif method_name in ['predict',
|
|
1051
|
-
'predict_proba']:
|
|
1052
|
-
|
|
1053
|
-
X = data[0]
|
|
1054
|
-
|
|
1055
|
-
dal_ready = _patching_status.and_conditions([
|
|
1056
|
-
(hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
|
|
1057
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1058
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
1059
|
-
(daal_check_version((2023, 'P', 200)),
|
|
1060
|
-
"ExtraTrees only supported starting from oneDAL version 2023.2")
|
|
1061
|
-
])
|
|
1062
|
-
if hasattr(self, 'n_outputs_'):
|
|
1063
|
-
dal_ready &= _patching_status.and_conditions([
|
|
1064
|
-
(self.n_outputs_ == 1,
|
|
1065
|
-
f"Number of outputs ({self.n_outputs_}) is not 1."),
|
|
1066
|
-
])
|
|
1067
|
-
else:
|
|
1068
|
-
dal_ready = False
|
|
1069
|
-
|
|
1070
|
-
else:
|
|
1071
|
-
raise RuntimeError(
|
|
1072
|
-
f'Unknown method {method_name} in {self.__class__.__name__}')
|
|
1073
|
-
|
|
1074
|
-
_patching_status.write_log()
|
|
1075
|
-
return dal_ready
|
|
1076
|
-
|
|
1077
|
-
def _onedal_gpu_supported(self, method_name, *data):
|
|
1078
|
-
class_name = self.__class__.__name__
|
|
1079
|
-
_patching_status = PatchingConditionsChain(
|
|
1080
|
-
f'sklearn.ensemble.{class_name}.{method_name}')
|
|
1081
|
-
|
|
1082
|
-
if method_name == 'fit':
|
|
1083
|
-
ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
|
|
1084
|
-
|
|
1085
|
-
dal_ready = ready and _patching_status.and_conditions([
|
|
1086
|
-
(daal_check_version((2023, 'P', 100)),
|
|
1087
|
-
"ExtraTrees only supported starting from oneDAL version 2023.1"),
|
|
1088
|
-
(sample_weight is not None, "sample_weight is not supported."),
|
|
1089
|
-
])
|
|
1090
|
-
|
|
1091
|
-
dal_ready &= not hasattr(self, 'estimators_')
|
|
1092
|
-
|
|
1093
|
-
if dal_ready and (self.random_state is not None):
|
|
1094
|
-
warnings.warn("Setting 'random_state' value is not supported. "
|
|
1095
|
-
"State set by oneDAL to default value (777).",
|
|
1096
|
-
RuntimeWarning)
|
|
1097
|
-
|
|
1098
|
-
elif method_name in ['predict',
|
|
1099
|
-
'predict_proba']:
|
|
1100
|
-
|
|
1101
|
-
X = data[0]
|
|
1102
|
-
|
|
1103
|
-
dal_ready = _patching_status.and_conditions([
|
|
1104
|
-
(hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
|
|
1105
|
-
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
1106
|
-
(self.warm_start is False, "Warm start is not supported."),
|
|
1107
|
-
|
|
1108
|
-
(daal_check_version((2023, 'P', 100)),
|
|
1109
|
-
"ExtraTrees only supported starting from oneDAL version 2023.1")
|
|
1110
|
-
])
|
|
1111
|
-
if hasattr(self, 'n_outputs_'):
|
|
1112
|
-
dal_ready &= _patching_status.and_conditions([
|
|
1113
|
-
(self.n_outputs_ == 1,
|
|
1114
|
-
f"Number of outputs ({self.n_outputs_}) is not 1."),
|
|
1115
|
-
])
|
|
1116
|
-
|
|
1117
|
-
else:
|
|
1118
|
-
raise RuntimeError(
|
|
1119
|
-
f'Unknown method {method_name} in {self.__class__.__name__}')
|
|
1120
|
-
|
|
1121
|
-
_patching_status.write_log()
|
|
1122
|
-
return dal_ready
|
|
1123
|
-
|
|
1124
|
-
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
|
|
1125
|
-
if sp.issparse(y):
|
|
1126
|
-
raise ValueError(
|
|
1127
|
-
"sparse multilabel-indicator for y is not supported."
|
|
1128
|
-
)
|
|
1129
|
-
if sklearn_check_version("1.2"):
|
|
1130
|
-
self._validate_params()
|
|
1131
|
-
else:
|
|
1132
|
-
self._check_parameters()
|
|
1133
|
-
if sample_weight is not None:
|
|
1134
|
-
sample_weight = self.check_sample_weight(sample_weight, X)
|
|
1135
|
-
if sklearn_check_version("1.0"):
|
|
1136
|
-
self._check_feature_names(X, reset=True)
|
|
1137
|
-
X = check_array(X, dtype=[np.float64, np.float32])
|
|
1138
|
-
y = np.atleast_1d(np.asarray(y))
|
|
1139
|
-
if y.ndim == 2 and y.shape[1] == 1:
|
|
1140
|
-
warnings.warn(
|
|
1141
|
-
"A column-vector y was passed when a 1d array was"
|
|
1142
|
-
" expected. Please change the shape of y to "
|
|
1143
|
-
"(n_samples,), for example using ravel().",
|
|
1144
|
-
DataConversionWarning,
|
|
1145
|
-
stacklevel=2)
|
|
1146
|
-
y = check_array(y, ensure_2d=False, dtype=X.dtype)
|
|
1147
|
-
check_consistent_length(X, y)
|
|
1148
|
-
self.n_features_in_ = X.shape[1]
|
|
1149
|
-
if not sklearn_check_version('1.0'):
|
|
1150
|
-
self.n_features_ = self.n_features_in_
|
|
1151
|
-
rs_ = check_random_state(self.random_state)
|
|
1152
|
-
|
|
1153
|
-
if self.oob_score:
|
|
1154
|
-
err = 'out_of_bag_error_r2|out_of_bag_error_prediction'
|
|
1155
|
-
else:
|
|
1156
|
-
err = 'none'
|
|
1157
|
-
|
|
1158
|
-
onedal_params = {
|
|
1159
|
-
'n_estimators': self.n_estimators,
|
|
1160
|
-
'criterion': self.criterion,
|
|
1161
|
-
'max_depth': self.max_depth,
|
|
1162
|
-
'min_samples_split': self.min_samples_split,
|
|
1163
|
-
'min_samples_leaf': self.min_samples_leaf,
|
|
1164
|
-
'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
|
|
1165
|
-
'max_features': self.max_features,
|
|
1166
|
-
'max_leaf_nodes': self.max_leaf_nodes,
|
|
1167
|
-
'min_impurity_decrease': self.min_impurity_decrease,
|
|
1168
|
-
'bootstrap': self.bootstrap,
|
|
1169
|
-
'oob_score': self.oob_score,
|
|
1170
|
-
'n_jobs': self.n_jobs,
|
|
1171
|
-
'random_state': rs_,
|
|
1172
|
-
'verbose': self.verbose,
|
|
1173
|
-
'warm_start': self.warm_start,
|
|
1174
|
-
'error_metric_mode': err,
|
|
1175
|
-
'variable_importance_mode': 'mdi',
|
|
1176
|
-
'max_samples': self.max_samples
|
|
1177
|
-
}
|
|
1178
|
-
if daal_check_version((2023, 'P', 101)):
|
|
1179
|
-
onedal_params['splitter_mode'] = "random"
|
|
1180
|
-
self._cached_estimators_ = None
|
|
1181
|
-
self._onedal_estimator = self._onedal_regressor(**onedal_params)
|
|
1182
|
-
self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
|
|
1183
|
-
|
|
1184
|
-
self._save_attributes()
|
|
1185
|
-
if sklearn_check_version("1.2"):
|
|
1186
|
-
self._estimator = ExtraTreeRegressor()
|
|
1187
|
-
self.estimators_ = self._estimators_
|
|
1188
|
-
return self
|
|
1189
|
-
|
|
1190
|
-
def _onedal_predict(self, X, queue=None):
|
|
1191
|
-
if sklearn_check_version("1.0"):
|
|
1192
|
-
self._check_feature_names(X, reset=False)
|
|
1193
|
-
X = self._validate_X_predict(X)
|
|
1194
|
-
return self._onedal_estimator.predict(X, queue=queue)
|
|
1195
|
-
|
|
1196
|
-
def fit(self, X, y, sample_weight=None):
|
|
1197
|
-
"""
|
|
1198
|
-
Build a forest of trees from the training set (X, y).
|
|
1199
|
-
|
|
1200
|
-
Parameters
|
|
1201
|
-
----------
|
|
1202
|
-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
1203
|
-
The training input samples. Internally, its dtype will be converted
|
|
1204
|
-
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
1205
|
-
converted into a sparse ``csc_matrix``.
|
|
1206
|
-
|
|
1207
|
-
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
|
|
1208
|
-
The target values (class labels in classification, real numbers in
|
|
1209
|
-
regression).
|
|
1210
|
-
|
|
1211
|
-
sample_weight : array-like of shape (n_samples,), default=None
|
|
1212
|
-
Sample weights. If None, then samples are equally weighted. Splits
|
|
1213
|
-
that would create child nodes with net zero or negative weight are
|
|
1214
|
-
ignored while searching for a split in each node. In the case of
|
|
1215
|
-
classification, splits are also ignored if they would result in any
|
|
1216
|
-
single class carrying a negative weight in either child node.
|
|
1217
|
-
|
|
1218
|
-
Returns
|
|
1219
|
-
-------
|
|
1220
|
-
self : object
|
|
1221
|
-
"""
|
|
1222
|
-
dispatch(self, 'fit', {
|
|
1223
|
-
'onedal': self.__class__._onedal_fit,
|
|
1224
|
-
'sklearn': sklearn_ExtraTreesRegressor.fit,
|
|
1225
|
-
}, X, y, sample_weight)
|
|
1226
|
-
return self
|
|
1227
|
-
|
|
1228
|
-
@wrap_output_data
|
|
1229
|
-
def predict(self, X):
|
|
1230
|
-
"""
|
|
1231
|
-
Predict class for X.
|
|
1232
|
-
|
|
1233
|
-
The predicted class of an input sample is a vote by the trees in
|
|
1234
|
-
the forest, weighted by their probability estimates. That is,
|
|
1235
|
-
the predicted class is the one with highest mean probability
|
|
1236
|
-
estimate across the trees.
|
|
1237
|
-
|
|
1238
|
-
Parameters
|
|
1239
|
-
----------
|
|
1240
|
-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
|
1241
|
-
The input samples. Internally, its dtype will be converted to
|
|
1242
|
-
``dtype=np.float32``. If a sparse matrix is provided, it will be
|
|
1243
|
-
converted into a sparse ``csr_matrix``.
|
|
1244
|
-
|
|
1245
|
-
Returns
|
|
1246
|
-
-------
|
|
1247
|
-
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
|
|
1248
|
-
The predicted classes.
|
|
1249
|
-
"""
|
|
1250
|
-
return dispatch(self, 'predict', {
|
|
1251
|
-
'onedal': self.__class__._onedal_predict,
|
|
1252
|
-
'sklearn': sklearn_ExtraTreesRegressor.predict,
|
|
1253
|
-
}, X)
|
|
1254
|
-
|
|
1255
|
-
if sklearn_check_version('1.0'):
|
|
1256
|
-
@deprecated(
|
|
1257
|
-
"Attribute `n_features_` was deprecated in version 1.0 and will be "
|
|
1258
|
-
"removed in 1.2. Use `n_features_in_` instead.")
|
|
1259
|
-
@property
|
|
1260
|
-
def n_features_(self):
|
|
1261
|
-
return self.n_features_in_
|