scikit-learn-intelex 2025.4.0__py313-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-313-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-313-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +248 -0
- daal4py/sklearn/_utils.py +245 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +236 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +696 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +54 -0
- onedal/_device_offload.py +204 -0
- onedal/_onedal_py_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +175 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +242 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- onedal/basic_statistics/tests/utils.py +50 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +105 -0
- onedal/cluster/kmeans.py +557 -0
- onedal/cluster/kmeans_init.py +112 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +55 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +125 -0
- onedal/common/tests/test_policy.py +76 -0
- onedal/common/tests/test_sycl.py +128 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +122 -0
- onedal/covariance/incremental_covariance.py +161 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +190 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +121 -0
- onedal/datatypes/tests/common.py +126 -0
- onedal/datatypes/tests/test_data.py +475 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +214 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +285 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +736 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +292 -0
- onedal/linear_model/linear_model.py +325 -0
- onedal/linear_model/logistic_regression.py +247 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- onedal/linear_model/tests/test_linear_regression.py +259 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +763 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +152 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +71 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +83 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +124 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +101 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +176 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +57 -0
- onedal/tests/utils/_dataframes_support.py +162 -0
- onedal/tests/utils/_device_selection.py +102 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +81 -0
- onedal/utils/_dpep_helpers.py +56 -0
- onedal/utils/tests/test_validation.py +142 -0
- onedal/utils/validation.py +464 -0
- scikit_learn_intelex-2025.4.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.4.0.dist-info/METADATA +190 -0
- scikit_learn_intelex-2025.4.0.dist-info/RECORD +282 -0
- scikit_learn_intelex-2025.4.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.4.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +66 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +116 -0
- sklearnex/_device_offload.py +126 -0
- sklearnex/_utils.py +177 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +261 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +352 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +197 -0
- sklearnex/cluster/k_means.py +397 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +157 -0
- sklearnex/conftest.py +82 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +405 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +287 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +427 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +534 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2029 -0
- sklearnex/ensemble/tests/test_forest.py +140 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +495 -0
- sklearnex/linear_model/incremental_ridge.py +432 -0
- sklearnex/linear_model/linear.py +346 -0
- sklearnex/linear_model/logistic_regression.py +415 -0
- sklearnex/linear_model/ridge.py +390 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/linear_model/tests/test_ridge.py +256 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +26 -0
- sklearnex/manifold/tests/test_tsne.py +250 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +236 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +231 -0
- sklearnex/neighbors/knn_regression.py +207 -0
- sklearnex/neighbors/knn_unsupervised.py +178 -0
- sklearnex/neighbors/tests/test_neighbors.py +82 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +142 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +244 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +336 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +306 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +173 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +331 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +339 -0
- sklearnex/svm/nusvc.py +371 -0
- sklearnex/svm/nusvr.py +170 -0
- sklearnex/svm/svc.py +399 -0
- sklearnex/svm/svr.py +167 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/test_common.py +491 -0
- sklearnex/tests/test_config.py +123 -0
- sklearnex/tests/test_hyperparameters.py +43 -0
- sklearnex/tests/test_memory_usage.py +347 -0
- sklearnex/tests/test_monkeypatch.py +269 -0
- sklearnex/tests/test_n_jobs_support.py +108 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +377 -0
- sklearnex/tests/test_run_to_run_stability.py +326 -0
- sklearnex/tests/utils/__init__.py +48 -0
- sklearnex/tests/utils/base.py +436 -0
- sklearnex/tests/utils/spmd.py +198 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_validation.py +238 -0
- sklearnex/utils/validation.py +208 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
import warnings
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from scipy import sparse as sp
|
|
22
|
+
from sklearn.cluster import KMeans as KMeans_original
|
|
23
|
+
from sklearn.cluster._kmeans import _labels_inertia
|
|
24
|
+
from sklearn.exceptions import ConvergenceWarning
|
|
25
|
+
from sklearn.utils import check_array, check_random_state
|
|
26
|
+
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
|
|
27
|
+
from sklearn.utils.extmath import row_norms
|
|
28
|
+
from sklearn.utils.sparsefuncs import mean_variance_axis
|
|
29
|
+
from sklearn.utils.validation import (
|
|
30
|
+
_deprecate_positional_args,
|
|
31
|
+
_num_samples,
|
|
32
|
+
check_is_fitted,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import daal4py
|
|
36
|
+
|
|
37
|
+
from .._n_jobs_support import control_n_jobs
|
|
38
|
+
from .._utils import PatchingConditionsChain, getFPType, sklearn_check_version
|
|
39
|
+
|
|
40
|
+
if sklearn_check_version("1.1"):
|
|
41
|
+
from sklearn.utils.validation import _check_sample_weight, _is_arraylike_not_scalar
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _validate_center_shape(X, n_centers, centers):
|
|
45
|
+
"""Check if centers is compatible with X and n_centers"""
|
|
46
|
+
if centers.shape[0] != n_centers:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"The shape of the initial centers {centers.shape} does not "
|
|
49
|
+
f"match the number of clusters {n_centers}."
|
|
50
|
+
)
|
|
51
|
+
if centers.shape[1] != X.shape[1]:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"The shape of the initial centers {centers.shape} does not "
|
|
54
|
+
f"match the number of features of the data {X.shape[1]}."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _tolerance(X, rtol):
|
|
59
|
+
"""Compute absolute tolerance from the relative tolerance"""
|
|
60
|
+
if rtol == 0.0:
|
|
61
|
+
return rtol
|
|
62
|
+
if sp.issparse(X):
|
|
63
|
+
variances = mean_variance_axis(X, axis=0)[1]
|
|
64
|
+
mean_var = np.mean(variances)
|
|
65
|
+
else:
|
|
66
|
+
mean_var = np.var(X, axis=0).mean()
|
|
67
|
+
return mean_var * rtol
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _daal4py_compute_starting_centroids(
|
|
71
|
+
X, X_fptype, nClusters, cluster_centers_0, verbose, random_state
|
|
72
|
+
):
|
|
73
|
+
def is_string(s, target_str):
|
|
74
|
+
return isinstance(s, str) and s == target_str
|
|
75
|
+
|
|
76
|
+
is_sparse = sp.issparse(X)
|
|
77
|
+
|
|
78
|
+
deterministic = False
|
|
79
|
+
if is_string(cluster_centers_0, "k-means++"):
|
|
80
|
+
_seed = random_state.randint(np.iinfo("i").max)
|
|
81
|
+
plus_plus_method = "plusPlusCSR" if is_sparse else "plusPlusDense"
|
|
82
|
+
daal_engine = daal4py.engines_mt19937(
|
|
83
|
+
fptype=X_fptype, method="defaultDense", seed=_seed
|
|
84
|
+
)
|
|
85
|
+
_n_local_trials = 2 + int(np.log(nClusters))
|
|
86
|
+
kmeans_init = daal4py.kmeans_init(
|
|
87
|
+
nClusters,
|
|
88
|
+
fptype=X_fptype,
|
|
89
|
+
nTrials=_n_local_trials,
|
|
90
|
+
method=plus_plus_method,
|
|
91
|
+
engine=daal_engine,
|
|
92
|
+
)
|
|
93
|
+
kmeans_init_res = kmeans_init.compute(X)
|
|
94
|
+
centroids_ = kmeans_init_res.centroids
|
|
95
|
+
elif is_string(cluster_centers_0, "random"):
|
|
96
|
+
_seed = random_state.randint(np.iinfo("i").max)
|
|
97
|
+
random_method = "randomCSR" if is_sparse else "randomDense"
|
|
98
|
+
daal_engine = daal4py.engines_mt19937(
|
|
99
|
+
seed=_seed, fptype=X_fptype, method="defaultDense"
|
|
100
|
+
)
|
|
101
|
+
kmeans_init = daal4py.kmeans_init(
|
|
102
|
+
nClusters,
|
|
103
|
+
fptype=X_fptype,
|
|
104
|
+
method=random_method,
|
|
105
|
+
engine=daal_engine,
|
|
106
|
+
)
|
|
107
|
+
kmeans_init_res = kmeans_init.compute(X)
|
|
108
|
+
centroids_ = kmeans_init_res.centroids
|
|
109
|
+
elif hasattr(cluster_centers_0, "__array__"):
|
|
110
|
+
deterministic = True
|
|
111
|
+
cc_arr = np.ascontiguousarray(cluster_centers_0, dtype=X.dtype)
|
|
112
|
+
_validate_center_shape(X, nClusters, cc_arr)
|
|
113
|
+
centroids_ = cc_arr
|
|
114
|
+
elif callable(cluster_centers_0):
|
|
115
|
+
cc_arr = cluster_centers_0(X, nClusters, random_state)
|
|
116
|
+
cc_arr = np.ascontiguousarray(cc_arr, dtype=X.dtype)
|
|
117
|
+
_validate_center_shape(X, nClusters, cc_arr)
|
|
118
|
+
centroids_ = cc_arr
|
|
119
|
+
elif is_string(cluster_centers_0, "deterministic"):
|
|
120
|
+
deterministic = True
|
|
121
|
+
default_method = "lloydCSR" if is_sparse else "defaultDense"
|
|
122
|
+
kmeans_init = daal4py.kmeans_init(
|
|
123
|
+
nClusters, fptype=X_fptype, method=default_method
|
|
124
|
+
)
|
|
125
|
+
kmeans_init_res = kmeans_init.compute(X)
|
|
126
|
+
centroids_ = kmeans_init_res.centroids
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"init should be either 'k-means++', 'random', a ndarray or a "
|
|
130
|
+
f"callable, got '{cluster_centers_0}' instead."
|
|
131
|
+
)
|
|
132
|
+
if verbose:
|
|
133
|
+
print("Initialization complete")
|
|
134
|
+
return deterministic, centroids_
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _daal4py_kmeans_compatibility(
|
|
138
|
+
nClusters,
|
|
139
|
+
maxIterations,
|
|
140
|
+
fptype="double",
|
|
141
|
+
method="lloydDense",
|
|
142
|
+
accuracyThreshold=0.0,
|
|
143
|
+
resultsToEvaluate="computeCentroids",
|
|
144
|
+
gamma=1.0,
|
|
145
|
+
):
|
|
146
|
+
kmeans_algo = daal4py.kmeans(
|
|
147
|
+
nClusters=nClusters,
|
|
148
|
+
maxIterations=maxIterations,
|
|
149
|
+
fptype=fptype,
|
|
150
|
+
resultsToEvaluate=resultsToEvaluate,
|
|
151
|
+
accuracyThreshold=accuracyThreshold,
|
|
152
|
+
method=method,
|
|
153
|
+
gamma=gamma,
|
|
154
|
+
)
|
|
155
|
+
return kmeans_algo
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _daal4py_k_means_predict(
|
|
159
|
+
X, nClusters, centroids, resultsToEvaluate="computeAssignments"
|
|
160
|
+
):
|
|
161
|
+
X_fptype = getFPType(X)
|
|
162
|
+
is_sparse = sp.issparse(X)
|
|
163
|
+
method = "lloydCSR" if is_sparse else "defaultDense"
|
|
164
|
+
kmeans_algo = _daal4py_kmeans_compatibility(
|
|
165
|
+
nClusters=nClusters,
|
|
166
|
+
maxIterations=0,
|
|
167
|
+
fptype=X_fptype,
|
|
168
|
+
resultsToEvaluate=resultsToEvaluate,
|
|
169
|
+
method=method,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
res = kmeans_algo.compute(X, centroids)
|
|
173
|
+
|
|
174
|
+
return res.assignments[:, 0], res.objectiveFunction[0, 0]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _daal4py_k_means_fit(
|
|
178
|
+
X, nClusters, numIterations, tol, cluster_centers_0, n_init, verbose, random_state
|
|
179
|
+
):
|
|
180
|
+
if numIterations < 0:
|
|
181
|
+
raise ValueError("Wrong iterations number")
|
|
182
|
+
|
|
183
|
+
def is_string(s, target_str):
|
|
184
|
+
return isinstance(s, str) and s == target_str
|
|
185
|
+
|
|
186
|
+
default_n_init = 10
|
|
187
|
+
if n_init in ["auto", "warn"]:
|
|
188
|
+
if n_init == "warn" and sklearn_check_version("1.2"):
|
|
189
|
+
warnings.warn(
|
|
190
|
+
"The default value of `n_init` will change from "
|
|
191
|
+
f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`"
|
|
192
|
+
" explicitly to suppress the warning",
|
|
193
|
+
FutureWarning,
|
|
194
|
+
)
|
|
195
|
+
if is_string(cluster_centers_0, "k-means++"):
|
|
196
|
+
n_init = 1
|
|
197
|
+
else:
|
|
198
|
+
n_init = default_n_init
|
|
199
|
+
X_fptype = getFPType(X)
|
|
200
|
+
abs_tol = _tolerance(X, tol) # tol is relative tolerance
|
|
201
|
+
is_sparse = sp.issparse(X)
|
|
202
|
+
method = "lloydCSR" if is_sparse else "defaultDense"
|
|
203
|
+
best_inertia, best_cluster_centers = None, None
|
|
204
|
+
best_n_iter = -1
|
|
205
|
+
kmeans_algo = _daal4py_kmeans_compatibility(
|
|
206
|
+
nClusters=nClusters,
|
|
207
|
+
maxIterations=numIterations,
|
|
208
|
+
accuracyThreshold=abs_tol,
|
|
209
|
+
fptype=X_fptype,
|
|
210
|
+
resultsToEvaluate="computeCentroids",
|
|
211
|
+
method=method,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
for k in range(n_init):
|
|
215
|
+
deterministic, starting_centroids_ = _daal4py_compute_starting_centroids(
|
|
216
|
+
X, X_fptype, nClusters, cluster_centers_0, verbose, random_state
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
res = kmeans_algo.compute(X, starting_centroids_)
|
|
220
|
+
|
|
221
|
+
inertia = res.objectiveFunction[0, 0]
|
|
222
|
+
if verbose:
|
|
223
|
+
print(f"Iteration {k}, inertia {inertia}.")
|
|
224
|
+
|
|
225
|
+
if best_inertia is None or inertia < best_inertia:
|
|
226
|
+
best_cluster_centers = res.centroids
|
|
227
|
+
if n_init > 1:
|
|
228
|
+
best_cluster_centers = best_cluster_centers.copy()
|
|
229
|
+
best_inertia = inertia
|
|
230
|
+
best_n_iter = int(res.nIterations[0, 0])
|
|
231
|
+
if deterministic and n_init != 1:
|
|
232
|
+
warnings.warn(
|
|
233
|
+
"Explicit initial center position passed: "
|
|
234
|
+
"performing only one init in k-means instead of n_init=%d" % n_init,
|
|
235
|
+
RuntimeWarning,
|
|
236
|
+
stacklevel=2,
|
|
237
|
+
)
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
flag_compute = "computeAssignments|computeExactObjectiveFunction"
|
|
241
|
+
best_labels, best_inertia = _daal4py_k_means_predict(
|
|
242
|
+
X, nClusters, best_cluster_centers, flag_compute
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
distinct_clusters = np.unique(best_labels).size
|
|
246
|
+
if distinct_clusters < nClusters:
|
|
247
|
+
warnings.warn(
|
|
248
|
+
"Number of distinct clusters ({}) found smaller than "
|
|
249
|
+
"n_clusters ({}). Possibly due to duplicate points "
|
|
250
|
+
"in X.".format(distinct_clusters, nClusters),
|
|
251
|
+
ConvergenceWarning,
|
|
252
|
+
stacklevel=2,
|
|
253
|
+
)
|
|
254
|
+
# for passing test case "test_kmeans_warns_less_centers_than_unique_points"
|
|
255
|
+
|
|
256
|
+
return best_cluster_centers, best_labels, best_inertia, best_n_iter
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _fit(self, X, y=None, sample_weight=None):
|
|
260
|
+
init = self.init
|
|
261
|
+
if sklearn_check_version("1.1"):
|
|
262
|
+
if sklearn_check_version("1.2"):
|
|
263
|
+
self._validate_params()
|
|
264
|
+
|
|
265
|
+
X = self._validate_data(
|
|
266
|
+
X,
|
|
267
|
+
accept_sparse="csr",
|
|
268
|
+
dtype=[np.float64, np.float32],
|
|
269
|
+
order="C",
|
|
270
|
+
copy=self.copy_x,
|
|
271
|
+
accept_large_sparse=False,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if sklearn_check_version("1.2"):
|
|
275
|
+
self._check_params_vs_input(X)
|
|
276
|
+
else:
|
|
277
|
+
self._check_params(X)
|
|
278
|
+
|
|
279
|
+
random_state = check_random_state(self.random_state)
|
|
280
|
+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
|
|
281
|
+
self._n_threads = _openmp_effective_n_threads()
|
|
282
|
+
|
|
283
|
+
# Validate init array
|
|
284
|
+
init_is_array_like = _is_arraylike_not_scalar(init)
|
|
285
|
+
if init_is_array_like:
|
|
286
|
+
init = check_array(init, dtype=X.dtype, copy=True, order="C")
|
|
287
|
+
self._validate_center_shape(X, init)
|
|
288
|
+
else:
|
|
289
|
+
if hasattr(self, "precompute_distances"):
|
|
290
|
+
if self.precompute_distances != "deprecated":
|
|
291
|
+
warnings.warn(
|
|
292
|
+
"'precompute_distances' was deprecated in version "
|
|
293
|
+
"0.23 and will be removed in 1.0 (renaming of 0.25)."
|
|
294
|
+
" It has no effect",
|
|
295
|
+
FutureWarning,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self._n_threads = None
|
|
299
|
+
if hasattr(self, "n_jobs"):
|
|
300
|
+
if self.n_jobs != "deprecated":
|
|
301
|
+
warnings.warn(
|
|
302
|
+
"'n_jobs' was deprecated in version 0.23 and will be"
|
|
303
|
+
" removed in 1.0 (renaming of 0.25).",
|
|
304
|
+
FutureWarning,
|
|
305
|
+
)
|
|
306
|
+
self._n_threads = self.n_jobs
|
|
307
|
+
self._n_threads = _openmp_effective_n_threads(self._n_threads)
|
|
308
|
+
|
|
309
|
+
if self.n_init <= 0:
|
|
310
|
+
raise ValueError(f"n_init should be > 0, got {self.n_init} instead.")
|
|
311
|
+
|
|
312
|
+
random_state = check_random_state(self.random_state)
|
|
313
|
+
if sklearn_check_version("1.0"):
|
|
314
|
+
self._check_feature_names(X, reset=True)
|
|
315
|
+
|
|
316
|
+
if self.max_iter <= 0:
|
|
317
|
+
raise ValueError(f"max_iter should be > 0, got {self.max_iter} instead.")
|
|
318
|
+
|
|
319
|
+
algorithm = self.algorithm
|
|
320
|
+
if sklearn_check_version("1.2"):
|
|
321
|
+
if algorithm == "elkan" and self.n_clusters == 1:
|
|
322
|
+
warnings.warn(
|
|
323
|
+
"algorithm='elkan' doesn't make sense for a single "
|
|
324
|
+
"cluster. Using 'full' instead.",
|
|
325
|
+
RuntimeWarning,
|
|
326
|
+
)
|
|
327
|
+
algorithm = "lloyd"
|
|
328
|
+
|
|
329
|
+
if algorithm == "auto" or algorithm == "full":
|
|
330
|
+
warnings.warn(
|
|
331
|
+
"algorithm= {'auto','full'} is deprecated" "Using 'lloyd' instead.",
|
|
332
|
+
RuntimeWarning,
|
|
333
|
+
)
|
|
334
|
+
algorithm = "lloyd" if self.n_clusters == 1 else "elkan"
|
|
335
|
+
|
|
336
|
+
if algorithm not in ["lloyd", "full", "elkan"]:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"Algorithm must be 'auto','lloyd', 'full' or 'elkan',"
|
|
339
|
+
"got {}".format(str(algorithm))
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
if algorithm == "elkan" and self.n_clusters == 1:
|
|
343
|
+
warnings.warn(
|
|
344
|
+
"algorithm='elkan' doesn't make sense for a single "
|
|
345
|
+
"cluster. Using 'full' instead.",
|
|
346
|
+
RuntimeWarning,
|
|
347
|
+
)
|
|
348
|
+
algorithm = "full"
|
|
349
|
+
|
|
350
|
+
if algorithm == "auto":
|
|
351
|
+
algorithm = "full" if self.n_clusters == 1 else "elkan"
|
|
352
|
+
|
|
353
|
+
if algorithm not in ["full", "elkan"]:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
"Algorithm must be 'auto', 'full' or 'elkan', got"
|
|
356
|
+
" {}".format(str(algorithm))
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
X_len = _num_samples(X)
|
|
360
|
+
|
|
361
|
+
_patching_status = PatchingConditionsChain("sklearn.cluster.KMeans.fit")
|
|
362
|
+
_dal_ready = _patching_status.and_conditions(
|
|
363
|
+
[
|
|
364
|
+
(
|
|
365
|
+
self.n_clusters <= X_len,
|
|
366
|
+
"The number of clusters is larger than the number of samples in X.",
|
|
367
|
+
)
|
|
368
|
+
]
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
if _dal_ready and sample_weight is not None:
|
|
372
|
+
if isinstance(sample_weight, numbers.Number):
|
|
373
|
+
sample_weight = np.full(X_len, sample_weight, dtype=np.float64)
|
|
374
|
+
else:
|
|
375
|
+
sample_weight = np.asarray(sample_weight)
|
|
376
|
+
_dal_ready = _patching_status.and_conditions(
|
|
377
|
+
[
|
|
378
|
+
(
|
|
379
|
+
sample_weight.shape == (X_len,),
|
|
380
|
+
"Sample weights do not have the same length as X.",
|
|
381
|
+
),
|
|
382
|
+
(
|
|
383
|
+
np.allclose(sample_weight, np.ones_like(sample_weight)),
|
|
384
|
+
"Sample weights are not ones.",
|
|
385
|
+
),
|
|
386
|
+
]
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
_patching_status.write_log()
|
|
390
|
+
if _dal_ready:
|
|
391
|
+
X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
|
|
392
|
+
self.n_features_in_ = X.shape[1]
|
|
393
|
+
(
|
|
394
|
+
self.cluster_centers_,
|
|
395
|
+
self.labels_,
|
|
396
|
+
self.inertia_,
|
|
397
|
+
self.n_iter_,
|
|
398
|
+
) = _daal4py_k_means_fit(
|
|
399
|
+
X,
|
|
400
|
+
self.n_clusters,
|
|
401
|
+
self.max_iter,
|
|
402
|
+
self.tol,
|
|
403
|
+
init,
|
|
404
|
+
self.n_init,
|
|
405
|
+
self.verbose,
|
|
406
|
+
random_state,
|
|
407
|
+
)
|
|
408
|
+
if sklearn_check_version("1.1"):
|
|
409
|
+
self._n_features_out = self.cluster_centers_.shape[0]
|
|
410
|
+
else:
|
|
411
|
+
super(KMeans, self).fit(X, y=y, sample_weight=sample_weight)
|
|
412
|
+
return self
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _daal4py_check_test_data(self, X):
|
|
416
|
+
if sklearn_check_version("1.0"):
|
|
417
|
+
self._check_feature_names(X, reset=False)
|
|
418
|
+
X = check_array(
|
|
419
|
+
X, accept_sparse="csr", dtype=[np.float64, np.float32], accept_large_sparse=False
|
|
420
|
+
)
|
|
421
|
+
if self.n_features_in_ != X.shape[1]:
|
|
422
|
+
raise ValueError(
|
|
423
|
+
(
|
|
424
|
+
f"X has {X.shape[1]} features, "
|
|
425
|
+
f"but Kmeans is expecting {self.n_features_in_} features as input"
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
return X
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _predict(self, X, sample_weight=None):
|
|
432
|
+
check_is_fitted(self)
|
|
433
|
+
|
|
434
|
+
X = _daal4py_check_test_data(self, X)
|
|
435
|
+
|
|
436
|
+
if (
|
|
437
|
+
sklearn_check_version("1.3")
|
|
438
|
+
and isinstance(sample_weight, str)
|
|
439
|
+
and sample_weight == "deprecated"
|
|
440
|
+
):
|
|
441
|
+
sample_weight = None
|
|
442
|
+
|
|
443
|
+
_patching_status = PatchingConditionsChain("sklearn.cluster.KMeans.predict")
|
|
444
|
+
_patching_status.and_conditions(
|
|
445
|
+
[
|
|
446
|
+
(sample_weight is None, "Sample weights are not supported."),
|
|
447
|
+
(hasattr(X, "__array__"), "X does not have '__array__' attribute."),
|
|
448
|
+
]
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# CSR array is introduced in scipy 1.11, this requires an initial attribute check
|
|
452
|
+
if hasattr(sp, "csr_array"):
|
|
453
|
+
_dal_ready = _patching_status.or_conditions(
|
|
454
|
+
[
|
|
455
|
+
(
|
|
456
|
+
sp.isspmatrix_csr(X) or isinstance(X, sp.csr_array),
|
|
457
|
+
"X is not csr sparse.",
|
|
458
|
+
)
|
|
459
|
+
]
|
|
460
|
+
)
|
|
461
|
+
else:
|
|
462
|
+
_dal_ready = _patching_status.or_conditions(
|
|
463
|
+
[(sp.isspmatrix_csr(X), "X is not csr sparse.")]
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
_patching_status.write_log()
|
|
467
|
+
if _dal_ready:
|
|
468
|
+
return _daal4py_k_means_predict(X, self.n_clusters, self.cluster_centers_)[0]
|
|
469
|
+
if sklearn_check_version("1.2"):
|
|
470
|
+
if sklearn_check_version("1.3") and sample_weight is not None:
|
|
471
|
+
warnings.warn(
|
|
472
|
+
"'sample_weight' was deprecated in version 1.3 and "
|
|
473
|
+
"will be removed in 1.5.",
|
|
474
|
+
FutureWarning,
|
|
475
|
+
)
|
|
476
|
+
return _labels_inertia(X, sample_weight, self.cluster_centers_)[0]
|
|
477
|
+
else:
|
|
478
|
+
x_squared_norms = row_norms(X, squared=True)
|
|
479
|
+
return _labels_inertia(X, sample_weight, x_squared_norms, self.cluster_centers_)[
|
|
480
|
+
0
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
@control_n_jobs(decorated_methods=["fit", "predict"])
|
|
485
|
+
class KMeans(KMeans_original):
|
|
486
|
+
__doc__ = KMeans_original.__doc__
|
|
487
|
+
|
|
488
|
+
if sklearn_check_version("1.2"):
|
|
489
|
+
_parameter_constraints: dict = {**KMeans_original._parameter_constraints}
|
|
490
|
+
|
|
491
|
+
@_deprecate_positional_args
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
n_clusters=8,
|
|
495
|
+
*,
|
|
496
|
+
init="k-means++",
|
|
497
|
+
n_init="auto" if sklearn_check_version("1.4") else "warn",
|
|
498
|
+
max_iter=300,
|
|
499
|
+
tol=1e-4,
|
|
500
|
+
verbose=0,
|
|
501
|
+
random_state=None,
|
|
502
|
+
copy_x=True,
|
|
503
|
+
algorithm="lloyd",
|
|
504
|
+
):
|
|
505
|
+
super(KMeans, self).__init__(
|
|
506
|
+
n_clusters=n_clusters,
|
|
507
|
+
init=init,
|
|
508
|
+
max_iter=max_iter,
|
|
509
|
+
tol=tol,
|
|
510
|
+
n_init=n_init,
|
|
511
|
+
verbose=verbose,
|
|
512
|
+
random_state=random_state,
|
|
513
|
+
copy_x=copy_x,
|
|
514
|
+
algorithm=algorithm,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
elif sklearn_check_version("1.0"):
|
|
518
|
+
|
|
519
|
+
@_deprecate_positional_args
|
|
520
|
+
def __init__(
|
|
521
|
+
self,
|
|
522
|
+
n_clusters=8,
|
|
523
|
+
*,
|
|
524
|
+
init="k-means++",
|
|
525
|
+
n_init=10,
|
|
526
|
+
max_iter=300,
|
|
527
|
+
tol=1e-4,
|
|
528
|
+
verbose=0,
|
|
529
|
+
random_state=None,
|
|
530
|
+
copy_x=True,
|
|
531
|
+
algorithm="lloyd" if sklearn_check_version("1.1") else "auto",
|
|
532
|
+
):
|
|
533
|
+
super(KMeans, self).__init__(
|
|
534
|
+
n_clusters=n_clusters,
|
|
535
|
+
init=init,
|
|
536
|
+
max_iter=max_iter,
|
|
537
|
+
tol=tol,
|
|
538
|
+
n_init=n_init,
|
|
539
|
+
verbose=verbose,
|
|
540
|
+
random_state=random_state,
|
|
541
|
+
copy_x=copy_x,
|
|
542
|
+
algorithm=algorithm,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
else:
|
|
546
|
+
|
|
547
|
+
@_deprecate_positional_args
|
|
548
|
+
def __init__(
|
|
549
|
+
self,
|
|
550
|
+
n_clusters=8,
|
|
551
|
+
*,
|
|
552
|
+
init="k-means++",
|
|
553
|
+
n_init=10,
|
|
554
|
+
max_iter=300,
|
|
555
|
+
tol=1e-4,
|
|
556
|
+
precompute_distances="deprecated",
|
|
557
|
+
verbose=0,
|
|
558
|
+
random_state=None,
|
|
559
|
+
copy_x=True,
|
|
560
|
+
n_jobs="deprecated",
|
|
561
|
+
algorithm="auto",
|
|
562
|
+
):
|
|
563
|
+
super(KMeans, self).__init__(
|
|
564
|
+
n_clusters=n_clusters,
|
|
565
|
+
init=init,
|
|
566
|
+
max_iter=max_iter,
|
|
567
|
+
tol=tol,
|
|
568
|
+
precompute_distances=precompute_distances,
|
|
569
|
+
n_init=n_init,
|
|
570
|
+
verbose=verbose,
|
|
571
|
+
random_state=random_state,
|
|
572
|
+
copy_x=copy_x,
|
|
573
|
+
n_jobs=n_jobs,
|
|
574
|
+
algorithm=algorithm,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
def fit(self, X, y=None, sample_weight=None):
|
|
578
|
+
return _fit(self, X, y=y, sample_weight=sample_weight)
|
|
579
|
+
|
|
580
|
+
if sklearn_check_version("1.5"):
|
|
581
|
+
|
|
582
|
+
def predict(self, X):
|
|
583
|
+
return _predict(self, X)
|
|
584
|
+
|
|
585
|
+
else:
|
|
586
|
+
|
|
587
|
+
def predict(
|
|
588
|
+
self, X, sample_weight="deprecated" if sklearn_check_version("1.3") else None
|
|
589
|
+
):
|
|
590
|
+
return _predict(self, X, sample_weight=sample_weight)
|
|
591
|
+
|
|
592
|
+
def fit_predict(self, X, y=None, sample_weight=None):
|
|
593
|
+
return super().fit_predict(X, y, sample_weight)
|
|
594
|
+
|
|
595
|
+
fit.__doc__ = KMeans_original.fit.__doc__
|
|
596
|
+
predict.__doc__ = KMeans_original.predict.__doc__
|
|
597
|
+
fit_predict.__doc__ = KMeans_original.fit_predict.__doc__
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2020 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from sklearn.cluster import DBSCAN as DBSCAN_SKLEARN
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn.cluster import DBSCAN as DBSCAN_DAAL
|
|
22
|
+
|
|
23
|
+
METRIC = ("euclidean",)
|
|
24
|
+
USE_WEIGHTS = (True, False)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def generate_data(
|
|
28
|
+
low: int, high: int, samples_number: int, sample_dimension: tuple
|
|
29
|
+
) -> tuple:
|
|
30
|
+
generator = np.random.RandomState()
|
|
31
|
+
table_size = (samples_number, sample_dimension)
|
|
32
|
+
return generator.uniform(low=low, high=high, size=table_size), generator.uniform(
|
|
33
|
+
size=samples_number
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def check_labels_equals(left_labels: np.ndarray, right_labels: np.ndarray) -> bool:
|
|
38
|
+
if left_labels.shape != right_labels.shape:
|
|
39
|
+
raise Exception("Shapes not equals")
|
|
40
|
+
if len(left_labels.shape) != 1:
|
|
41
|
+
raise Exception("Shapes size not equals 1")
|
|
42
|
+
if len(set(left_labels)) != len(set(right_labels)):
|
|
43
|
+
raise Exception("Clusters count not equals")
|
|
44
|
+
dict_checker = {}
|
|
45
|
+
for index_sample in range(left_labels.shape[0]):
|
|
46
|
+
if left_labels[index_sample] not in dict_checker:
|
|
47
|
+
dict_checker[left_labels[index_sample]] = right_labels[index_sample]
|
|
48
|
+
elif dict_checker[left_labels[index_sample]] != right_labels[index_sample]:
|
|
49
|
+
raise Exception("Wrong clustering")
|
|
50
|
+
return True
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _test_dbscan_big_data_numpy_gen(
|
|
54
|
+
eps: float,
|
|
55
|
+
min_samples: int,
|
|
56
|
+
metric: str,
|
|
57
|
+
use_weights: bool,
|
|
58
|
+
low=-100.0,
|
|
59
|
+
high=100.0,
|
|
60
|
+
samples_number=1000,
|
|
61
|
+
sample_dimension=4,
|
|
62
|
+
):
|
|
63
|
+
data, weights = generate_data(
|
|
64
|
+
low=low,
|
|
65
|
+
high=high,
|
|
66
|
+
samples_number=samples_number,
|
|
67
|
+
sample_dimension=sample_dimension,
|
|
68
|
+
)
|
|
69
|
+
if use_weights is False:
|
|
70
|
+
weights = None
|
|
71
|
+
initialized_daal_dbscan = DBSCAN_DAAL(
|
|
72
|
+
eps=eps, min_samples=min_samples, metric=metric
|
|
73
|
+
).fit(X=data, sample_weight=weights)
|
|
74
|
+
initialized_sklearn_dbscan = DBSCAN_SKLEARN(
|
|
75
|
+
metric=metric, eps=eps, min_samples=min_samples
|
|
76
|
+
).fit(X=data, sample_weight=weights)
|
|
77
|
+
check_labels_equals(
|
|
78
|
+
initialized_daal_dbscan.labels_, initialized_sklearn_dbscan.labels_
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.mark.parametrize("metric", METRIC)
|
|
83
|
+
@pytest.mark.parametrize("use_weights", USE_WEIGHTS)
|
|
84
|
+
def test_dbscan_big_data_numpy_gen(metric, use_weights: bool):
|
|
85
|
+
eps = 35.0
|
|
86
|
+
min_samples = 6
|
|
87
|
+
_test_dbscan_big_data_numpy_gen(
|
|
88
|
+
eps=eps, min_samples=min_samples, metric=metric, use_weights=use_weights
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _test_across_grid_parameter_numpy_gen(metric, use_weights: bool):
|
|
93
|
+
eps_begin = 0.05
|
|
94
|
+
eps_end = 0.5
|
|
95
|
+
eps_step = 0.05
|
|
96
|
+
min_samples_begin = 5
|
|
97
|
+
min_samples_end = 15
|
|
98
|
+
min_samples_step = 1
|
|
99
|
+
for eps in np.arange(eps_begin, eps_end, eps_step):
|
|
100
|
+
for min_samples in range(min_samples_begin, min_samples_end, min_samples_step):
|
|
101
|
+
_test_dbscan_big_data_numpy_gen(
|
|
102
|
+
eps=eps, min_samples=min_samples, metric=metric, use_weights=use_weights
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.parametrize("metric", METRIC)
|
|
107
|
+
@pytest.mark.parametrize("use_weights", USE_WEIGHTS)
|
|
108
|
+
def test_across_grid_parameter_numpy_gen(metric, use_weights: bool):
|
|
109
|
+
_test_across_grid_parameter_numpy_gen(metric=metric, use_weights=use_weights)
|