scikit-learn-intelex 2024.0.1__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (90) hide show
  1. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__init__.py +61 -0
  2. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/__main__.py +59 -0
  3. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_config.py +110 -0
  4. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_device_offload.py +223 -0
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  7. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +17 -0
  8. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +21 -0
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +18 -0
  11. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +37 -0
  12. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +31 -0
  13. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +20 -0
  14. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +18 -0
  15. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +28 -0
  16. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/dispatcher.py +329 -0
  17. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +424 -0
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +30 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  20. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  21. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/__main__.py +73 -0
  22. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +88 -0
  23. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +30 -0
  24. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +18 -0
  25. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +373 -0
  26. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +18 -0
  27. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +18 -0
  28. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +77 -0
  29. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +29 -0
  30. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +20 -0
  31. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +18 -0
  32. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +27 -0
  33. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +24 -0
  34. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +18 -0
  35. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +18 -0
  36. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +40 -0
  37. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +22 -0
  38. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/split.py +18 -0
  39. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +35 -0
  40. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +28 -0
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/common.py +264 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  43. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  44. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +220 -0
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +437 -0
  46. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  47. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +18 -0
  48. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +20 -0
  49. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +84 -0
  50. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +370 -0
  51. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +20 -0
  52. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +376 -0
  53. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +38 -0
  54. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +24 -0
  55. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +19 -0
  56. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +21 -0
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  58. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  59. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  60. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +19 -0
  61. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +21 -0
  62. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  63. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +79 -0
  64. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +19 -0
  65. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +21 -0
  66. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +19 -0
  67. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +25 -0
  68. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/__init__.py +30 -0
  69. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/_common.py +188 -0
  70. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +272 -0
  71. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +163 -0
  72. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svc.py +301 -0
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/svr.py +164 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  75. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  76. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_config.py +39 -0
  77. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +225 -0
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +210 -0
  79. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +122 -0
  81. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  82. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +118 -0
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  84. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  85. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/validation.py +18 -0
  86. scikit_learn_intelex-2024.0.1.dist-info/LICENSE.txt +202 -0
  87. scikit_learn_intelex-2024.0.1.dist-info/METADATA +230 -0
  88. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  89. scikit_learn_intelex-2024.0.1.dist-info/WHEEL +5 -0
  90. scikit_learn_intelex-2024.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ from . import utils
