scikit-learn-intelex 2025.0.0__py310-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-310-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-310-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +242 -0
- daal4py/sklearn/_utils.py +241 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +318 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +196 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +155 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +87 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +118 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +693 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +53 -0
- onedal/_device_offload.py +229 -0
- onedal/_onedal_py_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-310-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-310-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +160 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +110 -0
- onedal/cluster/kmeans.py +560 -0
- onedal/cluster/kmeans_init.py +115 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +59 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +116 -0
- onedal/common/tests/test_policy.py +75 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +125 -0
- onedal/covariance/incremental_covariance.py +146 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +122 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +95 -0
- onedal/datatypes/tests/test_data.py +235 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +204 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +198 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +720 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +258 -0
- onedal/linear_model/linear_model.py +329 -0
- onedal/linear_model/logistic_regression.py +249 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
- onedal/linear_model/tests/test_linear_regression.py +149 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +778 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +153 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +69 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +82 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +117 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +97 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +168 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +41 -0
- onedal/tests/utils/_dataframes_support.py +168 -0
- onedal/tests/utils/_device_selection.py +107 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +91 -0
- onedal/utils/validation.py +432 -0
- scikit_learn_intelex-2025.0.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.0.0.dist-info/METADATA +231 -0
- scikit_learn_intelex-2025.0.0.dist-info/RECORD +278 -0
- scikit_learn_intelex-2025.0.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.0.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +65 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +98 -0
- sklearnex/_device_offload.py +121 -0
- sklearnex/_utils.py +109 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +140 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +384 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +192 -0
- sklearnex/cluster/k_means.py +383 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +153 -0
- sklearnex/conftest.py +73 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +368 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +226 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +414 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +543 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2016 -0
- sklearnex/ensemble/tests/test_forest.py +120 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +463 -0
- sklearnex/linear_model/incremental_ridge.py +418 -0
- sklearnex/linear_model/linear.py +302 -0
- sklearnex/linear_model/logistic_path.py +17 -0
- sklearnex/linear_model/logistic_regression.py +403 -0
- sklearnex/linear_model/ridge.py +24 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +203 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +21 -0
- sklearnex/manifold/tests/test_tsne.py +26 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +231 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +226 -0
- sklearnex/neighbors/knn_regression.py +203 -0
- sklearnex/neighbors/knn_unsupervised.py +170 -0
- sklearnex/neighbors/tests/test_neighbors.py +80 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +133 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +228 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
- sklearnex/preview/linear_model/__init__.py +19 -0
- sklearnex/preview/linear_model/ridge.py +419 -0
- sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +166 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +328 -0
- sklearnex/svm/nusvc.py +332 -0
- sklearnex/svm/nusvr.py +148 -0
- sklearnex/svm/svc.py +360 -0
- sklearnex/svm/svr.py +149 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/_utils.py +328 -0
- sklearnex/tests/_utils_spmd.py +198 -0
- sklearnex/tests/test_common.py +54 -0
- sklearnex/tests/test_config.py +43 -0
- sklearnex/tests/test_memory_usage.py +291 -0
- sklearnex/tests/test_monkeypatch.py +276 -0
- sklearnex/tests/test_n_jobs_support.py +103 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +385 -0
- sklearnex/tests/test_run_to_run_stability.py +296 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_finite.py +89 -0
- sklearnex/utils/validation.py +17 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import functools
|
|
18
|
+
import os
|
|
19
|
+
import sys
|
|
20
|
+
import warnings
|
|
21
|
+
from typing import Any, Callable, Tuple
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from numpy.lib.recfunctions import require_fields
|
|
25
|
+
from sklearn import __version__ as sklearn_version
|
|
26
|
+
|
|
27
|
+
from daal4py import _get__daal_link_version__ as dv
|
|
28
|
+
|
|
29
|
+
DaalVersionTuple = Tuple[int, str, int]
|
|
30
|
+
|
|
31
|
+
import logging
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from packaging.version import Version
|
|
35
|
+
except ImportError:
|
|
36
|
+
from distutils.version import LooseVersion as Version
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
from pandas import DataFrame
|
|
40
|
+
from pandas.core.dtypes.cast import find_common_type
|
|
41
|
+
|
|
42
|
+
pandas_is_imported = True
|
|
43
|
+
except (ImportError, ModuleNotFoundError):
|
|
44
|
+
pandas_is_imported = False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def set_idp_sklearn_verbose():
|
|
48
|
+
logLevel = os.environ.get("IDP_SKLEARN_VERBOSE")
|
|
49
|
+
try:
|
|
50
|
+
if logLevel is not None:
|
|
51
|
+
logging.basicConfig(
|
|
52
|
+
stream=sys.stdout,
|
|
53
|
+
format="%(levelname)s: %(message)s",
|
|
54
|
+
level=logLevel.upper(),
|
|
55
|
+
)
|
|
56
|
+
except Exception:
|
|
57
|
+
warnings.warn(
|
|
58
|
+
'Unknown level "{}" for logging.\n'
|
|
59
|
+
'Please, use one of "CRITICAL", "ERROR", '
|
|
60
|
+
'"WARNING", "INFO", "DEBUG".'.format(logLevel)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_daal_version() -> DaalVersionTuple:
|
|
65
|
+
return int(dv()[0:4]), str(dv()[10:11]), int(dv()[4:8])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@functools.lru_cache(maxsize=256, typed=False)
|
|
69
|
+
def daal_check_version(
|
|
70
|
+
required_version: Tuple[Any, ...],
|
|
71
|
+
daal_version: Tuple[Any, ...] = get_daal_version(),
|
|
72
|
+
) -> bool:
|
|
73
|
+
"""Check daal version provided as (MAJOR, STATUS, MINOR+PATCH)
|
|
74
|
+
|
|
75
|
+
This function also accepts a list or tuple of daal versions. It will return true if
|
|
76
|
+
any version in the list/tuple is <= `daal_version`.
|
|
77
|
+
"""
|
|
78
|
+
if isinstance(required_version[0], (list, tuple)):
|
|
79
|
+
# a list of version candidates was provided, recursively check if any is <= daal_version
|
|
80
|
+
return any(
|
|
81
|
+
map(lambda ver: daal_check_version(ver, daal_version), required_version)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
major_required, status_required, patch_required = required_version
|
|
85
|
+
major, status, patch = daal_version
|
|
86
|
+
|
|
87
|
+
if status != status_required:
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
if major_required < major:
|
|
91
|
+
return True
|
|
92
|
+
if major == major_required:
|
|
93
|
+
return patch_required <= patch
|
|
94
|
+
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@functools.lru_cache(maxsize=256, typed=False)
|
|
99
|
+
def sklearn_check_version(ver):
|
|
100
|
+
if hasattr(Version(ver), "base_version"):
|
|
101
|
+
base_sklearn_version = Version(sklearn_version).base_version
|
|
102
|
+
res = bool(Version(base_sklearn_version) >= Version(ver))
|
|
103
|
+
else:
|
|
104
|
+
# packaging module not available
|
|
105
|
+
res = bool(Version(sklearn_version) >= Version(ver))
|
|
106
|
+
return res
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def parse_dtype(dt):
|
|
110
|
+
if dt == np.double:
|
|
111
|
+
return "double"
|
|
112
|
+
if dt == np.single:
|
|
113
|
+
return "float"
|
|
114
|
+
raise ValueError(f"Input array has unexpected dtype = {dt}")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def getFPType(X):
|
|
118
|
+
if pandas_is_imported:
|
|
119
|
+
if isinstance(X, DataFrame):
|
|
120
|
+
dt = find_common_type(X.dtypes.tolist())
|
|
121
|
+
return parse_dtype(dt)
|
|
122
|
+
|
|
123
|
+
dt = getattr(X, "dtype", None)
|
|
124
|
+
return parse_dtype(dt)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def make2d(X):
|
|
128
|
+
if np.isscalar(X):
|
|
129
|
+
X = np.asarray(X)[np.newaxis, np.newaxis]
|
|
130
|
+
elif isinstance(X, np.ndarray) and X.ndim == 1:
|
|
131
|
+
X = X.reshape((X.size, 1))
|
|
132
|
+
return X
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_patch_message(s):
|
|
136
|
+
if s == "daal":
|
|
137
|
+
message = "running accelerated version on CPU"
|
|
138
|
+
|
|
139
|
+
elif s == "sklearn":
|
|
140
|
+
message = "fallback to original Scikit-learn"
|
|
141
|
+
elif s == "sklearn_after_daal":
|
|
142
|
+
message = "failed to run accelerated version, fallback to original Scikit-learn"
|
|
143
|
+
else:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Invalid input - expected one of 'daal','sklearn',"
|
|
146
|
+
f" 'sklearn_after_daal', got {s}"
|
|
147
|
+
)
|
|
148
|
+
return message
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def is_DataFrame(X):
|
|
152
|
+
if pandas_is_imported:
|
|
153
|
+
return isinstance(X, DataFrame)
|
|
154
|
+
else:
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_dtype(X):
|
|
159
|
+
if pandas_is_imported:
|
|
160
|
+
return find_common_type(list(X.dtypes)) if is_DataFrame(X) else X.dtype
|
|
161
|
+
else:
|
|
162
|
+
return getattr(X, "dtype", None)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def get_number_of_types(dataframe):
|
|
166
|
+
dtypes = getattr(dataframe, "dtypes", None)
|
|
167
|
+
try:
|
|
168
|
+
return len(set(dtypes))
|
|
169
|
+
except TypeError:
|
|
170
|
+
return 1
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def check_tree_nodes(tree_nodes):
|
|
174
|
+
def convert_to_old_tree_nodes(tree_nodes):
|
|
175
|
+
# conversion from sklearn>=1.3 tree nodes format to previous format:
|
|
176
|
+
# removal of 'missing_go_to_left' field from node dtype
|
|
177
|
+
new_field = "missing_go_to_left"
|
|
178
|
+
new_dtype = tree_nodes.dtype
|
|
179
|
+
old_dtype = np.dtype(
|
|
180
|
+
[
|
|
181
|
+
(key, value[0])
|
|
182
|
+
for key, value in new_dtype.fields.items()
|
|
183
|
+
if key != new_field
|
|
184
|
+
]
|
|
185
|
+
)
|
|
186
|
+
return require_fields(tree_nodes, old_dtype)
|
|
187
|
+
|
|
188
|
+
if sklearn_check_version("1.3"):
|
|
189
|
+
return tree_nodes
|
|
190
|
+
else:
|
|
191
|
+
return convert_to_old_tree_nodes(tree_nodes)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class PatchingConditionsChain:
|
|
195
|
+
def __init__(self, scope_name):
|
|
196
|
+
self.scope_name = scope_name
|
|
197
|
+
self.patching_is_enabled = True
|
|
198
|
+
self.messages = []
|
|
199
|
+
self.logger = logging.getLogger("sklearnex")
|
|
200
|
+
|
|
201
|
+
def _iter_conditions(self, conditions_and_messages):
|
|
202
|
+
result = []
|
|
203
|
+
for condition, message in conditions_and_messages:
|
|
204
|
+
result.append(condition)
|
|
205
|
+
if not condition:
|
|
206
|
+
self.messages.append(message)
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
def and_conditions(self, conditions_and_messages, conditions_merging=all):
|
|
210
|
+
self.patching_is_enabled &= conditions_merging(
|
|
211
|
+
self._iter_conditions(conditions_and_messages)
|
|
212
|
+
)
|
|
213
|
+
return self.patching_is_enabled
|
|
214
|
+
|
|
215
|
+
def and_condition(self, condition, message):
|
|
216
|
+
return self.and_conditions([(condition, message)])
|
|
217
|
+
|
|
218
|
+
def or_conditions(self, conditions_and_messages, conditions_merging=all):
|
|
219
|
+
self.patching_is_enabled |= conditions_merging(
|
|
220
|
+
self._iter_conditions(conditions_and_messages)
|
|
221
|
+
)
|
|
222
|
+
return self.patching_is_enabled
|
|
223
|
+
|
|
224
|
+
def write_log(self):
|
|
225
|
+
if self.patching_is_enabled:
|
|
226
|
+
self.logger.info(f"{self.scope_name}: {get_patch_message('daal')}")
|
|
227
|
+
else:
|
|
228
|
+
self.logger.debug(
|
|
229
|
+
f"{self.scope_name}: debugging for the patch is enabled to track"
|
|
230
|
+
" the usage of Intel® oneAPI Data Analytics Library (oneDAL)"
|
|
231
|
+
)
|
|
232
|
+
for message in self.messages:
|
|
233
|
+
self.logger.debug(
|
|
234
|
+
f"{self.scope_name}: patching failed with cause - {message}"
|
|
235
|
+
)
|
|
236
|
+
self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
|
|
237
|
+
|
|
238
|
+
def get_status(self, logs=False):
|
|
239
|
+
if logs:
|
|
240
|
+
self.write_log()
|
|
241
|
+
return self.patching_is_enabled
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
from .dbscan import DBSCAN
|
|
18
|
+
from .k_means import KMeans
|
|
19
|
+
|
|
20
|
+
__all__ = ["KMeans", "DBSCAN"]
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2014 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numbers
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from scipy import sparse as sp
|
|
21
|
+
from sklearn.cluster import DBSCAN as DBSCAN_original
|
|
22
|
+
from sklearn.utils import check_array
|
|
23
|
+
from sklearn.utils.validation import _check_sample_weight
|
|
24
|
+
|
|
25
|
+
import daal4py
|
|
26
|
+
|
|
27
|
+
from .._n_jobs_support import control_n_jobs
|
|
28
|
+
from .._utils import PatchingConditionsChain, getFPType, make2d, sklearn_check_version
|
|
29
|
+
|
|
30
|
+
if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
|
|
31
|
+
from sklearn.utils import check_scalar
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _daal_dbscan(X, eps=0.5, min_samples=5, sample_weight=None):
|
|
35
|
+
ww = make2d(sample_weight) if sample_weight is not None else None
|
|
36
|
+
XX = make2d(X)
|
|
37
|
+
|
|
38
|
+
fpt = getFPType(XX)
|
|
39
|
+
alg = daal4py.dbscan(
|
|
40
|
+
method="defaultDense",
|
|
41
|
+
fptype=fpt,
|
|
42
|
+
epsilon=float(eps),
|
|
43
|
+
minObservations=int(min_samples),
|
|
44
|
+
memorySavingMode=False,
|
|
45
|
+
resultsToCompute="computeCoreIndices",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
daal_res = alg.compute(XX, ww)
|
|
49
|
+
assignments = daal_res.assignments.ravel()
|
|
50
|
+
if daal_res.coreIndices is not None:
|
|
51
|
+
core_ind = daal_res.coreIndices.ravel()
|
|
52
|
+
else:
|
|
53
|
+
core_ind = np.array([], dtype=np.intc)
|
|
54
|
+
|
|
55
|
+
return (core_ind, assignments)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@control_n_jobs(decorated_methods=["fit"])
|
|
59
|
+
class DBSCAN(DBSCAN_original):
|
|
60
|
+
__doc__ = DBSCAN_original.__doc__
|
|
61
|
+
|
|
62
|
+
if sklearn_check_version("1.2"):
|
|
63
|
+
_parameter_constraints: dict = {**DBSCAN_original._parameter_constraints}
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
eps=0.5,
|
|
68
|
+
min_samples=5,
|
|
69
|
+
metric="euclidean",
|
|
70
|
+
metric_params=None,
|
|
71
|
+
algorithm="auto",
|
|
72
|
+
leaf_size=30,
|
|
73
|
+
p=None,
|
|
74
|
+
n_jobs=None,
|
|
75
|
+
):
|
|
76
|
+
self.eps = eps
|
|
77
|
+
self.min_samples = min_samples
|
|
78
|
+
self.metric = metric
|
|
79
|
+
self.metric_params = metric_params
|
|
80
|
+
self.algorithm = algorithm
|
|
81
|
+
self.leaf_size = leaf_size
|
|
82
|
+
self.p = p
|
|
83
|
+
self.n_jobs = n_jobs
|
|
84
|
+
|
|
85
|
+
def fit(self, X, y=None, sample_weight=None):
|
|
86
|
+
if sklearn_check_version("1.2"):
|
|
87
|
+
self._validate_params()
|
|
88
|
+
elif sklearn_check_version("1.1"):
|
|
89
|
+
check_scalar(
|
|
90
|
+
self.eps,
|
|
91
|
+
"eps",
|
|
92
|
+
target_type=numbers.Real,
|
|
93
|
+
min_val=0.0,
|
|
94
|
+
include_boundaries="neither",
|
|
95
|
+
)
|
|
96
|
+
check_scalar(
|
|
97
|
+
self.min_samples,
|
|
98
|
+
"min_samples",
|
|
99
|
+
target_type=numbers.Integral,
|
|
100
|
+
min_val=1,
|
|
101
|
+
include_boundaries="left",
|
|
102
|
+
)
|
|
103
|
+
check_scalar(
|
|
104
|
+
self.leaf_size,
|
|
105
|
+
"leaf_size",
|
|
106
|
+
target_type=numbers.Integral,
|
|
107
|
+
min_val=1,
|
|
108
|
+
include_boundaries="left",
|
|
109
|
+
)
|
|
110
|
+
if self.p is not None:
|
|
111
|
+
check_scalar(
|
|
112
|
+
self.p,
|
|
113
|
+
"p",
|
|
114
|
+
target_type=numbers.Real,
|
|
115
|
+
min_val=0.0,
|
|
116
|
+
include_boundaries="left",
|
|
117
|
+
)
|
|
118
|
+
if self.n_jobs is not None:
|
|
119
|
+
check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)
|
|
120
|
+
else:
|
|
121
|
+
if self.eps <= 0.0:
|
|
122
|
+
raise ValueError(f"eps == {self.eps}, must be > 0.0.")
|
|
123
|
+
|
|
124
|
+
if sklearn_check_version("1.0"):
|
|
125
|
+
self._check_feature_names(X, reset=True)
|
|
126
|
+
|
|
127
|
+
if sample_weight is not None:
|
|
128
|
+
sample_weight = _check_sample_weight(sample_weight, X)
|
|
129
|
+
|
|
130
|
+
_patching_status = PatchingConditionsChain("sklearn.cluster.DBSCAN.fit")
|
|
131
|
+
_dal_ready = _patching_status.and_conditions(
|
|
132
|
+
[
|
|
133
|
+
(
|
|
134
|
+
self.algorithm in ["auto", "brute"],
|
|
135
|
+
f"'{self.algorithm}' algorithm is not supported. "
|
|
136
|
+
"Only 'auto' and 'brute' algorithms are supported",
|
|
137
|
+
),
|
|
138
|
+
(
|
|
139
|
+
self.metric == "euclidean"
|
|
140
|
+
or (self.metric == "minkowski" and self.p == 2),
|
|
141
|
+
f"'{self.metric}' (p={self.p}) metric is not supported. "
|
|
142
|
+
"Only 'euclidean' or 'minkowski' with p=2 metrics are supported.",
|
|
143
|
+
),
|
|
144
|
+
(not sp.issparse(X), "X is sparse. Sparse input is not supported."),
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
_patching_status.write_log()
|
|
149
|
+
if _dal_ready:
|
|
150
|
+
X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
|
|
151
|
+
core_ind, assignments = _daal_dbscan(
|
|
152
|
+
X, self.eps, self.min_samples, sample_weight=sample_weight
|
|
153
|
+
)
|
|
154
|
+
self.core_sample_indices_ = core_ind
|
|
155
|
+
self.labels_ = assignments
|
|
156
|
+
self.components_ = np.take(X, core_ind, axis=0)
|
|
157
|
+
self.n_features_in_ = X.shape[1]
|
|
158
|
+
return self
|
|
159
|
+
return super().fit(X, y, sample_weight=sample_weight)
|
|
160
|
+
|
|
161
|
+
def fit_predict(self, X, y=None, sample_weight=None):
|
|
162
|
+
return super().fit_predict(X, y, sample_weight)
|
|
163
|
+
|
|
164
|
+
fit.__doc__ = DBSCAN_original.fit.__doc__
|
|
165
|
+
fit_predict.__doc__ = DBSCAN_original.fit_predict.__doc__
|