scikit-learn-intelex 2025.0.0__py312-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-312-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-312-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +242 -0
- daal4py/sklearn/_utils.py +241 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +155 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +693 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +53 -0
- onedal/_device_offload.py +229 -0
- onedal/_onedal_py_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-312-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-312-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +110 -0
- onedal/cluster/kmeans.py +560 -0
- onedal/cluster/kmeans_init.py +115 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +59 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +116 -0
- onedal/common/tests/test_policy.py +75 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +125 -0
- onedal/covariance/incremental_covariance.py +146 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +122 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +95 -0
- onedal/datatypes/tests/test_data.py +235 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +204 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +198 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +720 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +258 -0
- onedal/linear_model/linear_model.py +329 -0
- onedal/linear_model/logistic_regression.py +249 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- onedal/linear_model/tests/test_linear_regression.py +149 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +778 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +153 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +82 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +117 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +97 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +168 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +41 -0
- onedal/tests/utils/_dataframes_support.py +168 -0
- onedal/tests/utils/_device_selection.py +107 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +91 -0
- onedal/utils/validation.py +432 -0
- scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
- scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
- scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +65 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +98 -0
- sklearnex/_device_offload.py +121 -0
- sklearnex/_utils.py +109 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +140 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +192 -0
- sklearnex/cluster/k_means.py +383 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +153 -0
- sklearnex/conftest.py +73 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +368 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +414 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +543 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2016 -0
- sklearnex/ensemble/tests/test_forest.py +120 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +463 -0
- sklearnex/linear_model/incremental_ridge.py +418 -0
- sklearnex/linear_model/linear.py +302 -0
- sklearnex/linear_model/logistic_path.py +17 -0
- sklearnex/linear_model/logistic_regression.py +403 -0
- sklearnex/linear_model/ridge.py +24 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +21 -0
- sklearnex/manifold/tests/test_tsne.py +26 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +231 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +226 -0
- sklearnex/neighbors/knn_regression.py +203 -0
- sklearnex/neighbors/knn_unsupervised.py +170 -0
- sklearnex/neighbors/tests/test_neighbors.py +80 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +133 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +228 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- sklearnex/preview/linear_model/__init__.py +19 -0
- sklearnex/preview/linear_model/ridge.py +419 -0
- sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +328 -0
- sklearnex/svm/nusvc.py +332 -0
- sklearnex/svm/nusvr.py +148 -0
- sklearnex/svm/svc.py +360 -0
- sklearnex/svm/svr.py +149 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/_utils.py +328 -0
- sklearnex/tests/_utils_spmd.py +198 -0
- sklearnex/tests/test_common.py +54 -0
- sklearnex/tests/test_config.py +43 -0
- sklearnex/tests/test_memory_usage.py +291 -0
- sklearnex/tests/test_monkeypatch.py +276 -0
- sklearnex/tests/test_n_jobs_support.py +103 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +385 -0
- sklearnex/tests/test_run_to_run_stability.py +296 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_finite.py +89 -0
- sklearnex/utils/validation.py +17 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
|
|
18
|
+
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch
|
|
19
|
+
|
|
20
|
+
from ..._device_offload import support_input_format
|
|
21
|
+
from .._base import BaseEstimatorSPMD
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class KNeighborsClassifier(BaseEstimatorSPMD, KNeighborsClassifier_Batch):
|
|
25
|
+
@support_input_format()
|
|
26
|
+
def fit(self, X, y, queue=None):
|
|
27
|
+
return super().fit(X, y, queue=queue)
|
|
28
|
+
|
|
29
|
+
@support_input_format()
|
|
30
|
+
def predict(self, X, queue=None):
|
|
31
|
+
return super().predict(X, queue=queue)
|
|
32
|
+
|
|
33
|
+
@support_input_format()
|
|
34
|
+
def predict_proba(self, X, queue=None):
|
|
35
|
+
raise NotImplementedError("predict_proba not supported in distributed mode.")
|
|
36
|
+
|
|
37
|
+
@support_input_format()
|
|
38
|
+
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
|
|
39
|
+
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KNeighborsRegressor(BaseEstimatorSPMD, KNeighborsRegressor_Batch):
|
|
43
|
+
@support_input_format()
|
|
44
|
+
def fit(self, X, y, queue=None):
|
|
45
|
+
if queue is not None and queue.sycl_device.is_gpu:
|
|
46
|
+
return super()._fit(X, y, queue=queue)
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
"SPMD version of kNN is not implemented for "
|
|
50
|
+
"CPU. Consider running on it on GPU."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@support_input_format()
|
|
54
|
+
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
|
|
55
|
+
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
|
|
56
|
+
|
|
57
|
+
@support_input_format()
|
|
58
|
+
def predict(self, X, queue=None):
|
|
59
|
+
return self._predict_gpu(X, queue=queue)
|
|
60
|
+
|
|
61
|
+
def _get_onedal_params(self, X, y=None):
|
|
62
|
+
params = super()._get_onedal_params(X, y)
|
|
63
|
+
if "responses" not in params["result_option"]:
|
|
64
|
+
params["result_option"] += "|responses"
|
|
65
|
+
return params
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class NearestNeighbors(BaseEstimatorSPMD):
|
|
69
|
+
@support_input_format()
|
|
70
|
+
def fit(self, X, y, queue=None):
|
|
71
|
+
return super().fit(X, y, queue=queue)
|
|
72
|
+
|
|
73
|
+
@support_input_format()
|
|
74
|
+
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
|
|
75
|
+
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
|
onedal/svm/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from .svm import SVC, SVR, NuSVC, NuSVR, SVMtype
|
|
18
|
+
|
|
19
|
+
__all__ = ["SVC", "SVR", "NuSVC", "NuSVR", "SVMtype"]
|
onedal/svm/svm.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from abc import ABCMeta, abstractmethod
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from scipy import sparse as sp
|
|
22
|
+
|
|
23
|
+
from onedal import _backend
|
|
24
|
+
|
|
25
|
+
from ..common._estimator_checks import _check_is_fitted
|
|
26
|
+
from ..common._mixin import ClassifierMixin, RegressorMixin
|
|
27
|
+
from ..common._policy import _get_policy
|
|
28
|
+
from ..datatypes import _convert_to_supported, from_table, to_table
|
|
29
|
+
from ..utils import (
|
|
30
|
+
_check_array,
|
|
31
|
+
_check_n_features,
|
|
32
|
+
_check_X_y,
|
|
33
|
+
_column_or_1d,
|
|
34
|
+
_validate_targets,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SVMtype(Enum):
|
|
39
|
+
c_svc = 0
|
|
40
|
+
epsilon_svr = 1
|
|
41
|
+
nu_svc = 2
|
|
42
|
+
nu_svr = 3
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class BaseSVM(metaclass=ABCMeta):
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
C,
|
|
50
|
+
nu,
|
|
51
|
+
epsilon,
|
|
52
|
+
kernel="rbf",
|
|
53
|
+
*,
|
|
54
|
+
degree,
|
|
55
|
+
gamma,
|
|
56
|
+
coef0,
|
|
57
|
+
tol,
|
|
58
|
+
shrinking,
|
|
59
|
+
cache_size,
|
|
60
|
+
max_iter,
|
|
61
|
+
tau,
|
|
62
|
+
class_weight,
|
|
63
|
+
decision_function_shape,
|
|
64
|
+
break_ties,
|
|
65
|
+
algorithm,
|
|
66
|
+
svm_type=None,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
self.C = C
|
|
70
|
+
self.nu = nu
|
|
71
|
+
self.epsilon = epsilon
|
|
72
|
+
self.kernel = kernel
|
|
73
|
+
self.degree = degree
|
|
74
|
+
self.coef0 = coef0
|
|
75
|
+
self.gamma = gamma
|
|
76
|
+
self.tol = tol
|
|
77
|
+
self.shrinking = shrinking
|
|
78
|
+
self.cache_size = cache_size
|
|
79
|
+
self.max_iter = max_iter
|
|
80
|
+
self.tau = tau
|
|
81
|
+
self.class_weight = class_weight
|
|
82
|
+
self.decision_function_shape = decision_function_shape
|
|
83
|
+
self.break_ties = break_ties
|
|
84
|
+
self.algorithm = algorithm
|
|
85
|
+
self.svm_type = svm_type
|
|
86
|
+
|
|
87
|
+
def _validate_targets(self, y, dtype):
|
|
88
|
+
self.class_weight_ = None
|
|
89
|
+
self.classes_ = None
|
|
90
|
+
return _column_or_1d(y, warn=True).astype(dtype, copy=False)
|
|
91
|
+
|
|
92
|
+
def _get_onedal_params(self, data):
|
|
93
|
+
max_iter = 10000 if self.max_iter == -1 else self.max_iter
|
|
94
|
+
# TODO: remove this workaround
|
|
95
|
+
# when oneDAL SVM starts support of 'n_iterations' result
|
|
96
|
+
self.n_iter_ = 1 if max_iter < 1 else max_iter
|
|
97
|
+
class_count = 0 if self.classes_ is None else len(self.classes_)
|
|
98
|
+
return {
|
|
99
|
+
"fptype": "float" if data.dtype == np.float32 else "double",
|
|
100
|
+
"method": self.algorithm,
|
|
101
|
+
"kernel": self.kernel,
|
|
102
|
+
"c": self.C,
|
|
103
|
+
"nu": self.nu,
|
|
104
|
+
"epsilon": self.epsilon,
|
|
105
|
+
"class_count": class_count,
|
|
106
|
+
"accuracy_threshold": self.tol,
|
|
107
|
+
"max_iteration_count": int(max_iter),
|
|
108
|
+
"scale": self._scale_,
|
|
109
|
+
"sigma": self._sigma_,
|
|
110
|
+
"shift": self.coef0,
|
|
111
|
+
"degree": self.degree,
|
|
112
|
+
"tau": self.tau,
|
|
113
|
+
"shrinking": self.shrinking,
|
|
114
|
+
"cache_size": self.cache_size,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
def _fit(self, X, y, sample_weight, module, queue):
|
|
118
|
+
if hasattr(self, "decision_function_shape"):
|
|
119
|
+
if self.decision_function_shape not in ("ovr", "ovo", None):
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"decision_function_shape must be either 'ovr' or 'ovo', "
|
|
122
|
+
f"got {self.decision_function_shape}."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
X, y = _check_X_y(
|
|
126
|
+
X,
|
|
127
|
+
y,
|
|
128
|
+
dtype=[np.float64, np.float32],
|
|
129
|
+
force_all_finite=True,
|
|
130
|
+
accept_sparse="csr",
|
|
131
|
+
)
|
|
132
|
+
y = self._validate_targets(y, X.dtype)
|
|
133
|
+
if sample_weight is not None and len(sample_weight) > 0:
|
|
134
|
+
sample_weight = _check_array(
|
|
135
|
+
sample_weight,
|
|
136
|
+
accept_sparse=False,
|
|
137
|
+
ensure_2d=False,
|
|
138
|
+
dtype=X.dtype,
|
|
139
|
+
order="C",
|
|
140
|
+
)
|
|
141
|
+
elif self.class_weight is not None:
|
|
142
|
+
sample_weight = np.ones(X.shape[0], dtype=X.dtype)
|
|
143
|
+
|
|
144
|
+
if sample_weight is not None:
|
|
145
|
+
if self.class_weight_ is not None:
|
|
146
|
+
for i, v in enumerate(self.class_weight_):
|
|
147
|
+
sample_weight[y == i] *= v
|
|
148
|
+
data = (X, y, sample_weight)
|
|
149
|
+
else:
|
|
150
|
+
data = (X, y)
|
|
151
|
+
self._sparse = sp.issparse(X)
|
|
152
|
+
|
|
153
|
+
if self.kernel == "linear":
|
|
154
|
+
self._scale_, self._sigma_ = 1.0, 1.0
|
|
155
|
+
self.coef0 = 0.0
|
|
156
|
+
else:
|
|
157
|
+
if isinstance(self.gamma, str):
|
|
158
|
+
if self.gamma == "scale":
|
|
159
|
+
if sp.issparse(X):
|
|
160
|
+
# var = E[X^2] - E[X]^2
|
|
161
|
+
X_sc = (X.multiply(X)).mean() - (X.mean()) ** 2
|
|
162
|
+
else:
|
|
163
|
+
X_sc = X.var()
|
|
164
|
+
_gamma = 1.0 / (X.shape[1] * X_sc) if X_sc != 0 else 1.0
|
|
165
|
+
elif self.gamma == "auto":
|
|
166
|
+
_gamma = 1.0 / X.shape[1]
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"When 'gamma' is a string, it should be either 'scale' or "
|
|
170
|
+
"'auto'. Got '{}' instead.".format(self.gamma)
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
_gamma = self.gamma
|
|
174
|
+
self._scale_, self._sigma_ = _gamma, np.sqrt(0.5 / _gamma)
|
|
175
|
+
|
|
176
|
+
policy = _get_policy(queue, *data)
|
|
177
|
+
X = _convert_to_supported(policy, X)
|
|
178
|
+
params = self._get_onedal_params(X)
|
|
179
|
+
result = module.train(policy, params, *to_table(*data))
|
|
180
|
+
|
|
181
|
+
if self._sparse:
|
|
182
|
+
self.dual_coef_ = sp.csr_matrix(from_table(result.coeffs).T)
|
|
183
|
+
self.support_vectors_ = sp.csr_matrix(from_table(result.support_vectors))
|
|
184
|
+
else:
|
|
185
|
+
self.dual_coef_ = from_table(result.coeffs).T
|
|
186
|
+
self.support_vectors_ = from_table(result.support_vectors)
|
|
187
|
+
|
|
188
|
+
self.intercept_ = from_table(result.biases).ravel()
|
|
189
|
+
self.support_ = from_table(result.support_indices).ravel().astype("int")
|
|
190
|
+
self.n_features_in_ = X.shape[1]
|
|
191
|
+
self.shape_fit_ = X.shape
|
|
192
|
+
|
|
193
|
+
if getattr(self, "classes_", None) is not None:
|
|
194
|
+
indices = y.take(self.support_, axis=0)
|
|
195
|
+
self._n_support = np.array(
|
|
196
|
+
[np.sum(indices == i) for i, _ in enumerate(self.classes_)]
|
|
197
|
+
)
|
|
198
|
+
self._gamma = self._scale_
|
|
199
|
+
|
|
200
|
+
self._onedal_model = result.model
|
|
201
|
+
return self
|
|
202
|
+
|
|
203
|
+
def _create_model(self, module):
|
|
204
|
+
m = module.model()
|
|
205
|
+
|
|
206
|
+
m.support_vectors = to_table(self.support_vectors_)
|
|
207
|
+
m.coeffs = to_table(self.dual_coef_.T)
|
|
208
|
+
m.biases = to_table(self.intercept_)
|
|
209
|
+
|
|
210
|
+
if self.svm_type is SVMtype.c_svc or self.svm_type is SVMtype.nu_svc:
|
|
211
|
+
m.first_class_response, m.second_class_response = 0, 1
|
|
212
|
+
return m
|
|
213
|
+
|
|
214
|
+
def _predict(self, X, module, queue):
|
|
215
|
+
_check_is_fitted(self)
|
|
216
|
+
if self.break_ties and self.decision_function_shape == "ovo":
|
|
217
|
+
raise ValueError(
|
|
218
|
+
"break_ties must be False when " "decision_function_shape is 'ovo'"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if module in [_backend.svm.classification, _backend.svm.nu_classification]:
|
|
222
|
+
sv = self.support_vectors_
|
|
223
|
+
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"The internal representation "
|
|
226
|
+
f"of {self.__class__.__name__} was altered"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if (
|
|
230
|
+
self.break_ties
|
|
231
|
+
and self.decision_function_shape == "ovr"
|
|
232
|
+
and len(self.classes_) > 2
|
|
233
|
+
):
|
|
234
|
+
y = np.argmax(self.decision_function(X), axis=1)
|
|
235
|
+
else:
|
|
236
|
+
X = _check_array(
|
|
237
|
+
X,
|
|
238
|
+
dtype=[np.float64, np.float32],
|
|
239
|
+
force_all_finite=True,
|
|
240
|
+
accept_sparse="csr",
|
|
241
|
+
)
|
|
242
|
+
_check_n_features(self, X, False)
|
|
243
|
+
|
|
244
|
+
if self._sparse and not sp.isspmatrix(X):
|
|
245
|
+
X = sp.csr_matrix(X)
|
|
246
|
+
if self._sparse:
|
|
247
|
+
X.sort_indices()
|
|
248
|
+
|
|
249
|
+
if sp.issparse(X) and not self._sparse and not callable(self.kernel):
|
|
250
|
+
raise ValueError(
|
|
251
|
+
"cannot use sparse input in %r trained on dense data"
|
|
252
|
+
% type(self).__name__
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
policy = _get_policy(queue, X)
|
|
256
|
+
X = _convert_to_supported(policy, X)
|
|
257
|
+
params = self._get_onedal_params(X)
|
|
258
|
+
|
|
259
|
+
if hasattr(self, "_onedal_model"):
|
|
260
|
+
model = self._onedal_model
|
|
261
|
+
else:
|
|
262
|
+
model = self._create_model(module)
|
|
263
|
+
result = module.infer(policy, params, model, to_table(X))
|
|
264
|
+
y = from_table(result.responses)
|
|
265
|
+
return y
|
|
266
|
+
|
|
267
|
+
def _ovr_decision_function(self, predictions, confidences, n_classes):
|
|
268
|
+
n_samples = predictions.shape[0]
|
|
269
|
+
votes = np.zeros((n_samples, n_classes))
|
|
270
|
+
sum_of_confidences = np.zeros((n_samples, n_classes))
|
|
271
|
+
|
|
272
|
+
k = 0
|
|
273
|
+
for i in range(n_classes):
|
|
274
|
+
for j in range(i + 1, n_classes):
|
|
275
|
+
sum_of_confidences[:, i] -= confidences[:, k]
|
|
276
|
+
sum_of_confidences[:, j] += confidences[:, k]
|
|
277
|
+
votes[predictions[:, k] == 0, i] += 1
|
|
278
|
+
votes[predictions[:, k] == 1, j] += 1
|
|
279
|
+
k += 1
|
|
280
|
+
|
|
281
|
+
transformed_confidences = sum_of_confidences / (
|
|
282
|
+
3 * (np.abs(sum_of_confidences) + 1)
|
|
283
|
+
)
|
|
284
|
+
return votes + transformed_confidences
|
|
285
|
+
|
|
286
|
+
def _decision_function(self, X, module, queue):
|
|
287
|
+
_check_is_fitted(self)
|
|
288
|
+
X = _check_array(
|
|
289
|
+
X, dtype=[np.float64, np.float32], force_all_finite=False, accept_sparse="csr"
|
|
290
|
+
)
|
|
291
|
+
_check_n_features(self, X, False)
|
|
292
|
+
|
|
293
|
+
if self._sparse and not sp.isspmatrix(X):
|
|
294
|
+
X = sp.csr_matrix(X)
|
|
295
|
+
if self._sparse:
|
|
296
|
+
X.sort_indices()
|
|
297
|
+
|
|
298
|
+
if sp.issparse(X) and not self._sparse and not callable(self.kernel):
|
|
299
|
+
raise ValueError(
|
|
300
|
+
"cannot use sparse input in %r trained on dense data"
|
|
301
|
+
% type(self).__name__
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if module in [_backend.svm.classification, _backend.svm.nu_classification]:
|
|
305
|
+
sv = self.support_vectors_
|
|
306
|
+
if not self._sparse and sv.size > 0 and self._n_support.sum() != sv.shape[0]:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
"The internal representation "
|
|
309
|
+
f"of {self.__class__.__name__} was altered"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
policy = _get_policy(queue, X)
|
|
313
|
+
X = _convert_to_supported(policy, X)
|
|
314
|
+
params = self._get_onedal_params(X)
|
|
315
|
+
|
|
316
|
+
if hasattr(self, "_onedal_model"):
|
|
317
|
+
model = self._onedal_model
|
|
318
|
+
else:
|
|
319
|
+
model = self._create_model(module)
|
|
320
|
+
result = module.infer(policy, params, model, to_table(X))
|
|
321
|
+
decision_function = from_table(result.decision_function)
|
|
322
|
+
|
|
323
|
+
if len(self.classes_) == 2:
|
|
324
|
+
decision_function = decision_function.ravel()
|
|
325
|
+
|
|
326
|
+
if self.decision_function_shape == "ovr" and len(self.classes_) > 2:
|
|
327
|
+
decision_function = self._ovr_decision_function(
|
|
328
|
+
decision_function < 0, -decision_function, len(self.classes_)
|
|
329
|
+
)
|
|
330
|
+
return decision_function
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class SVR(RegressorMixin, BaseSVM):
|
|
334
|
+
"""
|
|
335
|
+
Epsilon--Support Vector Regression.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
C=1.0,
|
|
341
|
+
epsilon=0.1,
|
|
342
|
+
kernel="rbf",
|
|
343
|
+
*,
|
|
344
|
+
degree=3,
|
|
345
|
+
gamma="scale",
|
|
346
|
+
coef0=0.0,
|
|
347
|
+
tol=1e-3,
|
|
348
|
+
shrinking=True,
|
|
349
|
+
cache_size=200.0,
|
|
350
|
+
max_iter=-1,
|
|
351
|
+
tau=1e-12,
|
|
352
|
+
algorithm="thunder",
|
|
353
|
+
**kwargs,
|
|
354
|
+
):
|
|
355
|
+
super().__init__(
|
|
356
|
+
C=C,
|
|
357
|
+
nu=0.5,
|
|
358
|
+
epsilon=epsilon,
|
|
359
|
+
kernel=kernel,
|
|
360
|
+
degree=degree,
|
|
361
|
+
gamma=gamma,
|
|
362
|
+
coef0=coef0,
|
|
363
|
+
tol=tol,
|
|
364
|
+
shrinking=shrinking,
|
|
365
|
+
cache_size=cache_size,
|
|
366
|
+
max_iter=max_iter,
|
|
367
|
+
tau=tau,
|
|
368
|
+
class_weight=None,
|
|
369
|
+
decision_function_shape=None,
|
|
370
|
+
break_ties=False,
|
|
371
|
+
algorithm=algorithm,
|
|
372
|
+
)
|
|
373
|
+
self.svm_type = SVMtype.epsilon_svr
|
|
374
|
+
|
|
375
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
376
|
+
return super()._fit(X, y, sample_weight, _backend.svm.regression, queue)
|
|
377
|
+
|
|
378
|
+
def predict(self, X, queue=None):
|
|
379
|
+
y = super()._predict(X, _backend.svm.regression, queue)
|
|
380
|
+
return y.ravel()
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class SVC(ClassifierMixin, BaseSVM):
|
|
384
|
+
"""
|
|
385
|
+
C-Support Vector Classification.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def __init__(
|
|
389
|
+
self,
|
|
390
|
+
C=1.0,
|
|
391
|
+
kernel="rbf",
|
|
392
|
+
*,
|
|
393
|
+
degree=3,
|
|
394
|
+
gamma="scale",
|
|
395
|
+
coef0=0.0,
|
|
396
|
+
tol=1e-3,
|
|
397
|
+
shrinking=True,
|
|
398
|
+
cache_size=200.0,
|
|
399
|
+
max_iter=-1,
|
|
400
|
+
tau=1e-12,
|
|
401
|
+
class_weight=None,
|
|
402
|
+
decision_function_shape="ovr",
|
|
403
|
+
break_ties=False,
|
|
404
|
+
algorithm="thunder",
|
|
405
|
+
**kwargs,
|
|
406
|
+
):
|
|
407
|
+
super().__init__(
|
|
408
|
+
C=C,
|
|
409
|
+
nu=0.5,
|
|
410
|
+
epsilon=0.0,
|
|
411
|
+
kernel=kernel,
|
|
412
|
+
degree=degree,
|
|
413
|
+
gamma=gamma,
|
|
414
|
+
coef0=coef0,
|
|
415
|
+
tol=tol,
|
|
416
|
+
shrinking=shrinking,
|
|
417
|
+
cache_size=cache_size,
|
|
418
|
+
max_iter=max_iter,
|
|
419
|
+
tau=tau,
|
|
420
|
+
class_weight=class_weight,
|
|
421
|
+
decision_function_shape=decision_function_shape,
|
|
422
|
+
break_ties=break_ties,
|
|
423
|
+
algorithm=algorithm,
|
|
424
|
+
)
|
|
425
|
+
self.svm_type = SVMtype.c_svc
|
|
426
|
+
|
|
427
|
+
def _validate_targets(self, y, dtype):
|
|
428
|
+
y, self.class_weight_, self.classes_ = _validate_targets(
|
|
429
|
+
y, self.class_weight, dtype
|
|
430
|
+
)
|
|
431
|
+
return y
|
|
432
|
+
|
|
433
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
434
|
+
return super()._fit(X, y, sample_weight, _backend.svm.classification, queue)
|
|
435
|
+
|
|
436
|
+
def predict(self, X, queue=None):
|
|
437
|
+
y = super()._predict(X, _backend.svm.classification, queue)
|
|
438
|
+
if len(self.classes_) == 2:
|
|
439
|
+
y = y.ravel()
|
|
440
|
+
return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
|
|
441
|
+
|
|
442
|
+
def decision_function(self, X, queue=None):
|
|
443
|
+
return super()._decision_function(X, _backend.svm.classification, queue)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class NuSVR(RegressorMixin, BaseSVM):
|
|
447
|
+
"""
|
|
448
|
+
Nu-Support Vector Regression.
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
def __init__(
|
|
452
|
+
self,
|
|
453
|
+
nu=0.5,
|
|
454
|
+
C=1.0,
|
|
455
|
+
kernel="rbf",
|
|
456
|
+
*,
|
|
457
|
+
degree=3,
|
|
458
|
+
gamma="scale",
|
|
459
|
+
coef0=0.0,
|
|
460
|
+
tol=1e-3,
|
|
461
|
+
shrinking=True,
|
|
462
|
+
cache_size=200.0,
|
|
463
|
+
max_iter=-1,
|
|
464
|
+
tau=1e-12,
|
|
465
|
+
algorithm="thunder",
|
|
466
|
+
**kwargs,
|
|
467
|
+
):
|
|
468
|
+
super().__init__(
|
|
469
|
+
C=C,
|
|
470
|
+
nu=nu,
|
|
471
|
+
epsilon=0.0,
|
|
472
|
+
kernel=kernel,
|
|
473
|
+
degree=degree,
|
|
474
|
+
gamma=gamma,
|
|
475
|
+
coef0=coef0,
|
|
476
|
+
tol=tol,
|
|
477
|
+
shrinking=shrinking,
|
|
478
|
+
cache_size=cache_size,
|
|
479
|
+
max_iter=max_iter,
|
|
480
|
+
tau=tau,
|
|
481
|
+
class_weight=None,
|
|
482
|
+
decision_function_shape=None,
|
|
483
|
+
break_ties=False,
|
|
484
|
+
algorithm=algorithm,
|
|
485
|
+
)
|
|
486
|
+
self.svm_type = SVMtype.nu_svr
|
|
487
|
+
|
|
488
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
489
|
+
return super()._fit(X, y, sample_weight, _backend.svm.nu_regression, queue)
|
|
490
|
+
|
|
491
|
+
def predict(self, X, queue=None):
|
|
492
|
+
y = super()._predict(X, _backend.svm.nu_regression, queue)
|
|
493
|
+
return y.ravel()
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class NuSVC(ClassifierMixin, BaseSVM):
|
|
497
|
+
"""
|
|
498
|
+
Nu-Support Vector Classification.
|
|
499
|
+
"""
|
|
500
|
+
|
|
501
|
+
def __init__(
|
|
502
|
+
self,
|
|
503
|
+
nu=0.5,
|
|
504
|
+
kernel="rbf",
|
|
505
|
+
*,
|
|
506
|
+
degree=3,
|
|
507
|
+
gamma="scale",
|
|
508
|
+
coef0=0.0,
|
|
509
|
+
tol=1e-3,
|
|
510
|
+
shrinking=True,
|
|
511
|
+
cache_size=200.0,
|
|
512
|
+
max_iter=-1,
|
|
513
|
+
tau=1e-12,
|
|
514
|
+
class_weight=None,
|
|
515
|
+
decision_function_shape="ovr",
|
|
516
|
+
break_ties=False,
|
|
517
|
+
algorithm="thunder",
|
|
518
|
+
**kwargs,
|
|
519
|
+
):
|
|
520
|
+
super().__init__(
|
|
521
|
+
C=1.0,
|
|
522
|
+
nu=nu,
|
|
523
|
+
epsilon=0.0,
|
|
524
|
+
kernel=kernel,
|
|
525
|
+
degree=degree,
|
|
526
|
+
gamma=gamma,
|
|
527
|
+
coef0=coef0,
|
|
528
|
+
tol=tol,
|
|
529
|
+
shrinking=shrinking,
|
|
530
|
+
cache_size=cache_size,
|
|
531
|
+
max_iter=max_iter,
|
|
532
|
+
tau=tau,
|
|
533
|
+
class_weight=class_weight,
|
|
534
|
+
decision_function_shape=decision_function_shape,
|
|
535
|
+
break_ties=break_ties,
|
|
536
|
+
algorithm=algorithm,
|
|
537
|
+
)
|
|
538
|
+
self.svm_type = SVMtype.nu_svc
|
|
539
|
+
|
|
540
|
+
def _validate_targets(self, y, dtype):
|
|
541
|
+
y, self.class_weight_, self.classes_ = _validate_targets(
|
|
542
|
+
y, self.class_weight, dtype
|
|
543
|
+
)
|
|
544
|
+
return y
|
|
545
|
+
|
|
546
|
+
def fit(self, X, y, sample_weight=None, queue=None):
|
|
547
|
+
return super()._fit(X, y, sample_weight, _backend.svm.nu_classification, queue)
|
|
548
|
+
|
|
549
|
+
def predict(self, X, queue=None):
|
|
550
|
+
y = super()._predict(X, _backend.svm.nu_classification, queue)
|
|
551
|
+
if len(self.classes_) == 2:
|
|
552
|
+
y = y.ravel()
|
|
553
|
+
return self.classes_.take(np.asarray(y, dtype=np.intp)).ravel()
|
|
554
|
+
|
|
555
|
+
def decision_function(self, X, queue=None):
|
|
556
|
+
return super()._decision_function(X, _backend.svm.nu_classification, queue)
|