19
+ from ._config import config_context, get_config, set_config
20
+ from .dispatcher import (
21
+ get_patch_map,
22
+ get_patch_names,
23
+ is_patched_instance,
24
+ patch_sklearn,
25
+ sklearn_is_patched,
26
+ unpatch_sklearn,
27
+ )
28
+
29
+ __all__ = [
30
+ "basic_statistics",
31
+ "cluster",
32
+ "config_context",
33
+ "decomposition",
34
+ "ensemble",
35
+ "get_config",
36
+ "get_patch_map",
37
+ "get_patch_names",
38
+ "is_patched_instance",
39
+ "linear_model",
40
+ "manifold",
41
+ "metrics",
42
+ "neighbors",
43
+ "patch_sklearn",
44
+ "set_config",
45
+ "sklearn_is_patched",
46
+ "sklearn_is_patchedget_patch_map",
47
+ "svm",
48
+ "unpatch_sklearn",
49
+ "utils",
50
+ ]
51
+
52
+
53
+ from onedal import _is_dpc_backend
54
+
55
+ if _is_dpc_backend:
56
+ __all__.append("spmd")
57
+
58
+
59
+ from ._utils import set_sklearn_ex_verbose
60
+
61
+ set_sklearn_ex_verbose()
@@ -0,0 +1,59 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ import sys
19
+
20
+ from sklearnex import patch_sklearn
21
+
22
+
23
+ def _main():
24
+ import argparse
25
+
26
+ parser = argparse.ArgumentParser(
27
+ prog="python -m sklearnex",
28
+ description="""
29
+ Run your Python script with Intel(R) Extension for
30
+ scikit-learn, optimizing solvers of
31
+ scikit-learn with Intel(R) oneAPI Data Analytics Library.
32
+ """,
33
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
34
+ )
35
+
36
+ parser.add_argument(
37
+ "-m", action="store_true", dest="module", help="Executes following as a module"
38
+ )
39
+ parser.add_argument("name", help="Script or module name")
40
+ parser.add_argument("args", nargs=argparse.REMAINDER, help="Command line arguments")
41
+ args = parser.parse_args()
42
+
43
+ try:
44
+ import sklearn
45
+
46
+ patch_sklearn()
47
+ except ImportError:
48
+ print("Scikit-learn could not be imported. Nothing to patch")
49
+
50
+ sys.argv = [args.name] + args.args
51
+ if "_" + args.name in globals():
52
+ return globals()["_" + args.name](*args.args)
53
+ import runpy
54
+
55
+ runf = runpy.run_module if args.module else runpy.run_path
56
+ runf(args.name, run_name="__main__")
57
+
58
+
59
+ sys.exit(_main())
@@ -0,0 +1,110 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import threading
18
+ from contextlib import contextmanager
19
+
20
+ from sklearn import get_config as skl_get_config
21
+ from sklearn import set_config as skl_set_config
22
+
23
+ _default_global_config = {
24
+ "target_offload": "auto",
25
+ "allow_fallback_to_host": False,
26
+ }
27
+
28
+ _threadlocal = threading.local()
29
+
30
+
31
+ def _get_sklearnex_threadlocal_config():
32
+ if not hasattr(_threadlocal, "global_config"):
33
+ _threadlocal.global_config = _default_global_config.copy()
34
+ return _threadlocal.global_config
35
+
36
+
37
+ def get_config():
38
+ """Retrieve current values for configuration set by :func:`set_config`
39
+ Returns
40
+ -------
41
+ config : dict
42
+ Keys are parameter names that can be passed to :func:`set_config`.
43
+ See Also
44
+ --------
45
+ config_context : Context manager for global configuration.
46
+ set_config : Set global configuration.
47
+ """
48
+ sklearn = skl_get_config()
49
+ sklearnex = _get_sklearnex_threadlocal_config().copy()
50
+ return {**sklearn, **sklearnex}
51
+
52
+
53
+ def set_config(target_offload=None, allow_fallback_to_host=None, **sklearn_configs):
54
+ """Set global configuration
55
+ Parameters
56
+ ----------
57
+ target_offload : string or dpctl.SyclQueue, default=None
58
+ The device primarily used to perform computations.
59
+ If string, expected to be "auto" (the execution context
60
+ is deduced from input data location),
61
+ or SYCL* filter selector string. Global default: "auto".
62
+ allow_fallback_to_host : bool, default=None
63
+ If True, allows to fallback computation to host device
64
+ in case particular estimator does not support the selected one.
65
+ Global default: False.
66
+ See Also
67
+ --------
68
+ config_context : Context manager for global configuration.
69
+ get_config : Retrieve current values of the global configuration.
70
+ """
71
+ skl_set_config(**sklearn_configs)
72
+
73
+ local_config = _get_sklearnex_threadlocal_config()
74
+
75
+ if target_offload is not None:
76
+ local_config["target_offload"] = target_offload
77
+ if allow_fallback_to_host is not None:
78
+ local_config["allow_fallback_to_host"] = allow_fallback_to_host
79
+
80
+
81
+ @contextmanager
82
+ def config_context(**new_config):
83
+ """Context manager for global scikit-learn configuration
84
+ Parameters
85
+ ----------
86
+ target_offload : string or dpctl.SyclQueue, default=None
87
+ The device primarily used to perform computations.
88
+ If string, expected to be "auto" (the execution context
89
+ is deduced from input data location),
90
+ or SYCL* filter selector string. Global default: "auto".
91
+ allow_fallback_to_host : bool, default=None
92
+ If True, allows to fallback computation to host device
93
+ in case particular estimator does not support the selected one.
94
+ Global default: False.
95
+ Notes
96
+ -----
97
+ All settings, not just those presently modified, will be returned to
98
+ their previous values when the context manager is exited.
99
+ See Also
100
+ --------
101
+ set_config : Set global scikit-learn configuration.
102
+ get_config : Retrieve current values of the global configuration.
103
+ """
104
+ old_config = get_config()
105
+ set_config(**new_config)
106
+
107
+ try:
108
+ yield
109
+ finally:
110
+ set_config(**old_config)
@@ -0,0 +1,223 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import logging
18
+ import sys
19
+ from functools import wraps
20
+
21
+ import numpy as np
22
+
23
+ try:
24
+ from dpctl import SyclQueue
25
+ from dpctl.memory import MemoryUSMDevice, as_usm_memory
26
+ from dpctl.tensor import usm_ndarray
27
+
28
+ dpctl_available = True
29
+ except ImportError:
30
+ dpctl_available = False
31
+
32
+ try:
33
+ import dpnp
34
+
35
+ dpnp_available = True
36
+ except ImportError:
37
+ dpnp_available = False
38
+
39
+ from ._config import get_config
40
+ from ._utils import get_patch_message
41
+
42
+ oneapi_is_available = "daal4py.oneapi" in sys.modules
43
+ if oneapi_is_available:
44
+ from daal4py.oneapi import _get_device_name_sycl_ctxt, _get_sycl_ctxt_params
45
+
46
+
47
+ class DummySyclQueue:
48
+ """This class is designed to act like dpctl.SyclQueue
49
+ to allow device dispatching in scenarios when dpctl is not available"""
50
+
51
+ class DummySyclDevice:
52
+ def __init__(self, filter_string):
53
+ self._filter_string = filter_string
54
+ self.is_cpu = "cpu" in filter_string
55
+ self.is_gpu = "gpu" in filter_string
56
+ # TODO: check for possibility of fp64 support
57
+ # on other devices in this dummy class
58
+ self.has_aspect_fp64 = self.is_cpu
59
+
60
+ if not (self.is_cpu):
61
+ logging.warning(
62
+ "Device support is limited. "
63
+ "Please install dpctl for full experience"
64
+ )
65
+
66
+ def get_filter_string(self):
67
+ return self._filter_string
68
+
69
+ def __init__(self, filter_string):
70
+ self.sycl_device = self.DummySyclDevice(filter_string)
71
+
72
+
73
+ def _get_device_info_from_daal4py():
74
+ if oneapi_is_available:
75
+ return _get_device_name_sycl_ctxt(), _get_sycl_ctxt_params()
76
+ return None, dict()
77
+
78
+
79
+ def _get_global_queue():
80
+ target = get_config()["target_offload"]
81
+ d4p_target, _ = _get_device_info_from_daal4py()
82
+ if d4p_target == "host":
83
+ d4p_target = "cpu"
84
+
85
+ QueueClass = DummySyclQueue if not dpctl_available else SyclQueue
86
+
87
+ if target != "auto":
88
+ if d4p_target is not None and d4p_target != target:
89
+ if not isinstance(target, str):
90
+ if d4p_target not in target.sycl_device.get_filter_string():
91
+ raise RuntimeError(
92
+ "Cannot use target offload option "
93
+ "inside daal4py.oneapi.sycl_context"
94
+ )
95
+ else:
96
+ raise RuntimeError(
97
+ "Cannot use target offload option "
98
+ "inside daal4py.oneapi.sycl_context"
99
+ )
100
+ if isinstance(target, QueueClass):
101
+ return target
102
+ return QueueClass(target)
103
+ if d4p_target is not None:
104
+ return QueueClass(d4p_target)
105
+ return None
106
+
107
+
108
+ def _transfer_to_host(queue, *data):
109
+ has_usm_data, has_host_data = False, False
110
+
111
+ host_data = []
112
+ for item in data:
113
+ usm_iface = getattr(item, "__sycl_usm_array_interface__", None)
114
+ if usm_iface is not None:
115
+ if not dpctl_available:
116
+ raise RuntimeError(
117
+ "dpctl need to be installed to work "
118
+ "with __sycl_usm_array_interface__"
119
+ )
120
+ if queue is not None:
121
+ if queue.sycl_device != usm_iface["syclobj"].sycl_device:
122
+ raise RuntimeError(
123
+ "Input data shall be located " "on single target device"
124
+ )
125
+ else:
126
+ queue = usm_iface["syclobj"]
127
+
128
+ buffer = as_usm_memory(item).copy_to_host()
129
+ item = np.ndarray(
130
+ shape=usm_iface["shape"], dtype=usm_iface["typestr"], buffer=buffer
131
+ )
132
+ has_usm_data = True
133
+ else:
134
+ has_host_data = True
135
+
136
+ mismatch_host_item = usm_iface is None and item is not None and has_usm_data
137
+ mismatch_usm_item = usm_iface is not None and has_host_data
138
+
139
+ if mismatch_host_item or mismatch_usm_item:
140
+ raise RuntimeError("Input data shall be located on single target device")
141
+
142
+ host_data.append(item)
143
+ return queue, host_data
144
+
145
+
146
+ def _get_backend(obj, queue, method_name, *data):
147
+ cpu_device = queue is None or queue.sycl_device.is_cpu
148
+ gpu_device = queue is not None and queue.sycl_device.is_gpu
149
+
150
+ if cpu_device:
151
+ patching_status = obj._onedal_cpu_supported(method_name, *data)
152
+ if patching_status.get_status():
153
+ return "onedal", queue, patching_status
154
+ else:
155
+ return "sklearn", None, patching_status
156
+
157
+ _, d4p_options = _get_device_info_from_daal4py()
158
+ allow_fallback_to_host = get_config()["allow_fallback_to_host"] or d4p_options.get(
159
+ "host_offload_on_fail", False
160
+ )
161
+
162
+ if gpu_device:
163
+ patching_status = obj._onedal_gpu_supported(method_name, *data)
164
+ if patching_status.get_status():
165
+ return "onedal", queue, patching_status
166
+ else:
167
+ if allow_fallback_to_host:
168
+ patching_status = obj._onedal_cpu_supported(method_name, *data)
169
+ if patching_status.get_status():
170
+ return "onedal", None, patching_status
171
+ else:
172
+ return "sklearn", None, patching_status
173
+ else:
174
+ return "sklearn", None, patching_status
175
+
176
+ raise RuntimeError("Device support is not implemented")
177
+
178
+
179
+ def dispatch(obj, method_name, branches, *args, **kwargs):
180
+ q = _get_global_queue()
181
+ q, hostargs = _transfer_to_host(q, *args)
182
+ q, hostvalues = _transfer_to_host(q, *kwargs.values())
183
+ hostkwargs = dict(zip(kwargs.keys(), hostvalues))
184
+
185
+ backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)
186
+
187
+ if backend == "onedal":
188
+ patching_status.write_log(queue=q)
189
+ return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
190
+ if backend == "sklearn":
191
+ patching_status.write_log()
192
+ return branches[backend](obj, *hostargs, **hostkwargs)
193
+ raise RuntimeError(
194
+ f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}"
195
+ )
196
+
197
+
198
+ def _copy_to_usm(queue, array):
199
+ if not dpctl_available:
200
+ raise RuntimeError(
201
+ "dpctl need to be installed to work " "with __sycl_usm_array_interface__"
202
+ )
203
+ mem = MemoryUSMDevice(array.nbytes, queue=queue)
204
+ mem.copy_from_host(array.tobytes())
205
+ return usm_ndarray(array.shape, array.dtype, buffer=mem)
206
+
207
+
208
+ def wrap_output_data(func):
209
+ @wraps(func)
210
+ def wrapper(self, *args, **kwargs):
211
+ data = (*args, *kwargs.values())
212
+ if len(data) == 0:
213
+ usm_iface = None
214
+ else:
215
+ usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
216
+ result = func(self, *args, **kwargs)
217
+ if usm_iface is not None:
218
+ result = _copy_to_usm(usm_iface["syclobj"], result)
219
+ if dpnp_available and isinstance(data[0], dpnp.ndarray):
220
+ result = dpnp.array(result, copy=False)
221
+ return result
222
+
223
+ return wrapper
@@ -0,0 +1,95 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ import logging
19
+ import os
20
+ import sys
21
+ import warnings
22
+
23
+ from daal4py.sklearn._utils import (
24
+ PatchingConditionsChain as daal4py_PatchingConditionsChain,
25
+ )
26
+ from daal4py.sklearn._utils import daal_check_version
27
+
28
+
29
+ class PatchingConditionsChain(daal4py_PatchingConditionsChain):
30
+ def get_status(self):
31
+ return self.patching_is_enabled
32
+
33
+ def write_log(self, queue=None):
34
+ if self.patching_is_enabled:
35
+ self.logger.info(
36
+ f"{self.scope_name}: {get_patch_message('onedal', queue=queue)}"
37
+ )
38
+ else:
39
+ self.logger.debug(
40
+ f"{self.scope_name}: debugging for the patch is enabled to track"
41
+ " the usage of Intel® oneAPI Data Analytics Library (oneDAL)"
42
+ )
43
+ for message in self.messages:
44
+ self.logger.debug(
45
+ f"{self.scope_name}: patching failed with cause - {message}"
46
+ )
47
+ self.logger.info(f"{self.scope_name}: {get_patch_message('sklearn')}")
48
+
49
+
50
+ def set_sklearn_ex_verbose():
51
+ log_level = os.environ.get("SKLEARNEX_VERBOSE")
52
+
53
+ logger = logging.getLogger("sklearnex")
54
+ logging_channel = logging.StreamHandler()
55
+ logging_formatter = logging.Formatter("%(levelname)s:%(name)s: %(message)s")
56
+ logging_channel.setFormatter(logging_formatter)
57
+ logger.addHandler(logging_channel)
58
+
59
+ try:
60
+ if log_level is not None:
61
+ logger.setLevel(log_level)
62
+ except Exception:
63
+ warnings.warn(
64
+ 'Unknown level "{}" for logging.\n'
65
+ 'Please, use one of "CRITICAL", "ERROR", '
66
+ '"WARNING", "INFO", "DEBUG".'.format(log_level)
67
+ )
68
+
69
+
70
+ def get_patch_message(s, queue=None):
71
+ if s == "onedal":
72
+ message = "running accelerated version on "
73
+ if queue is not None:
74
+ if queue.sycl_device.is_gpu:
75
+ message += "GPU"
76
+ elif queue.sycl_device.is_cpu:
77
+ message += "CPU"
78
+ else:
79
+ raise RuntimeError("Unsupported device")
80
+ else:
81
+ message += "CPU"
82
+ elif s == "sklearn":
83
+ message = "fallback to original Scikit-learn"
84
+ elif s == "sklearn_after_onedal":
85
+ message = "failed to run accelerated version, fallback to original Scikit-learn"
86
+ else:
87
+ raise ValueError(
88
+ f"Invalid input - expected one of 'onedal','sklearn',"
89
+ f" 'sklearn_after_onedal', got {s}"
90
+ )
91
+ return message
92
+
93
+
94
+ def get_sklearnex_version(rule):
95
+ return daal_check_version(rule)
@@ -0,0 +1,20 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2023 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ from .basic_statistics import BasicStatistics
19
+
20
+ __all__ = ["BasicStatistics"]
@@ -0,0 +1,17 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from onedal.basic_statistics import BasicStatistics
@@ -0,0 +1,21 @@
1
+ #!/usr/bin/env python
2
+ # ===============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ===============================================================================
17
+
18
+ from .dbscan import DBSCAN
19
+ from .k_means import KMeans
20
+
21
+ __all__ = ["KMeans", "DBSCAN"]