scikit-learn-intelex 2025.4.0__py313-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- daal4py/__init__.py +73 -0
- daal4py/__main__.py +58 -0
- daal4py/_daal4py.cpython-313-x86_64-linux-gnu.so +0 -0
- daal4py/doc/third-party-programs.txt +424 -0
- daal4py/mb/__init__.py +19 -0
- daal4py/mb/model_builders.py +377 -0
- daal4py/mpi_transceiver.cpython-313-x86_64-linux-gnu.so +0 -0
- daal4py/sklearn/__init__.py +40 -0
- daal4py/sklearn/_n_jobs_support.py +248 -0
- daal4py/sklearn/_utils.py +245 -0
- daal4py/sklearn/cluster/__init__.py +20 -0
- daal4py/sklearn/cluster/dbscan.py +165 -0
- daal4py/sklearn/cluster/k_means.py +597 -0
- daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
- daal4py/sklearn/decomposition/__init__.py +19 -0
- daal4py/sklearn/decomposition/_pca.py +524 -0
- daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
- daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
- daal4py/sklearn/ensemble/__init__.py +27 -0
- daal4py/sklearn/ensemble/_forest.py +1397 -0
- daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
- daal4py/sklearn/linear_model/__init__.py +29 -0
- daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
- daal4py/sklearn/linear_model/_linear.py +272 -0
- daal4py/sklearn/linear_model/_ridge.py +325 -0
- daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
- daal4py/sklearn/linear_model/linear.py +17 -0
- daal4py/sklearn/linear_model/logistic_loss.py +195 -0
- daal4py/sklearn/linear_model/logistic_path.py +1026 -0
- daal4py/sklearn/linear_model/ridge.py +17 -0
- daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
- daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
- daal4py/sklearn/manifold/__init__.py +19 -0
- daal4py/sklearn/manifold/_t_sne.py +405 -0
- daal4py/sklearn/metrics/__init__.py +20 -0
- daal4py/sklearn/metrics/_pairwise.py +236 -0
- daal4py/sklearn/metrics/_ranking.py +210 -0
- daal4py/sklearn/model_selection/__init__.py +19 -0
- daal4py/sklearn/model_selection/_split.py +309 -0
- daal4py/sklearn/model_selection/tests/test_split.py +56 -0
- daal4py/sklearn/monkeypatch/__init__.py +0 -0
- daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
- daal4py/sklearn/monkeypatch/tests/_models_info.py +161 -0
- daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
- daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
- daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
- daal4py/sklearn/neighbors/__init__.py +21 -0
- daal4py/sklearn/neighbors/_base.py +503 -0
- daal4py/sklearn/neighbors/_classification.py +139 -0
- daal4py/sklearn/neighbors/_regression.py +74 -0
- daal4py/sklearn/neighbors/_unsupervised.py +55 -0
- daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
- daal4py/sklearn/svm/__init__.py +19 -0
- daal4py/sklearn/svm/svm.py +734 -0
- daal4py/sklearn/utils/__init__.py +21 -0
- daal4py/sklearn/utils/base.py +75 -0
- daal4py/sklearn/utils/tests/test_utils.py +51 -0
- daal4py/sklearn/utils/validation.py +696 -0
- onedal/__init__.py +83 -0
- onedal/_config.py +54 -0
- onedal/_device_offload.py +204 -0
- onedal/_onedal_py_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_host.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/_onedal_py_spmd_dpc.cpython-313-x86_64-linux-gnu.so +0 -0
- onedal/basic_statistics/__init__.py +20 -0
- onedal/basic_statistics/basic_statistics.py +107 -0
- onedal/basic_statistics/incremental_basic_statistics.py +175 -0
- onedal/basic_statistics/tests/test_basic_statistics.py +242 -0
- onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
- onedal/basic_statistics/tests/utils.py +50 -0
- onedal/cluster/__init__.py +27 -0
- onedal/cluster/dbscan.py +105 -0
- onedal/cluster/kmeans.py +557 -0
- onedal/cluster/kmeans_init.py +112 -0
- onedal/cluster/tests/test_dbscan.py +125 -0
- onedal/cluster/tests/test_kmeans.py +88 -0
- onedal/cluster/tests/test_kmeans_init.py +93 -0
- onedal/common/_base.py +38 -0
- onedal/common/_estimator_checks.py +47 -0
- onedal/common/_mixin.py +62 -0
- onedal/common/_policy.py +55 -0
- onedal/common/_spmd_policy.py +30 -0
- onedal/common/hyperparameters.py +125 -0
- onedal/common/tests/test_policy.py +76 -0
- onedal/common/tests/test_sycl.py +128 -0
- onedal/covariance/__init__.py +20 -0
- onedal/covariance/covariance.py +122 -0
- onedal/covariance/incremental_covariance.py +161 -0
- onedal/covariance/tests/test_covariance.py +50 -0
- onedal/covariance/tests/test_incremental_covariance.py +190 -0
- onedal/datatypes/__init__.py +19 -0
- onedal/datatypes/_data_conversion.py +121 -0
- onedal/datatypes/tests/common.py +126 -0
- onedal/datatypes/tests/test_data.py +475 -0
- onedal/decomposition/__init__.py +20 -0
- onedal/decomposition/incremental_pca.py +214 -0
- onedal/decomposition/pca.py +186 -0
- onedal/decomposition/tests/test_incremental_pca.py +285 -0
- onedal/ensemble/__init__.py +29 -0
- onedal/ensemble/forest.py +736 -0
- onedal/ensemble/tests/test_random_forest.py +97 -0
- onedal/linear_model/__init__.py +27 -0
- onedal/linear_model/incremental_linear_model.py +292 -0
- onedal/linear_model/linear_model.py +325 -0
- onedal/linear_model/logistic_regression.py +247 -0
- onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
- onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
- onedal/linear_model/tests/test_linear_regression.py +259 -0
- onedal/linear_model/tests/test_logistic_regression.py +95 -0
- onedal/linear_model/tests/test_ridge.py +95 -0
- onedal/neighbors/__init__.py +19 -0
- onedal/neighbors/neighbors.py +763 -0
- onedal/neighbors/tests/test_knn_classification.py +49 -0
- onedal/primitives/__init__.py +27 -0
- onedal/primitives/get_tree.py +25 -0
- onedal/primitives/kernel_functions.py +152 -0
- onedal/primitives/tests/test_kernel_functions.py +159 -0
- onedal/spmd/__init__.py +25 -0
- onedal/spmd/_base.py +30 -0
- onedal/spmd/basic_statistics/__init__.py +20 -0
- onedal/spmd/basic_statistics/basic_statistics.py +30 -0
- onedal/spmd/basic_statistics/incremental_basic_statistics.py +71 -0
- onedal/spmd/cluster/__init__.py +28 -0
- onedal/spmd/cluster/dbscan.py +23 -0
- onedal/spmd/cluster/kmeans.py +56 -0
- onedal/spmd/covariance/__init__.py +20 -0
- onedal/spmd/covariance/covariance.py +26 -0
- onedal/spmd/covariance/incremental_covariance.py +83 -0
- onedal/spmd/decomposition/__init__.py +20 -0
- onedal/spmd/decomposition/incremental_pca.py +124 -0
- onedal/spmd/decomposition/pca.py +26 -0
- onedal/spmd/ensemble/__init__.py +19 -0
- onedal/spmd/ensemble/forest.py +28 -0
- onedal/spmd/linear_model/__init__.py +21 -0
- onedal/spmd/linear_model/incremental_linear_model.py +101 -0
- onedal/spmd/linear_model/linear_model.py +30 -0
- onedal/spmd/linear_model/logistic_regression.py +38 -0
- onedal/spmd/neighbors/__init__.py +19 -0
- onedal/spmd/neighbors/neighbors.py +75 -0
- onedal/svm/__init__.py +19 -0
- onedal/svm/svm.py +556 -0
- onedal/svm/tests/test_csr_svm.py +351 -0
- onedal/svm/tests/test_nusvc.py +204 -0
- onedal/svm/tests/test_nusvr.py +210 -0
- onedal/svm/tests/test_svc.py +176 -0
- onedal/svm/tests/test_svr.py +243 -0
- onedal/tests/test_common.py +57 -0
- onedal/tests/utils/_dataframes_support.py +162 -0
- onedal/tests/utils/_device_selection.py +102 -0
- onedal/utils/__init__.py +49 -0
- onedal/utils/_array_api.py +81 -0
- onedal/utils/_dpep_helpers.py +56 -0
- onedal/utils/tests/test_validation.py +142 -0
- onedal/utils/validation.py +464 -0
- scikit_learn_intelex-2025.4.0.dist-info/LICENSE.txt +202 -0
- scikit_learn_intelex-2025.4.0.dist-info/METADATA +190 -0
- scikit_learn_intelex-2025.4.0.dist-info/RECORD +282 -0
- scikit_learn_intelex-2025.4.0.dist-info/WHEEL +5 -0
- scikit_learn_intelex-2025.4.0.dist-info/top_level.txt +3 -0
- sklearnex/__init__.py +66 -0
- sklearnex/__main__.py +58 -0
- sklearnex/_config.py +116 -0
- sklearnex/_device_offload.py +126 -0
- sklearnex/_utils.py +177 -0
- sklearnex/basic_statistics/__init__.py +20 -0
- sklearnex/basic_statistics/basic_statistics.py +261 -0
- sklearnex/basic_statistics/incremental_basic_statistics.py +352 -0
- sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
- sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
- sklearnex/cluster/__init__.py +20 -0
- sklearnex/cluster/dbscan.py +197 -0
- sklearnex/cluster/k_means.py +397 -0
- sklearnex/cluster/tests/test_dbscan.py +38 -0
- sklearnex/cluster/tests/test_kmeans.py +157 -0
- sklearnex/conftest.py +82 -0
- sklearnex/covariance/__init__.py +19 -0
- sklearnex/covariance/incremental_covariance.py +405 -0
- sklearnex/covariance/tests/test_incremental_covariance.py +287 -0
- sklearnex/decomposition/__init__.py +19 -0
- sklearnex/decomposition/pca.py +427 -0
- sklearnex/decomposition/tests/test_pca.py +58 -0
- sklearnex/dispatcher.py +534 -0
- sklearnex/doc/third-party-programs.txt +424 -0
- sklearnex/ensemble/__init__.py +29 -0
- sklearnex/ensemble/_forest.py +2029 -0
- sklearnex/ensemble/tests/test_forest.py +140 -0
- sklearnex/glob/__main__.py +72 -0
- sklearnex/glob/dispatcher.py +101 -0
- sklearnex/linear_model/__init__.py +32 -0
- sklearnex/linear_model/coordinate_descent.py +30 -0
- sklearnex/linear_model/incremental_linear.py +495 -0
- sklearnex/linear_model/incremental_ridge.py +432 -0
- sklearnex/linear_model/linear.py +346 -0
- sklearnex/linear_model/logistic_regression.py +415 -0
- sklearnex/linear_model/ridge.py +390 -0
- sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
- sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
- sklearnex/linear_model/tests/test_linear.py +142 -0
- sklearnex/linear_model/tests/test_logreg.py +134 -0
- sklearnex/linear_model/tests/test_ridge.py +256 -0
- sklearnex/manifold/__init__.py +19 -0
- sklearnex/manifold/t_sne.py +26 -0
- sklearnex/manifold/tests/test_tsne.py +250 -0
- sklearnex/metrics/__init__.py +23 -0
- sklearnex/metrics/pairwise.py +22 -0
- sklearnex/metrics/ranking.py +20 -0
- sklearnex/metrics/tests/test_metrics.py +39 -0
- sklearnex/model_selection/__init__.py +21 -0
- sklearnex/model_selection/split.py +22 -0
- sklearnex/model_selection/tests/test_model_selection.py +34 -0
- sklearnex/neighbors/__init__.py +27 -0
- sklearnex/neighbors/_lof.py +236 -0
- sklearnex/neighbors/common.py +310 -0
- sklearnex/neighbors/knn_classification.py +231 -0
- sklearnex/neighbors/knn_regression.py +207 -0
- sklearnex/neighbors/knn_unsupervised.py +178 -0
- sklearnex/neighbors/tests/test_neighbors.py +82 -0
- sklearnex/preview/__init__.py +17 -0
- sklearnex/preview/covariance/__init__.py +19 -0
- sklearnex/preview/covariance/covariance.py +142 -0
- sklearnex/preview/covariance/tests/test_covariance.py +66 -0
- sklearnex/preview/decomposition/__init__.py +19 -0
- sklearnex/preview/decomposition/incremental_pca.py +244 -0
- sklearnex/preview/decomposition/tests/test_incremental_pca.py +336 -0
- sklearnex/spmd/__init__.py +25 -0
- sklearnex/spmd/basic_statistics/__init__.py +20 -0
- sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
- sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
- sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
- sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +306 -0
- sklearnex/spmd/cluster/__init__.py +30 -0
- sklearnex/spmd/cluster/dbscan.py +50 -0
- sklearnex/spmd/cluster/kmeans.py +21 -0
- sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
- sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +173 -0
- sklearnex/spmd/covariance/__init__.py +20 -0
- sklearnex/spmd/covariance/covariance.py +21 -0
- sklearnex/spmd/covariance/incremental_covariance.py +37 -0
- sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
- sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
- sklearnex/spmd/decomposition/__init__.py +20 -0
- sklearnex/spmd/decomposition/incremental_pca.py +30 -0
- sklearnex/spmd/decomposition/pca.py +21 -0
- sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
- sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
- sklearnex/spmd/ensemble/__init__.py +19 -0
- sklearnex/spmd/ensemble/forest.py +71 -0
- sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
- sklearnex/spmd/linear_model/__init__.py +21 -0
- sklearnex/spmd/linear_model/incremental_linear_model.py +35 -0
- sklearnex/spmd/linear_model/linear_model.py +21 -0
- sklearnex/spmd/linear_model/logistic_regression.py +21 -0
- sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +331 -0
- sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
- sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
- sklearnex/spmd/neighbors/__init__.py +19 -0
- sklearnex/spmd/neighbors/neighbors.py +25 -0
- sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
- sklearnex/svm/__init__.py +29 -0
- sklearnex/svm/_common.py +339 -0
- sklearnex/svm/nusvc.py +371 -0
- sklearnex/svm/nusvr.py +170 -0
- sklearnex/svm/svc.py +399 -0
- sklearnex/svm/svr.py +167 -0
- sklearnex/svm/tests/test_svm.py +93 -0
- sklearnex/tests/test_common.py +491 -0
- sklearnex/tests/test_config.py +123 -0
- sklearnex/tests/test_hyperparameters.py +43 -0
- sklearnex/tests/test_memory_usage.py +347 -0
- sklearnex/tests/test_monkeypatch.py +269 -0
- sklearnex/tests/test_n_jobs_support.py +108 -0
- sklearnex/tests/test_parallel.py +48 -0
- sklearnex/tests/test_patching.py +377 -0
- sklearnex/tests/test_run_to_run_stability.py +326 -0
- sklearnex/tests/utils/__init__.py +48 -0
- sklearnex/tests/utils/base.py +436 -0
- sklearnex/tests/utils/spmd.py +198 -0
- sklearnex/utils/__init__.py +19 -0
- sklearnex/utils/_array_api.py +82 -0
- sklearnex/utils/parallel.py +59 -0
- sklearnex/utils/tests/test_validation.py +238 -0
- sklearnex/utils/validation.py +208 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any, Dict, Tuple
|
|
19
|
+
from warnings import warn
|
|
20
|
+
|
|
21
|
+
from daal4py.sklearn._utils import daal_check_version
|
|
22
|
+
from onedal import _backend
|
|
23
|
+
|
|
24
|
+
if not daal_check_version((2024, "P", 0)):
|
|
25
|
+
warn("Hyperparameters are supported in oneDAL starting from 2024.0.0 version.")
|
|
26
|
+
hyperparameters_map = {}
|
|
27
|
+
else:
|
|
28
|
+
_hparams_reserved_words = [
|
|
29
|
+
"algorithm",
|
|
30
|
+
"op",
|
|
31
|
+
"setters",
|
|
32
|
+
"getters",
|
|
33
|
+
"backend",
|
|
34
|
+
"is_default",
|
|
35
|
+
"to_dict",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
class HyperParameters:
|
|
39
|
+
"""Class for simplified interaction with oneDAL hyperparameters.
|
|
40
|
+
Overrides `__getattribute__` and `__setattr__` to utilize getters and setters
|
|
41
|
+
of hyperparameter class from onedal backend.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, algorithm, op, setters, getters, backend):
|
|
45
|
+
self.algorithm = algorithm
|
|
46
|
+
self.op = op
|
|
47
|
+
self.setters = setters
|
|
48
|
+
self.getters = getters
|
|
49
|
+
self.backend = backend
|
|
50
|
+
self.is_default = True
|
|
51
|
+
|
|
52
|
+
def __getattribute__(self, __name):
|
|
53
|
+
if __name in _hparams_reserved_words:
|
|
54
|
+
if __name == "backend":
|
|
55
|
+
# `backend` attribute accessed only for oneDAL kernel calls
|
|
56
|
+
logging.getLogger("sklearnex").debug(
|
|
57
|
+
"Using next hyperparameters for "
|
|
58
|
+
f"'{self.algorithm}.{self.op}': {self.to_dict()}"
|
|
59
|
+
)
|
|
60
|
+
return super().__getattribute__(__name)
|
|
61
|
+
elif __name in self.getters.keys():
|
|
62
|
+
return self.getters[__name]()
|
|
63
|
+
try:
|
|
64
|
+
# try to return attribute from base class
|
|
65
|
+
# required to read builtin attributes like __class__, __doc__, etc.
|
|
66
|
+
# which are used in debuggers
|
|
67
|
+
return super().__getattribute__(__name)
|
|
68
|
+
except AttributeError:
|
|
69
|
+
# raise an AttributeError with a hyperparameter-specific message
|
|
70
|
+
# for easier debugging
|
|
71
|
+
raise AttributeError(
|
|
72
|
+
f"Unknown attribute '{__name}' in "
|
|
73
|
+
f"'{self.algorithm}.{self.op}' hyperparameters"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def __setattr__(self, __name, __value):
|
|
77
|
+
if __name in _hparams_reserved_words:
|
|
78
|
+
super().__setattr__(__name, __value)
|
|
79
|
+
elif __name in self.setters.keys():
|
|
80
|
+
self.is_default = False
|
|
81
|
+
self.setters[__name](__value)
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Unknown attribute '{__name}' in "
|
|
85
|
+
f"'{self.algorithm}.{self.op}' hyperparameters"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def to_dict(self):
|
|
89
|
+
return {name: getter() for name, getter in self.getters.items()}
|
|
90
|
+
|
|
91
|
+
def get_methods_with_prefix(obj, prefix):
|
|
92
|
+
return {
|
|
93
|
+
method.replace(prefix, ""): getattr(obj, method)
|
|
94
|
+
for method in filter(lambda f: f.startswith(prefix), dir(obj))
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
hyperparameters_backend: Dict[Tuple[str, str], Any] = {
|
|
98
|
+
(
|
|
99
|
+
"linear_regression",
|
|
100
|
+
"train",
|
|
101
|
+
): _backend.linear_model.regression.train_hyperparameters(),
|
|
102
|
+
("covariance", "compute"): _backend.covariance.compute_hyperparameters(),
|
|
103
|
+
}
|
|
104
|
+
if daal_check_version((2024, "P", 300)):
|
|
105
|
+
df_infer_hp = _backend.decision_forest.infer_hyperparameters
|
|
106
|
+
hyperparameters_backend[("decision_forest", "infer")] = df_infer_hp()
|
|
107
|
+
hyperparameters_map = {}
|
|
108
|
+
|
|
109
|
+
for (algorithm, op), hyperparameters in hyperparameters_backend.items():
|
|
110
|
+
setters = get_methods_with_prefix(hyperparameters, "set_")
|
|
111
|
+
getters = get_methods_with_prefix(hyperparameters, "get_")
|
|
112
|
+
|
|
113
|
+
if set(setters.keys()) != set(getters.keys()):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Setters and getters in '{algorithm}.{op}' "
|
|
116
|
+
"hyperparameters wrapper do not correspond."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
hyperparameters_map[(algorithm, op)] = HyperParameters(
|
|
120
|
+
algorithm, op, setters, getters, hyperparameters
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_hyperparameters(algorithm, op):
|
|
125
|
+
return hyperparameters_map.get((algorithm, op), None)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2021 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
from onedal.common._policy import _get_policy
|
|
21
|
+
from onedal.tests.utils._device_selection import (
|
|
22
|
+
device_type_to_str,
|
|
23
|
+
get_memory_usm,
|
|
24
|
+
get_queues,
|
|
25
|
+
is_dpctl_device_available,
|
|
26
|
+
)
|
|
27
|
+
from onedal.utils._dpep_helpers import dpctl_available
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
31
|
+
def test_queue_passed_directly(queue):
|
|
32
|
+
device_name = device_type_to_str(queue)
|
|
33
|
+
test_queue = _get_policy(queue)
|
|
34
|
+
test_device_name = test_queue.get_device_name()
|
|
35
|
+
assert test_device_name == device_name
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
39
|
+
def test_with_numpy_data(queue):
|
|
40
|
+
X = np.zeros((5, 3))
|
|
41
|
+
y = np.zeros(3)
|
|
42
|
+
|
|
43
|
+
device_name = device_type_to_str(queue)
|
|
44
|
+
assert _get_policy(queue, X, y).get_device_name() == device_name
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.mark.skipif(not dpctl_available, reason="depends on dpctl")
|
|
48
|
+
@pytest.mark.parametrize("queue", get_queues("cpu,gpu"))
|
|
49
|
+
@pytest.mark.parametrize("memtype", get_memory_usm())
|
|
50
|
+
def test_with_usm_ndarray_data(queue, memtype):
|
|
51
|
+
if queue is None:
|
|
52
|
+
pytest.skip(
|
|
53
|
+
"dpctl Memory object with queue=None uses cached default (gpu if available)"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
from dpctl.tensor import usm_ndarray
|
|
57
|
+
|
|
58
|
+
device_name = device_type_to_str(queue)
|
|
59
|
+
X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=queue))
|
|
60
|
+
y = usm_ndarray((3,), buffer=memtype(3 * 8, queue=queue))
|
|
61
|
+
assert _get_policy(None, X, y).get_device_name() == device_name
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.skipif(
|
|
65
|
+
not is_dpctl_device_available(["cpu", "gpu"]), reason="test uses multiple devices"
|
|
66
|
+
)
|
|
67
|
+
@pytest.mark.parametrize("memtype", get_memory_usm())
|
|
68
|
+
def test_queue_parameter_with_usm_ndarray(memtype):
|
|
69
|
+
from dpctl import SyclQueue
|
|
70
|
+
from dpctl.tensor import usm_ndarray
|
|
71
|
+
|
|
72
|
+
q1 = SyclQueue("cpu")
|
|
73
|
+
q2 = SyclQueue("gpu")
|
|
74
|
+
|
|
75
|
+
X = usm_ndarray((5, 3), buffer=memtype(5 * 3 * 8, queue=q1))
|
|
76
|
+
assert _get_policy(q2, X).get_device_name() == device_type_to_str(q2)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# ==============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ==============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
from onedal import _backend, _is_dpc_backend
|
|
21
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
22
|
+
from onedal.utils._dpep_helpers import dpctl_available
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.mark.skipif(
|
|
26
|
+
not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl"
|
|
27
|
+
)
|
|
28
|
+
@pytest.mark.parametrize("device_type", ["cpu", "gpu"])
|
|
29
|
+
@pytest.mark.parametrize("device_number", [None, 0, 1, 2, 3])
|
|
30
|
+
def test_sycl_queue_string_creation(device_type, device_number):
|
|
31
|
+
# create devices from strings
|
|
32
|
+
from dpctl import SyclQueue
|
|
33
|
+
from dpctl._sycl_queue import SyclQueueCreationError
|
|
34
|
+
|
|
35
|
+
onedal_SyclQueue = _backend.SyclQueue
|
|
36
|
+
|
|
37
|
+
device = (
|
|
38
|
+
":".join([device_type, str(device_number)])
|
|
39
|
+
if device_number is not None
|
|
40
|
+
else device_type
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
raised_exception_dpctl = False
|
|
44
|
+
raised_exception_backend = False
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
dpctl_string = SyclQueue(device).sycl_device.filter_string
|
|
48
|
+
except SyclQueueCreationError:
|
|
49
|
+
raised_exception_dpctl = True
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
onedal_string = onedal_SyclQueue(device).sycl_device.filter_string
|
|
53
|
+
except RuntimeError:
|
|
54
|
+
raised_exception_backend = True
|
|
55
|
+
|
|
56
|
+
assert raised_exception_dpctl == raised_exception_backend
|
|
57
|
+
if not raised_exception_backend:
|
|
58
|
+
# dpctl filter string converts simple sycl filter_strings
|
|
59
|
+
# i.e. "gpu:1" -> "opencl:gpu:0", use SyclQueue to convert
|
|
60
|
+
# for matching, as oneDAL sycl queue only returns simple
|
|
61
|
+
# strings as these are operationally sufficient
|
|
62
|
+
assert SyclQueue(onedal_string).sycl_device.filter_string == dpctl_string
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.mark.skipif(
|
|
66
|
+
not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl"
|
|
67
|
+
)
|
|
68
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
69
|
+
def test_sycl_queue_conversion(queue):
|
|
70
|
+
if queue is None:
|
|
71
|
+
pytest.skip("Not a dpctl queue")
|
|
72
|
+
SyclQueue = queue.__class__
|
|
73
|
+
onedal_SyclQueue = _backend.SyclQueue
|
|
74
|
+
|
|
75
|
+
q = onedal_SyclQueue(queue)
|
|
76
|
+
|
|
77
|
+
# convert back and forth to test `_get_capsule` attribute
|
|
78
|
+
for i in range(10):
|
|
79
|
+
q = SyclQueue(q.sycl_device.filter_string)
|
|
80
|
+
q = onedal_SyclQueue(q)
|
|
81
|
+
|
|
82
|
+
assert q.sycl_device.filter_string in queue.sycl_device.filter_string
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.mark.skipif(
|
|
86
|
+
not _is_dpc_backend or not dpctl_available, reason="requires dpc backend and dpctl"
|
|
87
|
+
)
|
|
88
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
89
|
+
def test_sycl_device_attributes(queue):
|
|
90
|
+
from dpctl import SyclQueue
|
|
91
|
+
|
|
92
|
+
if queue is None:
|
|
93
|
+
pytest.skip("Not a dpctl queue")
|
|
94
|
+
onedal_SyclQueue = _backend.SyclQueue
|
|
95
|
+
|
|
96
|
+
onedal_queue = onedal_SyclQueue(queue)
|
|
97
|
+
|
|
98
|
+
# check fp64 support
|
|
99
|
+
assert onedal_queue.sycl_device.has_aspect_fp64 == queue.sycl_device.has_aspect_fp64
|
|
100
|
+
# check fp16 support
|
|
101
|
+
assert onedal_queue.sycl_device.has_aspect_fp16 == queue.sycl_device.has_aspect_fp16
|
|
102
|
+
# check is_cpu
|
|
103
|
+
assert onedal_queue.sycl_device.is_cpu == queue.sycl_device.is_cpu
|
|
104
|
+
# check is_gpu
|
|
105
|
+
assert onedal_queue.sycl_device.is_gpu == queue.sycl_device.is_gpu
|
|
106
|
+
# check device number
|
|
107
|
+
assert onedal_queue.sycl_device.filter_string in queue.sycl_device.filter_string
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.mark.skipif(not _is_dpc_backend, reason="requires dpc backend")
|
|
111
|
+
def test_backend_queue():
|
|
112
|
+
try:
|
|
113
|
+
q = _backend.SyclQueue("cpu")
|
|
114
|
+
except RuntimeError:
|
|
115
|
+
pytest.skip("OpenCL CPU runtime not installed")
|
|
116
|
+
|
|
117
|
+
# verify copying via a py capsule object is functional
|
|
118
|
+
q2 = _backend.SyclQueue(q._get_capsule())
|
|
119
|
+
# verify copying via the _get_capsule attribute
|
|
120
|
+
q3 = _backend.SyclQueue(q)
|
|
121
|
+
|
|
122
|
+
q_array = [q, q2, q3]
|
|
123
|
+
|
|
124
|
+
assert all([queue.sycl_device.has_aspect_fp64 for queue in q_array])
|
|
125
|
+
assert all([queue.sycl_device.has_aspect_fp16 for queue in q_array])
|
|
126
|
+
assert all([queue.sycl_device.is_cpu for queue in q_array])
|
|
127
|
+
assert all([not queue.sycl_device.is_gpu for queue in q_array])
|
|
128
|
+
assert all(["cpu" in queue.sycl_device.filter_string for queue in q_array])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
from .covariance import EmpiricalCovariance
|
|
18
|
+
from .incremental_covariance import IncrementalEmpiricalCovariance
|
|
19
|
+
|
|
20
|
+
__all__ = ["EmpiricalCovariance", "IncrementalEmpiricalCovariance"]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
from abc import ABCMeta
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
21
|
+
from onedal.utils import _check_array
|
|
22
|
+
|
|
23
|
+
from ..common._base import BaseEstimator
|
|
24
|
+
from ..common.hyperparameters import get_hyperparameters
|
|
25
|
+
from ..datatypes import from_table, to_table
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BaseEmpiricalCovariance(BaseEstimator, metaclass=ABCMeta):
|
|
29
|
+
def __init__(self, method="dense", bias=False, assume_centered=False):
|
|
30
|
+
self.method = method
|
|
31
|
+
self.bias = bias
|
|
32
|
+
self.assume_centered = assume_centered
|
|
33
|
+
|
|
34
|
+
def _get_onedal_params(self, dtype=np.float32):
|
|
35
|
+
params = {
|
|
36
|
+
"fptype": dtype,
|
|
37
|
+
"method": self.method,
|
|
38
|
+
}
|
|
39
|
+
if daal_check_version((2024, "P", 1)):
|
|
40
|
+
params["bias"] = self.bias
|
|
41
|
+
if daal_check_version((2024, "P", 400)):
|
|
42
|
+
params["assumeCentered"] = self.assume_centered
|
|
43
|
+
|
|
44
|
+
return params
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EmpiricalCovariance(BaseEmpiricalCovariance):
|
|
48
|
+
"""Covariance estimator.
|
|
49
|
+
|
|
50
|
+
Computes sample covariance matrix.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
method : string, default="dense"
|
|
55
|
+
Specifies computation method. Available methods: "dense".
|
|
56
|
+
|
|
57
|
+
bias: bool, default=False
|
|
58
|
+
If True biased estimation of covariance is computed which equals to
|
|
59
|
+
the unbiased one multiplied by (n_samples - 1) / n_samples.
|
|
60
|
+
|
|
61
|
+
assume_centered : bool, default=False
|
|
62
|
+
If True, data are not centered before computation.
|
|
63
|
+
Useful when working with data whose mean is almost, but not exactly
|
|
64
|
+
zero.
|
|
65
|
+
If False (default), data are centered before computation.
|
|
66
|
+
|
|
67
|
+
Attributes
|
|
68
|
+
----------
|
|
69
|
+
location_ : ndarray of shape (n_features,)
|
|
70
|
+
Estimated location, i.e., the estimated mean.
|
|
71
|
+
|
|
72
|
+
covariance_ : ndarray of shape (n_features, n_features)
|
|
73
|
+
Estimated covariance matrix
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def fit(self, X, y=None, queue=None):
|
|
77
|
+
"""Fit the sample covariance matrix of X.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
X : array-like of shape (n_samples, n_features)
|
|
82
|
+
Training data, where `n_samples` is the number of samples, and
|
|
83
|
+
`n_features` is the number of features.
|
|
84
|
+
|
|
85
|
+
y : Ignored
|
|
86
|
+
Not used, present for API consistency by convention.
|
|
87
|
+
|
|
88
|
+
queue : dpctl.SyclQueue
|
|
89
|
+
If not None, use this queue for computations.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
self : object
|
|
94
|
+
Returns the instance itself.
|
|
95
|
+
"""
|
|
96
|
+
policy = self._get_policy(queue, X)
|
|
97
|
+
X = _check_array(X, dtype=[np.float64, np.float32])
|
|
98
|
+
X = to_table(X, queue=queue)
|
|
99
|
+
params = self._get_onedal_params(X.dtype)
|
|
100
|
+
hparams = get_hyperparameters("covariance", "compute")
|
|
101
|
+
if hparams is not None and not hparams.is_default:
|
|
102
|
+
result = self._get_backend(
|
|
103
|
+
"covariance",
|
|
104
|
+
None,
|
|
105
|
+
"compute",
|
|
106
|
+
policy,
|
|
107
|
+
params,
|
|
108
|
+
hparams.backend,
|
|
109
|
+
X,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
result = self._get_backend("covariance", None, "compute", policy, params, X)
|
|
113
|
+
if daal_check_version((2024, "P", 1)) or (not self.bias):
|
|
114
|
+
self.covariance_ = from_table(result.cov_matrix)
|
|
115
|
+
else:
|
|
116
|
+
self.covariance_ = (
|
|
117
|
+
from_table(result.cov_matrix) * (X.shape[0] - 1) / X.shape[0]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.location_ = from_table(result.means).ravel()
|
|
121
|
+
|
|
122
|
+
return self
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2024 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from daal4py.sklearn._utils import daal_check_version, get_dtype
|
|
19
|
+
|
|
20
|
+
from ..datatypes import from_table, to_table
|
|
21
|
+
from ..utils import _check_array
|
|
22
|
+
from .covariance import BaseEmpiricalCovariance
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class IncrementalEmpiricalCovariance(BaseEmpiricalCovariance):
|
|
26
|
+
"""
|
|
27
|
+
Covariance estimator based on oneDAL implementation.
|
|
28
|
+
|
|
29
|
+
Computes sample covariance matrix.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
method : string, default="dense"
|
|
34
|
+
Specifies computation method. Available methods: "dense".
|
|
35
|
+
|
|
36
|
+
bias: bool, default=False
|
|
37
|
+
If True biased estimation of covariance is computed which equals to
|
|
38
|
+
the unbiased one multiplied by (n_samples - 1) / n_samples.
|
|
39
|
+
|
|
40
|
+
assume_centered : bool, default=False
|
|
41
|
+
If True, data are not centered before computation.
|
|
42
|
+
Useful when working with data whose mean is almost, but not exactly
|
|
43
|
+
zero.
|
|
44
|
+
If False (default), data are centered before computation.
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
location_ : ndarray of shape (n_features,)
|
|
49
|
+
Estimated location, i.e., the estimated mean.
|
|
50
|
+
|
|
51
|
+
covariance_ : ndarray of shape (n_features, n_features)
|
|
52
|
+
Estimated covariance matrix
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, method="dense", bias=False, assume_centered=False):
|
|
56
|
+
super().__init__(method, bias, assume_centered)
|
|
57
|
+
self._reset()
|
|
58
|
+
|
|
59
|
+
def _reset(self):
|
|
60
|
+
self._need_to_finalize = False
|
|
61
|
+
self._partial_result = self._get_backend(
|
|
62
|
+
"covariance", None, "partial_compute_result"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def __getstate__(self):
|
|
66
|
+
# Since finalize_fit can't be dispatched without directly provided queue
|
|
67
|
+
# and the dispatching policy can't be serialized, the computation is finalized
|
|
68
|
+
# here and the policy is not saved in serialized data.
|
|
69
|
+
|
|
70
|
+
self.finalize_fit()
|
|
71
|
+
data = self.__dict__.copy()
|
|
72
|
+
data.pop("_queue", None)
|
|
73
|
+
|
|
74
|
+
return data
|
|
75
|
+
|
|
76
|
+
def partial_fit(self, X, y=None, queue=None):
|
|
77
|
+
"""
|
|
78
|
+
Computes partial data for the covariance matrix
|
|
79
|
+
from data batch X and saves it to `_partial_result`.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
X : array-like of shape (n_samples, n_features)
|
|
84
|
+
Training data batch, where `n_samples` is the number of samples
|
|
85
|
+
in the batch, and `n_features` is the number of features.
|
|
86
|
+
|
|
87
|
+
y : Ignored
|
|
88
|
+
Not used, present for API consistency by convention.
|
|
89
|
+
|
|
90
|
+
queue : dpctl.SyclQueue
|
|
91
|
+
If not None, use this queue for computations.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
self : object
|
|
96
|
+
Returns the instance itself.
|
|
97
|
+
"""
|
|
98
|
+
X = _check_array(X, dtype=[np.float64, np.float32], ensure_2d=True)
|
|
99
|
+
|
|
100
|
+
self._queue = queue
|
|
101
|
+
|
|
102
|
+
policy = self._get_policy(queue, X)
|
|
103
|
+
|
|
104
|
+
X_table = to_table(X, queue=queue)
|
|
105
|
+
|
|
106
|
+
if not hasattr(self, "_dtype"):
|
|
107
|
+
self._dtype = X_table.dtype
|
|
108
|
+
|
|
109
|
+
params = self._get_onedal_params(self._dtype)
|
|
110
|
+
self._partial_result = self._get_backend(
|
|
111
|
+
"covariance",
|
|
112
|
+
None,
|
|
113
|
+
"partial_compute",
|
|
114
|
+
policy,
|
|
115
|
+
params,
|
|
116
|
+
self._partial_result,
|
|
117
|
+
X_table,
|
|
118
|
+
)
|
|
119
|
+
self._need_to_finalize = True
|
|
120
|
+
|
|
121
|
+
def finalize_fit(self, queue=None):
|
|
122
|
+
"""
|
|
123
|
+
Finalizes covariance matrix and obtains `covariance_` and `location_`
|
|
124
|
+
attributes from the current `_partial_result`.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
queue : dpctl.SyclQueue
|
|
129
|
+
If not None, use this queue for computations.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
self : object
|
|
134
|
+
Returns the instance itself.
|
|
135
|
+
"""
|
|
136
|
+
if self._need_to_finalize:
|
|
137
|
+
params = self._get_onedal_params(self._dtype)
|
|
138
|
+
if queue is not None:
|
|
139
|
+
policy = self._get_policy(queue)
|
|
140
|
+
else:
|
|
141
|
+
policy = self._get_policy(self._queue)
|
|
142
|
+
|
|
143
|
+
result = self._get_backend(
|
|
144
|
+
"covariance",
|
|
145
|
+
None,
|
|
146
|
+
"finalize_compute",
|
|
147
|
+
policy,
|
|
148
|
+
params,
|
|
149
|
+
self._partial_result,
|
|
150
|
+
)
|
|
151
|
+
if daal_check_version((2024, "P", 1)) or (not self.bias):
|
|
152
|
+
self.covariance_ = from_table(result.cov_matrix)
|
|
153
|
+
else:
|
|
154
|
+
n_rows = self._partial_result.partial_n_rows
|
|
155
|
+
self.covariance_ = from_table(result.cov_matrix) * (n_rows - 1) / n_rows
|
|
156
|
+
|
|
157
|
+
self.location_ = from_table(result.means).ravel()
|
|
158
|
+
|
|
159
|
+
self._need_to_finalize = False
|
|
160
|
+
|
|
161
|
+
return self
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# ===============================================================================
|
|
2
|
+
# Copyright 2023 Intel Corporation
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ===============================================================================
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
from numpy.testing import assert_allclose
|
|
20
|
+
|
|
21
|
+
from onedal.tests.utils._device_selection import get_queues
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.mark.parametrize("queue", get_queues())
|
|
25
|
+
def test_onedal_import_covariance(queue):
|
|
26
|
+
from onedal.covariance import EmpiricalCovariance
|
|
27
|
+
|
|
28
|
+
X = np.array([[0, 1], [0, 1]])
|
|
29
|
+
result = EmpiricalCovariance().fit(X, queue=queue)
|
|
30
|
+
expected_covariance = np.array([[0, 0], [0, 0]])
|
|
31
|
+
expected_means = np.array([0, 1])
|
|
32
|
+
|
|
33
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
34
|
+
assert_allclose(expected_means, result.location_)
|
|
35
|
+
|
|
36
|
+
X = np.array([[1, 2], [3, 6]])
|
|
37
|
+
result = EmpiricalCovariance().fit(X, queue=queue)
|
|
38
|
+
expected_covariance = np.array([[2, 4], [4, 8]])
|
|
39
|
+
expected_means = np.array([2, 4])
|
|
40
|
+
|
|
41
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
42
|
+
assert_allclose(expected_means, result.location_)
|
|
43
|
+
|
|
44
|
+
X = np.array([[1, 2], [3, 6]])
|
|
45
|
+
result = EmpiricalCovariance(bias=True).fit(X, queue=queue)
|
|
46
|
+
expected_covariance = np.array([[1, 2], [2, 4]])
|
|
47
|
+
expected_means = np.array([2, 4])
|
|
48
|
+
|
|
49
|
+
assert_allclose(expected_covariance, result.covariance_)
|
|
50
|
+
assert_allclose(expected_means, result.location_)
|