scikit-learn-intelex 2025.0.0__py311-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-311-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-311-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +242 -0
- daal4py/sklearn/_utils.py +241 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +155 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +693 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +53 -0
- onedal/_device_offload.py +229 -0
- onedal/_onedal_py_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +110 -0
- onedal/cluster/kmeans.py +560 -0
- onedal/cluster/kmeans_init.py +115 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +59 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +116 -0
- onedal/common/tests/test_policy.py +75 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +125 -0
- onedal/covariance/incremental_covariance.py +146 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +122 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +95 -0
- onedal/datatypes/tests/test_data.py +235 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +204 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +198 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +720 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +258 -0
- onedal/linear_model/linear_model.py +329 -0
- onedal/linear_model/logistic_regression.py +249 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- onedal/linear_model/tests/test_linear_regression.py +149 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +778 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +153 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +82 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +117 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +97 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +168 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +41 -0
- onedal/tests/utils/_dataframes_support.py +168 -0
- onedal/tests/utils/_device_selection.py +107 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +91 -0
- onedal/utils/validation.py +432 -0
- scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
- scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
- scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +65 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +98 -0
- sklearnex/_device_offload.py +121 -0
- sklearnex/_utils.py +109 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +140 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +192 -0
- sklearnex/cluster/k_means.py +383 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +153 -0
- sklearnex/conftest.py +73 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +368 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +414 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +543 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2016 -0
- sklearnex/ensemble/tests/test_forest.py +120 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +463 -0
- sklearnex/linear_model/incremental_ridge.py +418 -0
- sklearnex/linear_model/linear.py +302 -0
- sklearnex/linear_model/logistic_path.py +17 -0
- sklearnex/linear_model/logistic_regression.py +403 -0
- sklearnex/linear_model/ridge.py +24 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +21 -0
- sklearnex/manifold/tests/test_tsne.py +26 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +231 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +226 -0
- sklearnex/neighbors/knn_regression.py +203 -0
- sklearnex/neighbors/knn_unsupervised.py +170 -0
- sklearnex/neighbors/tests/test_neighbors.py +80 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +133 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +228 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- sklearnex/preview/linear_model/__init__.py +19 -0
- sklearnex/preview/linear_model/ridge.py +419 -0
- sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +328 -0
- sklearnex/svm/nusvc.py +332 -0
- sklearnex/svm/nusvr.py +148 -0
- sklearnex/svm/svc.py +360 -0
- sklearnex/svm/svr.py +149 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/_utils.py +328 -0
- sklearnex/tests/_utils_spmd.py +198 -0
- sklearnex/tests/test_common.py +54 -0
- sklearnex/tests/test_config.py +43 -0
- sklearnex/tests/test_memory_usage.py +291 -0
- sklearnex/tests/test_monkeypatch.py +276 -0
- sklearnex/tests/test_n_jobs_support.py +103 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +385 -0
- sklearnex/tests/test_run_to_run_stability.py +296 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_finite.py +89 -0
- sklearnex/utils/validation.py +17 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
# daal4py Model builders API
|
|
18
|
+
|
|
19
|
+
from typing import Literal, Optional
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
import daal4py as d4p
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from pandas import DataFrame
|
|
27
|
+
from pandas.core.dtypes.cast import find_common_type
|
|
28
|
+
|
|
29
|
+
pandas_is_imported = True
|
|
30
|
+
except (ImportError, ModuleNotFoundError):
|
|
31
|
+
pandas_is_imported = False
|
|
32
|
+
|
|
33
|
+
from sklearn.utils.metaestimators import available_if
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def parse_dtype(dt):
|
|
37
|
+
if dt == np.double:
|
|
38
|
+
return "double"
|
|
39
|
+
if dt == np.single:
|
|
40
|
+
return "float"
|
|
41
|
+
raise ValueError(f"Input array has unexpected dtype = {dt}")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def getFPType(X):
|
|
45
|
+
if pandas_is_imported:
|
|
46
|
+
if isinstance(X, DataFrame):
|
|
47
|
+
dt = find_common_type(X.dtypes.tolist())
|
|
48
|
+
return parse_dtype(dt)
|
|
49
|
+
|
|
50
|
+
dt = getattr(X, "dtype", None)
|
|
51
|
+
return parse_dtype(dt)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class GBTDAALBaseModel:
|
|
55
|
+
def __init__(self):
|
|
56
|
+
self.model_type: Optional[Literal["xgboost", "catboost", "lightgbm"]] = None
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def _is_regression(self):
|
|
60
|
+
return hasattr(self, "daal_model_") and isinstance(
|
|
61
|
+
self.daal_model_, d4p.gbt_regression_model
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _get_params_from_lightgbm(self, params):
|
|
65
|
+
self.n_classes_ = params["num_tree_per_iteration"]
|
|
66
|
+
objective_fun = params["objective"]
|
|
67
|
+
if self.n_classes_ <= 2:
|
|
68
|
+
if "binary" in objective_fun: # nClasses == 1
|
|
69
|
+
self.n_classes_ = 2
|
|
70
|
+
|
|
71
|
+
self.n_features_in_ = params["max_feature_idx"] + 1
|
|
72
|
+
|
|
73
|
+
def _get_params_from_xgboost(self, params):
|
|
74
|
+
self.n_classes_ = int(params["learner"]["learner_model_param"]["num_class"])
|
|
75
|
+
objective_fun = params["learner"]["learner_train_param"]["objective"]
|
|
76
|
+
if self.n_classes_ <= 2:
|
|
77
|
+
if objective_fun in ["binary:logistic", "binary:logitraw"]:
|
|
78
|
+
self.n_classes_ = 2
|
|
79
|
+
|
|
80
|
+
self.n_features_in_ = int(params["learner"]["learner_model_param"]["num_feature"])
|
|
81
|
+
|
|
82
|
+
def _get_params_from_catboost(self, params):
|
|
83
|
+
if "class_params" in params["model_info"]:
|
|
84
|
+
self.n_classes_ = len(params["model_info"]["class_params"]["class_to_label"])
|
|
85
|
+
self.n_features_in_ = len(params["features_info"]["float_features"])
|
|
86
|
+
|
|
87
|
+
def _convert_model_from_lightgbm(self, booster):
|
|
88
|
+
lgbm_params = d4p.get_lightgbm_params(booster)
|
|
89
|
+
self.daal_model_ = d4p.get_gbt_model_from_lightgbm(booster, lgbm_params)
|
|
90
|
+
self._get_params_from_lightgbm(lgbm_params)
|
|
91
|
+
|
|
92
|
+
def _convert_model_from_xgboost(self, booster):
|
|
93
|
+
xgb_params = d4p.get_xgboost_params(booster)
|
|
94
|
+
self.daal_model_ = d4p.get_gbt_model_from_xgboost(booster, xgb_params)
|
|
95
|
+
self._get_params_from_xgboost(xgb_params)
|
|
96
|
+
|
|
97
|
+
def _convert_model_from_catboost(self, booster):
|
|
98
|
+
catboost_params = d4p.get_catboost_params(booster)
|
|
99
|
+
self.daal_model_ = d4p.get_gbt_model_from_catboost(booster)
|
|
100
|
+
self._get_params_from_catboost(catboost_params)
|
|
101
|
+
|
|
102
|
+
def _convert_model(self, model):
|
|
103
|
+
(submodule_name, class_name) = (
|
|
104
|
+
model.__class__.__module__,
|
|
105
|
+
model.__class__.__name__,
|
|
106
|
+
)
|
|
107
|
+
self_class_name = self.__class__.__name__
|
|
108
|
+
|
|
109
|
+
# Build GBTDAALClassifier from LightGBM
|
|
110
|
+
if (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMClassifier"):
|
|
111
|
+
if self_class_name == "GBTDAALClassifier":
|
|
112
|
+
self._convert_model_from_lightgbm(model.booster_)
|
|
113
|
+
else:
|
|
114
|
+
raise TypeError(
|
|
115
|
+
f"Only GBTDAALClassifier can be created from\
|
|
116
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
117
|
+
)
|
|
118
|
+
# Build GBTDAALClassifier from XGBoost
|
|
119
|
+
elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBClassifier"):
|
|
120
|
+
if self_class_name == "GBTDAALClassifier":
|
|
121
|
+
self._convert_model_from_xgboost(model.get_booster())
|
|
122
|
+
else:
|
|
123
|
+
raise TypeError(
|
|
124
|
+
f"Only GBTDAALClassifier can be created from\
|
|
125
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
126
|
+
)
|
|
127
|
+
# Build GBTDAALClassifier from CatBoost
|
|
128
|
+
elif (submodule_name, class_name) == ("catboost.core", "CatBoostClassifier"):
|
|
129
|
+
if self_class_name == "GBTDAALClassifier":
|
|
130
|
+
self._convert_model_from_catboost(model)
|
|
131
|
+
else:
|
|
132
|
+
raise TypeError(
|
|
133
|
+
f"Only GBTDAALClassifier can be created from\
|
|
134
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
135
|
+
)
|
|
136
|
+
# Build GBTDAALRegressor from LightGBM
|
|
137
|
+
elif (submodule_name, class_name) == ("lightgbm.sklearn", "LGBMRegressor"):
|
|
138
|
+
if self_class_name == "GBTDAALRegressor":
|
|
139
|
+
self._convert_model_from_lightgbm(model.booster_)
|
|
140
|
+
else:
|
|
141
|
+
raise TypeError(
|
|
142
|
+
f"Only GBTDAALRegressor can be created from\
|
|
143
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
144
|
+
)
|
|
145
|
+
# Build GBTDAALRegressor from XGBoost
|
|
146
|
+
elif (submodule_name, class_name) == ("xgboost.sklearn", "XGBRegressor"):
|
|
147
|
+
if self_class_name == "GBTDAALRegressor":
|
|
148
|
+
self._convert_model_from_xgboost(model.get_booster())
|
|
149
|
+
else:
|
|
150
|
+
raise TypeError(
|
|
151
|
+
f"Only GBTDAALRegressor can be created from\
|
|
152
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
153
|
+
)
|
|
154
|
+
# Build GBTDAALRegressor from CatBoost
|
|
155
|
+
elif (submodule_name, class_name) == ("catboost.core", "CatBoostRegressor"):
|
|
156
|
+
if self_class_name == "GBTDAALRegressor":
|
|
157
|
+
self._convert_model_from_catboost(model)
|
|
158
|
+
else:
|
|
159
|
+
raise TypeError(
|
|
160
|
+
f"Only GBTDAALRegressor can be created from\
|
|
161
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
162
|
+
)
|
|
163
|
+
# Build GBTDAALModel from LightGBM
|
|
164
|
+
elif (submodule_name, class_name) == ("lightgbm.basic", "Booster"):
|
|
165
|
+
if self_class_name == "GBTDAALModel":
|
|
166
|
+
self._convert_model_from_lightgbm(model)
|
|
167
|
+
else:
|
|
168
|
+
raise TypeError(
|
|
169
|
+
f"Only GBTDAALModel can be created from\
|
|
170
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
171
|
+
)
|
|
172
|
+
# Build GBTDAALModel from XGBoost
|
|
173
|
+
elif (submodule_name, class_name) == ("xgboost.core", "Booster"):
|
|
174
|
+
if self_class_name == "GBTDAALModel":
|
|
175
|
+
self._convert_model_from_xgboost(model)
|
|
176
|
+
else:
|
|
177
|
+
raise TypeError(
|
|
178
|
+
f"Only GBTDAALModel can be created from\
|
|
179
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
180
|
+
)
|
|
181
|
+
# Build GBTDAALModel from CatBoost
|
|
182
|
+
elif (submodule_name, class_name) == ("catboost.core", "CatBoost"):
|
|
183
|
+
if self_class_name == "GBTDAALModel":
|
|
184
|
+
self._convert_model_from_catboost(model)
|
|
185
|
+
else:
|
|
186
|
+
raise TypeError(
|
|
187
|
+
f"Only GBTDAALModel can be created from\
|
|
188
|
+
{submodule_name}.{class_name} (got {self_class_name})"
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
raise TypeError(f"Unknown model format {submodule_name}.{class_name}")
|
|
192
|
+
|
|
193
|
+
def _predict_classification(
|
|
194
|
+
self, X, fptype, resultsToEvaluate, pred_contribs=False, pred_interactions=False
|
|
195
|
+
):
|
|
196
|
+
if X.shape[1] != self.n_features_in_:
|
|
197
|
+
raise ValueError("Shape of input is different from what was seen in `fit`")
|
|
198
|
+
|
|
199
|
+
if not hasattr(self, "daal_model_"):
|
|
200
|
+
raise ValueError(
|
|
201
|
+
(
|
|
202
|
+
"The class {} instance does not have 'daal_model_' attribute set. "
|
|
203
|
+
"Call 'fit' with appropriate arguments before using this method."
|
|
204
|
+
).format(type(self).__name__)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Prediction
|
|
208
|
+
try:
|
|
209
|
+
return self._predict_classification_with_results_to_compute(
|
|
210
|
+
X, fptype, resultsToEvaluate, pred_contribs, pred_interactions
|
|
211
|
+
)
|
|
212
|
+
except TypeError as e:
|
|
213
|
+
if "unexpected keyword argument 'resultsToCompute'" in str(e):
|
|
214
|
+
if pred_contribs or pred_interactions:
|
|
215
|
+
# SHAP values requested, but not supported by this version
|
|
216
|
+
raise TypeError(
|
|
217
|
+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} not supported by this version of daal4py"
|
|
218
|
+
) from e
|
|
219
|
+
else:
|
|
220
|
+
# unknown type error
|
|
221
|
+
raise
|
|
222
|
+
except RuntimeError as e:
|
|
223
|
+
if "Method is not implemented" in str(e):
|
|
224
|
+
if pred_contribs or pred_interactions:
|
|
225
|
+
raise NotImplementedError(
|
|
226
|
+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
|
|
227
|
+
)
|
|
228
|
+
else:
|
|
229
|
+
raise
|
|
230
|
+
|
|
231
|
+
# fallback to calculation without `resultsToCompute`
|
|
232
|
+
predict_algo = d4p.gbt_classification_prediction(
|
|
233
|
+
nClasses=self.n_classes_,
|
|
234
|
+
fptype=fptype,
|
|
235
|
+
resultsToEvaluate=resultsToEvaluate,
|
|
236
|
+
)
|
|
237
|
+
predict_result = predict_algo.compute(X, self.daal_model_)
|
|
238
|
+
|
|
239
|
+
if resultsToEvaluate == "computeClassLabels":
|
|
240
|
+
return predict_result.prediction.ravel().astype(np.int64, copy=False)
|
|
241
|
+
else:
|
|
242
|
+
return predict_result.probabilities
|
|
243
|
+
|
|
244
|
+
def _predict_classification_with_results_to_compute(
|
|
245
|
+
self,
|
|
246
|
+
X,
|
|
247
|
+
fptype,
|
|
248
|
+
resultsToEvaluate,
|
|
249
|
+
pred_contribs=False,
|
|
250
|
+
pred_interactions=False,
|
|
251
|
+
):
|
|
252
|
+
"""Assume daal4py supports the resultsToCompute kwarg"""
|
|
253
|
+
resultsToCompute = ""
|
|
254
|
+
if pred_contribs:
|
|
255
|
+
resultsToCompute = "shapContributions"
|
|
256
|
+
elif pred_interactions:
|
|
257
|
+
resultsToCompute = "shapInteractions"
|
|
258
|
+
|
|
259
|
+
predict_algo = d4p.gbt_classification_prediction(
|
|
260
|
+
nClasses=self.n_classes_,
|
|
261
|
+
fptype=fptype,
|
|
262
|
+
resultsToCompute=resultsToCompute,
|
|
263
|
+
resultsToEvaluate=resultsToEvaluate,
|
|
264
|
+
)
|
|
265
|
+
predict_result = predict_algo.compute(X, self.daal_model_)
|
|
266
|
+
|
|
267
|
+
if pred_contribs:
|
|
268
|
+
return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1))
|
|
269
|
+
elif pred_interactions:
|
|
270
|
+
return predict_result.prediction.ravel().reshape(
|
|
271
|
+
(-1, X.shape[1] + 1, X.shape[1] + 1)
|
|
272
|
+
)
|
|
273
|
+
elif resultsToEvaluate == "computeClassLabels":
|
|
274
|
+
return predict_result.prediction.ravel().astype(np.int64, copy=False)
|
|
275
|
+
else:
|
|
276
|
+
return predict_result.probabilities
|
|
277
|
+
|
|
278
|
+
def _predict_regression(
|
|
279
|
+
self, X, fptype, pred_contribs=False, pred_interactions=False
|
|
280
|
+
):
|
|
281
|
+
if X.shape[1] != self.n_features_in_:
|
|
282
|
+
raise ValueError("Shape of input is different from what was seen in `fit`")
|
|
283
|
+
|
|
284
|
+
if not hasattr(self, "daal_model_"):
|
|
285
|
+
raise ValueError(
|
|
286
|
+
(
|
|
287
|
+
"The class {} instance does not have 'daal_model_' attribute set. "
|
|
288
|
+
"Call 'fit' with appropriate arguments before using this method."
|
|
289
|
+
).format(type(self).__name__)
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
return self._predict_regression_with_results_to_compute(
|
|
294
|
+
X, fptype, pred_contribs, pred_interactions
|
|
295
|
+
)
|
|
296
|
+
except TypeError as e:
|
|
297
|
+
if "unexpected keyword argument 'resultsToCompute'" in str(e):
|
|
298
|
+
if pred_contribs or pred_interactions:
|
|
299
|
+
# SHAP values requested, but not supported by this version
|
|
300
|
+
raise TypeError(
|
|
301
|
+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} not supported by this version of daalp4y"
|
|
302
|
+
) from e
|
|
303
|
+
else:
|
|
304
|
+
# unknown type error
|
|
305
|
+
raise
|
|
306
|
+
|
|
307
|
+
# fallback to calculation without `resultsToCompute`
|
|
308
|
+
predict_algo = d4p.gbt_regression_prediction(fptype=fptype)
|
|
309
|
+
predict_result = predict_algo.compute(X, self.daal_model_)
|
|
310
|
+
return predict_result.prediction.ravel()
|
|
311
|
+
|
|
312
|
+
def _predict_regression_with_results_to_compute(
|
|
313
|
+
self, X, fptype, pred_contribs=False, pred_interactions=False
|
|
314
|
+
):
|
|
315
|
+
"""Assume daal4py supports the resultsToCompute kwarg"""
|
|
316
|
+
resultsToCompute = ""
|
|
317
|
+
if pred_contribs:
|
|
318
|
+
resultsToCompute = "shapContributions"
|
|
319
|
+
elif pred_interactions:
|
|
320
|
+
resultsToCompute = "shapInteractions"
|
|
321
|
+
|
|
322
|
+
predict_algo = d4p.gbt_regression_prediction(
|
|
323
|
+
fptype=fptype, resultsToCompute=resultsToCompute
|
|
324
|
+
)
|
|
325
|
+
predict_result = predict_algo.compute(X, self.daal_model_)
|
|
326
|
+
|
|
327
|
+
if pred_contribs:
|
|
328
|
+
return predict_result.prediction.ravel().reshape((-1, X.shape[1] + 1))
|
|
329
|
+
elif pred_interactions:
|
|
330
|
+
return predict_result.prediction.ravel().reshape(
|
|
331
|
+
(-1, X.shape[1] + 1, X.shape[1] + 1)
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
return predict_result.prediction.ravel()
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class GBTDAALModel(GBTDAALBaseModel):
|
|
338
|
+
def predict(self, X, pred_contribs=False, pred_interactions=False):
|
|
339
|
+
fptype = getFPType(X)
|
|
340
|
+
if self._is_regression:
|
|
341
|
+
return self._predict_regression(X, fptype, pred_contribs, pred_interactions)
|
|
342
|
+
else:
|
|
343
|
+
if (pred_contribs or pred_interactions) and self.model_type != "xgboost":
|
|
344
|
+
raise NotImplementedError(
|
|
345
|
+
f"{'pred_contribs' if pred_contribs else 'pred_interactions'} is not implemented for classification models"
|
|
346
|
+
)
|
|
347
|
+
return self._predict_classification(
|
|
348
|
+
X, fptype, "computeClassLabels", pred_contribs, pred_interactions
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def _check_proba(self):
|
|
352
|
+
return not self._is_regression
|
|
353
|
+
|
|
354
|
+
@available_if(_check_proba)
|
|
355
|
+
def predict_proba(self, X):
|
|
356
|
+
fptype = getFPType(X)
|
|
357
|
+
return self._predict_classification(X, fptype, "computeClassProbabilities")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def convert_model(model):
|
|
361
|
+
try:
|
|
362
|
+
gbm = GBTDAALModel()
|
|
363
|
+
gbm._convert_model(model)
|
|
364
|
+
except TypeError as err:
|
|
365
|
+
if "Only GBTDAALRegressor can be created" in str(err):
|
|
366
|
+
gbm = d4p.sklearn.ensemble.GBTDAALRegressor.convert_model(model)
|
|
367
|
+
elif "Only GBTDAALClassifier can be created" in str(err):
|
|
368
|
+
gbm = d4p.sklearn.ensemble.GBTDAALClassifier.convert_model(model)
|
|
369
|
+
else:
|
|
370
|
+
raise
|
|
371
|
+
|
|
372
|
+
for type_str in ("xgboost", "lightgbm", "catboost"):
|
|
373
|
+
if type_str in str(type(model)):
|
|
374
|
+
gbm.model_type = type_str
|
|
375
|
+
break
|
|
376
|
+
|
|
377
|
+
return gbm
|
|
Binary file
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from .monkeypatch.dispatcher import _get_map_of_algorithms as sklearn_patch_map
|
|
18
|
+
from .monkeypatch.dispatcher import _patch_names as sklearn_patch_names
|
|
19
|
+
from .monkeypatch.dispatcher import disable as unpatch_sklearn
|
|
20
|
+
from .monkeypatch.dispatcher import enable as patch_sklearn
|
|
21
|
+
from .monkeypatch.dispatcher import patch_is_enabled as sklearn_is_patched
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"cluster",
|
|
25
|
+
"decomposition",
|
|
26
|
+
"ensemble",
|
|
27
|
+
"linear_model",
|
|
28
|
+
"manifold",
|
|
29
|
+
"metrics",
|
|
30
|
+
"model_selection",
|
|
31
|
+
"neighbors",
|
|
32
|
+
"patch_sklearn",
|
|
33
|
+
"sklearn_is_patched",
|
|
34
|
+
"sklearn_patch_map",
|
|
35
|
+
"sklearn_patch_names",
|
|
36
|
+
"svm",
|
|
37
|
+
"tree",
|
|
38
|
+
"unpatch_sklearn",
|
|
39
|
+
"utils",
|
|
40
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import threading
|
|
19
|
+
from functools import wraps
|
|
20
|
+
from inspect import Parameter, signature
|
|
21
|
+
from multiprocessing import cpu_count
|
|
22
|
+
from numbers import Integral
|
|
23
|
+
from warnings import warn
|
|
24
|
+
|
|
25
|
+
import threadpoolctl
|
|
26
|
+
|
|
27
|
+
from daal4py import daalinit as set_n_threads
|
|
28
|
+
from daal4py import num_threads as get_n_threads
|
|
29
|
+
|
|
30
|
+
from ._utils import sklearn_check_version
|
|
31
|
+
|
|
32
|
+
if sklearn_check_version("1.2"):
|
|
33
|
+
from sklearn.utils._param_validation import validate_parameter_constraints
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Note: getting controller in global scope of this module is required
|
|
37
|
+
# to avoid overheads by its initialization per each function call
|
|
38
|
+
threadpool_controller = threadpoolctl.ThreadpoolController()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_suggested_n_threads(n_cpus):
|
|
42
|
+
"""
|
|
43
|
+
Function to get `n_threads` limit
|
|
44
|
+
if `n_jobs` is set in upper parallelization context.
|
|
45
|
+
Usually, limit is equal to `n_logical_cpus` // `n_jobs`.
|
|
46
|
+
Returns None if limit is not set.
|
|
47
|
+
"""
|
|
48
|
+
n_threads_map = {
|
|
49
|
+
lib_ctl.internal_api: lib_ctl.get_num_threads()
|
|
50
|
+
for lib_ctl in threadpool_controller.lib_controllers
|
|
51
|
+
if lib_ctl.internal_api != "mkl"
|
|
52
|
+
}
|
|
53
|
+
# openBLAS is limited to 24, 64 or 128 threads by default
|
|
54
|
+
# depending on SW/HW configuration.
|
|
55
|
+
# thus, these numbers of threads from openBLAS are uninformative
|
|
56
|
+
if "openblas" in n_threads_map and n_threads_map["openblas"] in [24, 64, 128]:
|
|
57
|
+
del n_threads_map["openblas"]
|
|
58
|
+
# remove default values equal to n_cpus as uninformative
|
|
59
|
+
for backend in list(n_threads_map.keys()):
|
|
60
|
+
if n_threads_map[backend] == n_cpus:
|
|
61
|
+
del n_threads_map[backend]
|
|
62
|
+
if len(n_threads_map) > 0:
|
|
63
|
+
return min(n_threads_map.values())
|
|
64
|
+
else:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _run_with_n_jobs(method):
|
|
69
|
+
"""
|
|
70
|
+
Decorator for running of methods containing oneDAL kernels with 'n_jobs'.
|
|
71
|
+
|
|
72
|
+
Outside actual call of decorated method, this decorator:
|
|
73
|
+
- checks correctness of passed 'n_jobs',
|
|
74
|
+
- deducts actual number of threads to use,
|
|
75
|
+
- sets and resets this number for oneDAL environment.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
@wraps(method)
|
|
79
|
+
def method_wrapper(self, *args, **kwargs):
|
|
80
|
+
# threading parallel backend branch
|
|
81
|
+
if not isinstance(threading.current_thread(), threading._MainThread):
|
|
82
|
+
warn(
|
|
83
|
+
"'Threading' parallel backend is not supported by "
|
|
84
|
+
"Intel(R) Extension for Scikit-learn*. "
|
|
85
|
+
"Falling back to usage of all available threads."
|
|
86
|
+
)
|
|
87
|
+
result = method(self, *args, **kwargs)
|
|
88
|
+
return result
|
|
89
|
+
# multiprocess parallel backends branch
|
|
90
|
+
# preemptive validation of n_jobs parameter is required
|
|
91
|
+
# because '_run_with_n_jobs' decorator is applied on top of method
|
|
92
|
+
# where validation takes place
|
|
93
|
+
if sklearn_check_version("1.2") and hasattr(self, "_parameter_constraints"):
|
|
94
|
+
validate_parameter_constraints(
|
|
95
|
+
parameter_constraints={"n_jobs": self._parameter_constraints["n_jobs"]},
|
|
96
|
+
params={"n_jobs": self.n_jobs},
|
|
97
|
+
caller_name=self.__class__.__name__,
|
|
98
|
+
)
|
|
99
|
+
# search for specified n_jobs
|
|
100
|
+
n_jobs = self.n_jobs
|
|
101
|
+
n_cpus = cpu_count()
|
|
102
|
+
# receive n_threads limitation from upper parallelism context
|
|
103
|
+
# using `threadpoolctl.ThreadpoolController`
|
|
104
|
+
n_threads = get_suggested_n_threads(n_cpus)
|
|
105
|
+
# get real `n_jobs` number of threads for oneDAL
|
|
106
|
+
# using sklearn rules and `n_threads` from upper parallelism context
|
|
107
|
+
if n_jobs is None or n_jobs == 0:
|
|
108
|
+
if n_threads is None:
|
|
109
|
+
# default branch with no setting for n_jobs
|
|
110
|
+
return method(self, *args, **kwargs)
|
|
111
|
+
else:
|
|
112
|
+
n_jobs = n_threads
|
|
113
|
+
elif n_jobs < 0:
|
|
114
|
+
if n_threads is None:
|
|
115
|
+
n_jobs = max(1, n_cpus + n_jobs + 1)
|
|
116
|
+
else:
|
|
117
|
+
n_jobs = max(1, n_threads + n_jobs + 1)
|
|
118
|
+
# branch with set n_jobs
|
|
119
|
+
old_n_threads = get_n_threads()
|
|
120
|
+
if n_jobs != old_n_threads:
|
|
121
|
+
logger = logging.getLogger("sklearnex")
|
|
122
|
+
cl = self.__class__
|
|
123
|
+
logger.debug(
|
|
124
|
+
f"{cl.__module__}.{cl.__name__}.{method.__name__}: "
|
|
125
|
+
f"setting {n_jobs} threads (previous - {old_n_threads})"
|
|
126
|
+
)
|
|
127
|
+
set_n_threads(n_jobs)
|
|
128
|
+
result = method(self, *args, **kwargs)
|
|
129
|
+
if n_jobs != old_n_threads:
|
|
130
|
+
set_n_threads(old_n_threads)
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
return method_wrapper
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def control_n_jobs(decorated_methods: list = []):
|
|
137
|
+
"""
|
|
138
|
+
Decorator for controlling the 'n_jobs' parameter in an estimator class.
|
|
139
|
+
|
|
140
|
+
This decorator is designed to be applied to both estimators with and without
|
|
141
|
+
native support for the 'n_jobs' parameter in the original Scikit-learn APIs.
|
|
142
|
+
When applied to an estimator without 'n_jobs' support in
|
|
143
|
+
its original '__init__' method, this decorator adds the 'n_jobs' parameter.
|
|
144
|
+
|
|
145
|
+
Additionally, this decorator allows for fine-grained control over which methods
|
|
146
|
+
should be executed with the 'n_jobs' parameter. The methods specified in
|
|
147
|
+
the 'decorated_methods' argument will run with 'n_jobs',
|
|
148
|
+
while all other methods remain unaffected.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
decorated_methods (list): A list of method names to be executed with 'n_jobs'.
|
|
153
|
+
|
|
154
|
+
Example
|
|
155
|
+
-------
|
|
156
|
+
@control_n_jobs(decorated_methods=['fit', 'predict'])
|
|
157
|
+
|
|
158
|
+
class MyEstimator:
|
|
159
|
+
|
|
160
|
+
def __init__(self, *args, **kwargs):
|
|
161
|
+
# Your original __init__ implementation here
|
|
162
|
+
|
|
163
|
+
def fit(self, *args, **kwargs):
|
|
164
|
+
# Your original fit implementation here
|
|
165
|
+
|
|
166
|
+
def predict(self, *args, **kwargs):
|
|
167
|
+
# Your original predict implementation here
|
|
168
|
+
|
|
169
|
+
def other_method(self, *args, **kwargs):
|
|
170
|
+
# Methods not listed in decorated_methods will not be affected by 'n_jobs'
|
|
171
|
+
pass
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def class_wrapper(original_class):
|
|
175
|
+
original_class._n_jobs_supported_onedal_methods = decorated_methods.copy()
|
|
176
|
+
|
|
177
|
+
original_init = original_class.__init__
|
|
178
|
+
|
|
179
|
+
if sklearn_check_version("1.2") and hasattr(
|
|
180
|
+
original_class, "_parameter_constraints"
|
|
181
|
+
):
|
|
182
|
+
parameter_constraints = original_class._parameter_constraints
|
|
183
|
+
if "n_jobs" not in parameter_constraints:
|
|
184
|
+
parameter_constraints["n_jobs"] = [Integral, None]
|
|
185
|
+
|
|
186
|
+
@wraps(original_init)
|
|
187
|
+
def init_with_n_jobs(self, *args, n_jobs=None, **kwargs):
|
|
188
|
+
original_init(self, *args, **kwargs)
|
|
189
|
+
self.n_jobs = n_jobs
|
|
190
|
+
|
|
191
|
+
# add "n_jobs" parameter to signature of wrapped init
|
|
192
|
+
# if estimator doesn't originally support it
|
|
193
|
+
sig = signature(original_init)
|
|
194
|
+
if "n_jobs" not in sig.parameters:
|
|
195
|
+
params_copy = sig.parameters.copy()
|
|
196
|
+
params_copy.update(
|
|
197
|
+
{
|
|
198
|
+
"n_jobs": Parameter(
|
|
199
|
+
name="n_jobs", kind=Parameter.KEYWORD_ONLY, default=None
|
|
200
|
+
)
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
init_with_n_jobs.__signature__ = sig.replace(parameters=params_copy.values())
|
|
204
|
+
original_class.__init__ = init_with_n_jobs
|
|
205
|
+
|
|
206
|
+
# add n_jobs to __doc__ string if needed
|
|
207
|
+
if (
|
|
208
|
+
hasattr(original_class, "__doc__")
|
|
209
|
+
and isinstance(original_class.__doc__, str)
|
|
210
|
+
and "n_jobs : int" not in original_class.__doc__
|
|
211
|
+
):
|
|
212
|
+
parameters_doc_tail = "\n Attributes"
|
|
213
|
+
n_jobs_doc = """
|
|
214
|
+
n_jobs : int, default=None
|
|
215
|
+
The number of jobs to use in parallel for the computation.
|
|
216
|
+
``None`` means using all physical cores
|
|
217
|
+
unless in a :obj:`joblib.parallel_backend` context.
|
|
218
|
+
``-1`` means using all logical cores.
|
|
219
|
+
See :term:`Glossary <n_jobs>` for more details.
|
|
220
|
+
"""
|
|
221
|
+
original_class.__doc__ = original_class.__doc__.replace(
|
|
222
|
+
parameters_doc_tail, n_jobs_doc + parameters_doc_tail
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# decorate methods to be run with applied n_jobs parameter
|
|
226
|
+
for method_name in decorated_methods:
|
|
227
|
+
# if method doesn't exist, we want it to raise an Exception
|
|
228
|
+
method = getattr(original_class, method_name)
|
|
229
|
+
if not hasattr(method, "__onedal_n_jobs_decorated__"):
|
|
230
|
+
decorated_method = _run_with_n_jobs(method)
|
|
231
|
+
# sign decorated method for testing and other purposes
|
|
232
|
+
decorated_method.__onedal_n_jobs_decorated__ = True
|
|
233
|
+
setattr(original_class, method_name, decorated_method)
|
|
234
|
+
else:
|
|
235
|
+
warn(
|
|
236
|
+
f"{original_class.__name__}.{method_name} already has "
|
|
237
|
+
"oneDAL n_jobs support and will not be decorated."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return original_class
|
|
241
|
+
|
|
242
|
+
return class_wrapper
|