scikit-learn-intelex 2023.2.1__py38-none-win_amd64.whl → 2024.0.1__py38-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of scikit-learn-intelex might be problematic. Click here for more details.
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
- {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
- scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
- scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
- scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
- scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
- {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
- {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
|
-
|
|
2
|
+
# ===============================================================================
|
|
3
3
|
# Copyright 2021 Intel Corporation
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,85 +13,128 @@
|
|
|
13
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
14
|
# See the License for the specific language governing permissions and
|
|
15
15
|
# limitations under the License.
|
|
16
|
-
|
|
16
|
+
# ===============================================================================
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
19
|
from packaging.version import Version
|
|
20
20
|
except ImportError:
|
|
21
21
|
from distutils.version import LooseVersion as Version
|
|
22
|
-
|
|
23
|
-
from daal4py.sklearn._utils import sklearn_check_version
|
|
22
|
+
|
|
24
23
|
import warnings
|
|
25
24
|
|
|
26
|
-
|
|
25
|
+
import numpy as np
|
|
26
|
+
from sklearn import __version__ as sklearn_version
|
|
27
27
|
from sklearn.neighbors._ball_tree import BallTree
|
|
28
|
-
from sklearn.neighbors._kd_tree import KDTree
|
|
29
28
|
from sklearn.neighbors._base import VALID_METRICS
|
|
30
|
-
from sklearn.neighbors.
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
|
|
30
|
+
from sklearn.neighbors._kd_tree import KDTree
|
|
31
|
+
from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
|
|
33
32
|
from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
|
|
34
33
|
|
|
35
|
-
from
|
|
34
|
+
from daal4py.sklearn._utils import sklearn_check_version
|
|
36
35
|
from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors
|
|
36
|
+
from onedal.utils import _check_array, _num_features, _num_samples
|
|
37
37
|
|
|
38
|
-
from .common import KNeighborsDispatchingBase
|
|
39
38
|
from .._device_offload import dispatch, wrap_output_data
|
|
40
|
-
|
|
39
|
+
from .common import KNeighborsDispatchingBase
|
|
41
40
|
|
|
41
|
+
if sklearn_check_version("0.22") and Version(sklearn_version) < Version("0.23"):
|
|
42
42
|
|
|
43
|
-
if sklearn_check_version("0.22") and \
|
|
44
|
-
Version(sklearn_version) < Version("0.23"):
|
|
45
43
|
class NearestNeighbors_(sklearn_NearestNeighbors):
|
|
46
|
-
def __init__(
|
|
47
|
-
|
|
48
|
-
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
n_neighbors=5,
|
|
47
|
+
radius=1.0,
|
|
48
|
+
algorithm="auto",
|
|
49
|
+
leaf_size=30,
|
|
50
|
+
metric="minkowski",
|
|
51
|
+
p=2,
|
|
52
|
+
metric_params=None,
|
|
53
|
+
n_jobs=None,
|
|
54
|
+
):
|
|
49
55
|
super().__init__(
|
|
50
56
|
n_neighbors=n_neighbors,
|
|
51
57
|
radius=radius,
|
|
52
58
|
algorithm=algorithm,
|
|
53
|
-
leaf_size=leaf_size,
|
|
54
|
-
|
|
59
|
+
leaf_size=leaf_size,
|
|
60
|
+
metric=metric,
|
|
61
|
+
p=p,
|
|
62
|
+
metric_params=metric_params,
|
|
63
|
+
n_jobs=n_jobs,
|
|
64
|
+
)
|
|
65
|
+
|
|
55
66
|
else:
|
|
67
|
+
|
|
56
68
|
class NearestNeighbors_(sklearn_NearestNeighbors):
|
|
57
|
-
if sklearn_check_version(
|
|
69
|
+
if sklearn_check_version("1.2"):
|
|
58
70
|
_parameter_constraints: dict = {
|
|
59
|
-
**sklearn_NearestNeighbors._parameter_constraints
|
|
71
|
+
**sklearn_NearestNeighbors._parameter_constraints
|
|
72
|
+
}
|
|
60
73
|
|
|
61
74
|
@_deprecate_positional_args
|
|
62
|
-
def __init__(
|
|
63
|
-
|
|
64
|
-
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
*,
|
|
78
|
+
n_neighbors=5,
|
|
79
|
+
radius=1.0,
|
|
80
|
+
algorithm="auto",
|
|
81
|
+
leaf_size=30,
|
|
82
|
+
metric="minkowski",
|
|
83
|
+
p=2,
|
|
84
|
+
metric_params=None,
|
|
85
|
+
n_jobs=None,
|
|
86
|
+
):
|
|
65
87
|
super().__init__(
|
|
66
88
|
n_neighbors=n_neighbors,
|
|
67
89
|
radius=radius,
|
|
68
90
|
algorithm=algorithm,
|
|
69
|
-
leaf_size=leaf_size,
|
|
70
|
-
|
|
91
|
+
leaf_size=leaf_size,
|
|
92
|
+
metric=metric,
|
|
93
|
+
p=p,
|
|
94
|
+
metric_params=metric_params,
|
|
95
|
+
n_jobs=n_jobs,
|
|
96
|
+
)
|
|
71
97
|
|
|
72
98
|
|
|
73
99
|
class NearestNeighbors(NearestNeighbors_, KNeighborsDispatchingBase):
|
|
74
|
-
if sklearn_check_version(
|
|
75
|
-
_parameter_constraints: dict = {
|
|
76
|
-
**NearestNeighbors_._parameter_constraints}
|
|
100
|
+
if sklearn_check_version("1.2"):
|
|
101
|
+
_parameter_constraints: dict = {**NearestNeighbors_._parameter_constraints}
|
|
77
102
|
|
|
78
103
|
@_deprecate_positional_args
|
|
79
|
-
def __init__(
|
|
80
|
-
|
|
81
|
-
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
n_neighbors=5,
|
|
107
|
+
radius=1.0,
|
|
108
|
+
algorithm="auto",
|
|
109
|
+
leaf_size=30,
|
|
110
|
+
metric="minkowski",
|
|
111
|
+
p=2,
|
|
112
|
+
metric_params=None,
|
|
113
|
+
n_jobs=None,
|
|
114
|
+
):
|
|
82
115
|
super().__init__(
|
|
83
116
|
n_neighbors=n_neighbors,
|
|
84
117
|
radius=radius,
|
|
85
118
|
algorithm=algorithm,
|
|
86
|
-
leaf_size=leaf_size,
|
|
87
|
-
|
|
119
|
+
leaf_size=leaf_size,
|
|
120
|
+
metric=metric,
|
|
121
|
+
p=p,
|
|
122
|
+
metric_params=metric_params,
|
|
123
|
+
n_jobs=n_jobs,
|
|
124
|
+
)
|
|
88
125
|
|
|
89
126
|
def fit(self, X, y=None):
|
|
90
127
|
self._fit_validation(X, y)
|
|
91
|
-
dispatch(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
128
|
+
dispatch(
|
|
129
|
+
self,
|
|
130
|
+
"fit",
|
|
131
|
+
{
|
|
132
|
+
"onedal": self.__class__._onedal_fit,
|
|
133
|
+
"sklearn": sklearn_NearestNeighbors.fit,
|
|
134
|
+
},
|
|
135
|
+
X,
|
|
136
|
+
None,
|
|
137
|
+
)
|
|
95
138
|
return self
|
|
96
139
|
|
|
97
140
|
@wrap_output_data
|
|
@@ -99,37 +142,50 @@ class NearestNeighbors(NearestNeighbors_, KNeighborsDispatchingBase):
|
|
|
99
142
|
check_is_fitted(self)
|
|
100
143
|
if sklearn_check_version("1.0") and X is not None:
|
|
101
144
|
self._check_feature_names(X, reset=False)
|
|
102
|
-
return dispatch(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
145
|
+
return dispatch(
|
|
146
|
+
self,
|
|
147
|
+
"kneighbors",
|
|
148
|
+
{
|
|
149
|
+
"onedal": self.__class__._onedal_kneighbors,
|
|
150
|
+
"sklearn": sklearn_NearestNeighbors.kneighbors,
|
|
151
|
+
},
|
|
152
|
+
X,
|
|
153
|
+
n_neighbors,
|
|
154
|
+
return_distance,
|
|
155
|
+
)
|
|
106
156
|
|
|
107
157
|
@wrap_output_data
|
|
108
|
-
def radius_neighbors(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
158
|
+
def radius_neighbors(
|
|
159
|
+
self, X=None, radius=None, return_distance=True, sort_results=False
|
|
160
|
+
):
|
|
161
|
+
_onedal_estimator = getattr(self, "_onedal_estimator", None)
|
|
162
|
+
|
|
163
|
+
if (
|
|
164
|
+
_onedal_estimator is not None
|
|
165
|
+
or getattr(self, "_tree", 0) is None
|
|
166
|
+
and self._fit_method == "kd_tree"
|
|
167
|
+
):
|
|
114
168
|
if sklearn_check_version("0.24"):
|
|
115
|
-
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self,
|
|
169
|
+
sklearn_NearestNeighbors.fit(self, self._fit_X, getattr(self, "_y", None))
|
|
116
170
|
else:
|
|
117
171
|
sklearn_NearestNeighbors.fit(self, self._fit_X)
|
|
118
172
|
if sklearn_check_version("0.22"):
|
|
119
173
|
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
120
|
-
self, X, radius, return_distance, sort_results
|
|
174
|
+
self, X, radius, return_distance, sort_results
|
|
175
|
+
)
|
|
121
176
|
else:
|
|
122
177
|
result = sklearn_NearestNeighbors.radius_neighbors(
|
|
123
|
-
self, X, radius, return_distance
|
|
178
|
+
self, X, radius, return_distance
|
|
179
|
+
)
|
|
124
180
|
|
|
125
181
|
return result
|
|
126
182
|
|
|
127
183
|
def _onedal_fit(self, X, y=None, queue=None):
|
|
128
184
|
onedal_params = {
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
185
|
+
"n_neighbors": self.n_neighbors,
|
|
186
|
+
"algorithm": self.algorithm,
|
|
187
|
+
"metric": self.effective_metric_,
|
|
188
|
+
"p": self.effective_metric_params_["p"],
|
|
133
189
|
}
|
|
134
190
|
|
|
135
191
|
try:
|
|
@@ -148,10 +204,12 @@ class NearestNeighbors(NearestNeighbors_, KNeighborsDispatchingBase):
|
|
|
148
204
|
def _onedal_predict(self, X, queue=None):
|
|
149
205
|
return self._onedal_estimator.predict(X, queue=queue)
|
|
150
206
|
|
|
151
|
-
def _onedal_kneighbors(
|
|
152
|
-
|
|
207
|
+
def _onedal_kneighbors(
|
|
208
|
+
self, X=None, n_neighbors=None, return_distance=True, queue=None
|
|
209
|
+
):
|
|
153
210
|
return self._onedal_estimator.kneighbors(
|
|
154
|
-
X, n_neighbors, return_distance, queue=queue
|
|
211
|
+
X, n_neighbors, return_distance, queue=queue
|
|
212
|
+
)
|
|
155
213
|
|
|
156
214
|
def _save_attributes(self):
|
|
157
215
|
self.classes_ = self._onedal_estimator.classes_
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
#!/usr/bin/env python
|
|
2
|
-
|
|
2
|
+
# ===============================================================================
|
|
3
3
|
# Copyright 2023 Intel Corporation
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
14
|
# See the License for the specific language governing permissions and
|
|
15
15
|
# limitations under the License.
|
|
16
|
-
|
|
16
|
+
# ===============================================================================
|
|
17
17
|
|
|
18
|
-
import numpy as np
|
|
19
18
|
import warnings
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
import numpy as np
|
|
21
|
+
from sklearn.neighbors._lof import LocalOutlierFactor as sklearn_LocalOutlierFactor
|
|
22
|
+
|
|
23
23
|
from .knn_unsupervised import NearestNeighbors
|
|
24
24
|
|
|
25
25
|
try:
|
|
@@ -27,18 +27,22 @@ try:
|
|
|
27
27
|
except ImportError:
|
|
28
28
|
pass
|
|
29
29
|
|
|
30
|
-
from sklearn.utils.validation import check_is_fitted
|
|
31
30
|
from sklearn.utils import check_array
|
|
31
|
+
from sklearn.utils.validation import check_is_fitted
|
|
32
32
|
|
|
33
33
|
from daal4py.sklearn._utils import sklearn_check_version
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
from .._config import config_context
|
|
36
|
+
from .._device_offload import dispatch, wrap_output_data
|
|
37
|
+
from .._utils import PatchingConditionsChain
|
|
36
38
|
|
|
37
39
|
if sklearn_check_version("1.0"):
|
|
40
|
+
|
|
38
41
|
class LocalOutlierFactor(sklearn_LocalOutlierFactor):
|
|
39
|
-
if sklearn_check_version(
|
|
42
|
+
if sklearn_check_version("1.2"):
|
|
40
43
|
_parameter_constraints: dict = {
|
|
41
|
-
**sklearn_LocalOutlierFactor._parameter_constraints
|
|
44
|
+
**sklearn_LocalOutlierFactor._parameter_constraints
|
|
45
|
+
}
|
|
42
46
|
|
|
43
47
|
def __init__(
|
|
44
48
|
self,
|
|
@@ -62,7 +66,7 @@ if sklearn_check_version("1.0"):
|
|
|
62
66
|
metric_params=metric_params,
|
|
63
67
|
n_jobs=n_jobs,
|
|
64
68
|
contamination=contamination,
|
|
65
|
-
novelty=novelty
|
|
69
|
+
novelty=novelty,
|
|
66
70
|
)
|
|
67
71
|
|
|
68
72
|
def _fit(self, X, y, queue=None):
|
|
@@ -76,7 +80,7 @@ if sklearn_check_version("1.0"):
|
|
|
76
80
|
metric=self.metric,
|
|
77
81
|
p=self.p,
|
|
78
82
|
metric_params=self.metric_params,
|
|
79
|
-
n_jobs=self.n_jobs
|
|
83
|
+
n_jobs=self.n_jobs,
|
|
80
84
|
)
|
|
81
85
|
self._knn.fit(X)
|
|
82
86
|
|
|
@@ -98,8 +102,9 @@ if sklearn_check_version("1.0"):
|
|
|
98
102
|
)
|
|
99
103
|
self.n_neighbors_ = max(1, min(self.n_neighbors, n_samples - 1))
|
|
100
104
|
|
|
101
|
-
self._distances_fit_X_, _neighbors_indices_fit_X_
|
|
102
|
-
|
|
105
|
+
self._distances_fit_X_, _neighbors_indices_fit_X_ = self._knn.kneighbors(
|
|
106
|
+
n_neighbors=self.n_neighbors_
|
|
107
|
+
)
|
|
103
108
|
|
|
104
109
|
self._lrd = self._local_reachability_density(
|
|
105
110
|
self._distances_fit_X_, _neighbors_indices_fit_X_
|
|
@@ -127,10 +132,16 @@ if sklearn_check_version("1.0"):
|
|
|
127
132
|
return self
|
|
128
133
|
|
|
129
134
|
def fit(self, X, y=None):
|
|
130
|
-
return dispatch(
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
135
|
+
return dispatch(
|
|
136
|
+
self,
|
|
137
|
+
"neighbors.LocalOutlierFactor.fit",
|
|
138
|
+
{
|
|
139
|
+
"onedal": self.__class__._fit,
|
|
140
|
+
"sklearn": None,
|
|
141
|
+
},
|
|
142
|
+
X,
|
|
143
|
+
y,
|
|
144
|
+
)
|
|
134
145
|
|
|
135
146
|
def _onedal_predict(self, X, queue=None):
|
|
136
147
|
with config_context(target_offload=queue):
|
|
@@ -148,10 +159,15 @@ if sklearn_check_version("1.0"):
|
|
|
148
159
|
|
|
149
160
|
@wrap_output_data
|
|
150
161
|
def _predict(self, X=None):
|
|
151
|
-
return dispatch(
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
162
|
+
return dispatch(
|
|
163
|
+
self,
|
|
164
|
+
"neighbors.LocalOutlierFactor.predict",
|
|
165
|
+
{
|
|
166
|
+
"onedal": self.__class__._onedal_predict,
|
|
167
|
+
"sklearn": None,
|
|
168
|
+
},
|
|
169
|
+
X,
|
|
170
|
+
)
|
|
155
171
|
|
|
156
172
|
def _score_samples(self, X, queue=None):
|
|
157
173
|
with config_context(target_offload=queue):
|
|
@@ -183,10 +199,15 @@ if sklearn_check_version("1.0"):
|
|
|
183
199
|
@available_if(_check_novelty_score_samples)
|
|
184
200
|
@wrap_output_data
|
|
185
201
|
def score_samples(self, X):
|
|
186
|
-
return dispatch(
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
202
|
+
return dispatch(
|
|
203
|
+
self,
|
|
204
|
+
"neighbors.LocalOutlierFactor.score_samples",
|
|
205
|
+
{
|
|
206
|
+
"onedal": self.__class__._score_samples,
|
|
207
|
+
"sklearn": None,
|
|
208
|
+
},
|
|
209
|
+
X,
|
|
210
|
+
)
|
|
190
211
|
|
|
191
212
|
def _check_novelty_fit_predict(self):
|
|
192
213
|
if self.novelty:
|
|
@@ -204,17 +225,33 @@ if sklearn_check_version("1.0"):
|
|
|
204
225
|
@available_if(_check_novelty_fit_predict)
|
|
205
226
|
@wrap_output_data
|
|
206
227
|
def fit_predict(self, X, y=None):
|
|
207
|
-
return dispatch(
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
228
|
+
return dispatch(
|
|
229
|
+
self,
|
|
230
|
+
"neighbors.LocalOutlierFactor.fit_predict",
|
|
231
|
+
{
|
|
232
|
+
"onedal": self.__class__._fit_predict,
|
|
233
|
+
"sklearn": None,
|
|
234
|
+
},
|
|
235
|
+
X,
|
|
236
|
+
y,
|
|
237
|
+
)
|
|
211
238
|
|
|
212
239
|
def _onedal_gpu_supported(self, method_name, *data):
|
|
213
|
-
|
|
240
|
+
class_name = self.__class__.__name__
|
|
241
|
+
patching_status = PatchingConditionsChain(
|
|
242
|
+
f"sklearn.neighbors.{class_name}.{method_name}"
|
|
243
|
+
)
|
|
244
|
+
return patching_status
|
|
214
245
|
|
|
215
246
|
def _onedal_cpu_supported(self, method_name, *data):
|
|
216
|
-
|
|
247
|
+
class_name = self.__class__.__name__
|
|
248
|
+
patching_status = PatchingConditionsChain(
|
|
249
|
+
f"sklearn.neighbors.{class_name}.{method_name}"
|
|
250
|
+
)
|
|
251
|
+
return patching_status
|
|
252
|
+
|
|
217
253
|
else:
|
|
254
|
+
|
|
218
255
|
class LocalOutlierFactor(sklearn_LocalOutlierFactor):
|
|
219
256
|
def __init__(
|
|
220
257
|
self,
|
|
@@ -238,7 +275,7 @@ else:
|
|
|
238
275
|
metric_params=metric_params,
|
|
239
276
|
n_jobs=n_jobs,
|
|
240
277
|
contamination=contamination,
|
|
241
|
-
novelty=novelty
|
|
278
|
+
novelty=novelty,
|
|
242
279
|
)
|
|
243
280
|
|
|
244
281
|
def _fit(self, X, y=None, queue=None):
|
|
@@ -250,7 +287,7 @@ else:
|
|
|
250
287
|
metric=self.metric,
|
|
251
288
|
p=self.p,
|
|
252
289
|
metric_params=self.metric_params,
|
|
253
|
-
n_jobs=self.n_jobs
|
|
290
|
+
n_jobs=self.n_jobs,
|
|
254
291
|
)
|
|
255
292
|
self._knn.fit(X)
|
|
256
293
|
|
|
@@ -272,8 +309,9 @@ else:
|
|
|
272
309
|
)
|
|
273
310
|
self.n_neighbors_ = max(1, min(self.n_neighbors, n_samples - 1))
|
|
274
311
|
|
|
275
|
-
self._distances_fit_X_, _neighbors_indices_fit_X_
|
|
276
|
-
|
|
312
|
+
self._distances_fit_X_, _neighbors_indices_fit_X_ = self._knn.kneighbors(
|
|
313
|
+
n_neighbors=self.n_neighbors_
|
|
314
|
+
)
|
|
277
315
|
|
|
278
316
|
self._lrd = self._local_reachability_density(
|
|
279
317
|
self._distances_fit_X_, _neighbors_indices_fit_X_
|
|
@@ -301,10 +339,16 @@ else:
|
|
|
301
339
|
return self
|
|
302
340
|
|
|
303
341
|
def fit(self, X, y=None):
|
|
304
|
-
return dispatch(
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
342
|
+
return dispatch(
|
|
343
|
+
self,
|
|
344
|
+
"neighbors.LocalOutlierFactor.fit",
|
|
345
|
+
{
|
|
346
|
+
"onedal": self.__class__._fit,
|
|
347
|
+
"sklearn": None,
|
|
348
|
+
},
|
|
349
|
+
X,
|
|
350
|
+
y,
|
|
351
|
+
)
|
|
308
352
|
|
|
309
353
|
def _onedal_predict(self, X, queue=None):
|
|
310
354
|
with config_context(target_offload=queue):
|
|
@@ -322,10 +366,15 @@ else:
|
|
|
322
366
|
|
|
323
367
|
@wrap_output_data
|
|
324
368
|
def _predict(self, X=None):
|
|
325
|
-
return dispatch(
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
369
|
+
return dispatch(
|
|
370
|
+
self,
|
|
371
|
+
"neighbors.LocalOutlierFactor.predict",
|
|
372
|
+
{
|
|
373
|
+
"onedal": self.__class__._onedal_predict,
|
|
374
|
+
"sklearn": None,
|
|
375
|
+
},
|
|
376
|
+
X,
|
|
377
|
+
)
|
|
329
378
|
|
|
330
379
|
def _onedal_score_samples(self, X, queue=None):
|
|
331
380
|
with config_context(target_offload=queue):
|
|
@@ -345,17 +394,24 @@ else:
|
|
|
345
394
|
@wrap_output_data
|
|
346
395
|
def _score_samples(self, X):
|
|
347
396
|
if not self.novelty:
|
|
348
|
-
msg = (
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
397
|
+
msg = (
|
|
398
|
+
"score_samples is not available when novelty=False. The "
|
|
399
|
+
"scores of the training samples are always available "
|
|
400
|
+
"through the negative_outlier_factor_ attribute. Use "
|
|
401
|
+
"novelty=True if you want to use LOF for novelty detection "
|
|
402
|
+
"and compute score_samples for new unseen data."
|
|
403
|
+
)
|
|
353
404
|
raise AttributeError(msg)
|
|
354
405
|
|
|
355
|
-
return dispatch(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
406
|
+
return dispatch(
|
|
407
|
+
self,
|
|
408
|
+
"neighbors.LocalOutlierFactor.score_samples",
|
|
409
|
+
{
|
|
410
|
+
"onedal": self.__class__._onedal_score_samples,
|
|
411
|
+
"sklearn": None,
|
|
412
|
+
},
|
|
413
|
+
X,
|
|
414
|
+
)
|
|
359
415
|
|
|
360
416
|
def _onedal_fit_predict(self, X, y, queue=None):
|
|
361
417
|
with config_context(target_offload=queue):
|
|
@@ -363,10 +419,16 @@ else:
|
|
|
363
419
|
|
|
364
420
|
@wrap_output_data
|
|
365
421
|
def _fit_predict(self, X, y=None):
|
|
366
|
-
return dispatch(
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
422
|
+
return dispatch(
|
|
423
|
+
self,
|
|
424
|
+
"neighbors.LocalOutlierFactor._onedal_fit_predict",
|
|
425
|
+
{
|
|
426
|
+
"onedal": self.__class__._onedal_fit_predict,
|
|
427
|
+
"sklearn": None,
|
|
428
|
+
},
|
|
429
|
+
X,
|
|
430
|
+
y,
|
|
431
|
+
)
|
|
370
432
|
|
|
371
433
|
def _onedal_gpu_supported(self, method_name, *data):
|
|
372
434
|
return True
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# ===============================================================================
|
|
3
|
+
# Copyright 2021 Intel Corporation
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ===============================================================================
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pytest
|
|
20
|
+
from numpy.testing import assert_allclose
|
|
21
|
+
|
|
22
|
+
from onedal.tests.utils._dataframes_support import (
|
|
23
|
+
_as_numpy,
|
|
24
|
+
_convert_to_dataframe,
|
|
25
|
+
get_dataframes_and_queues,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
30
|
+
def test_sklearnex_import_knn_classifier(dataframe, queue):
|
|
31
|
+
from sklearnex.neighbors import KNeighborsClassifier
|
|
32
|
+
|
|
33
|
+
X = _convert_to_dataframe([[0], [1], [2], [3]], sycl_queue=queue, target_df=dataframe)
|
|
34
|
+
y = _convert_to_dataframe([0, 0, 1, 1], sycl_queue=queue, target_df=dataframe)
|
|
35
|
+
neigh = KNeighborsClassifier(n_neighbors=3).fit(X, y)
|
|
36
|
+
y_test = _convert_to_dataframe([[1.1]], sycl_queue=queue, target_df=dataframe)
|
|
37
|
+
pred = _as_numpy(neigh.predict(y_test))
|
|
38
|
+
assert "sklearnex" in neigh.__module__
|
|
39
|
+
assert_allclose(pred, [0])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
43
|
+
def test_sklearnex_import_knn_regression(dataframe, queue):
|
|
44
|
+
from sklearnex.neighbors import KNeighborsRegressor
|
|
45
|
+
|
|
46
|
+
X = _convert_to_dataframe([[0], [1], [2], [3]], sycl_queue=queue, target_df=dataframe)
|
|
47
|
+
y = _convert_to_dataframe([0, 0, 1, 1], sycl_queue=queue, target_df=dataframe)
|
|
48
|
+
neigh = KNeighborsRegressor(n_neighbors=2).fit(X, y)
|
|
49
|
+
y_test = _convert_to_dataframe([[1.5]], sycl_queue=queue, target_df=dataframe)
|
|
50
|
+
pred = _as_numpy(neigh.predict(y_test))
|
|
51
|
+
assert "sklearnex" in neigh.__module__
|
|
52
|
+
assert_allclose(pred, [0.5])
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# TODO:
|
|
56
|
+
# investigate failure for `dpnp.ndarrays` and `dpctl.tensors`.
|
|
57
|
+
@pytest.mark.parametrize(
|
|
58
|
+
"dataframe,queue", get_dataframes_and_queues(dataframe_filter_="numpy")
|
|
59
|
+
)
|
|
60
|
+
def test_sklearnex_import_nn(dataframe, queue):
|
|
61
|
+
from sklearnex.neighbors import NearestNeighbors
|
|
62
|
+
|
|
63
|
+
X = [[0, 0, 2], [1, 0, 0], [0, 0, 1]]
|
|
64
|
+
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
|
|
65
|
+
test = _convert_to_dataframe([[0, 0, 1.3]], sycl_queue=queue, target_df=dataframe)
|
|
66
|
+
neigh = NearestNeighbors(n_neighbors=2).fit(X)
|
|
67
|
+
result = neigh.kneighbors(test, 2, return_distance=False)
|
|
68
|
+
result = _as_numpy(result)
|
|
69
|
+
assert "sklearnex" in neigh.__module__
|
|
70
|
+
assert_allclose(result, [[2, 0]])
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
|
|
74
|
+
def test_sklearnex_import_lof(dataframe, queue):
|
|
75
|
+
from sklearnex.neighbors import LocalOutlierFactor
|
|
76
|
+
|
|
77
|
+
X = [[7, 7, 7], [1, 0, 0], [0, 0, 1], [0, 0, 1]]
|
|
78
|
+
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
|
|
79
|
+
lof = LocalOutlierFactor(n_neighbors=2)
|
|
80
|
+
result = lof.fit_predict(X)
|
|
81
|
+
result = _as_numpy(result)
|
|
82
|
+
assert hasattr(lof, "_knn")
|
|
83
|
+
assert "sklearnex" in lof.__module__
|
|
84
|
+
assert "sklearnex" in lof._knn.__module__
|
|
85
|
+
assert_allclose(result, [-1, 1, 1, 1])
|