scikit-learn-intelex 2025.0.1__py311-none-manylinux_2_28_x86_64.whl → 2025.2.0__py311-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (140) hide show
  1. daal4py/_daal4py.cpython-311-x86_64-linux-gnu.so +0 -0
  2. daal4py/mpi_transceiver.cpython-311-x86_64-linux-gnu.so +0 -0
  3. daal4py/sklearn/_n_jobs_support.py +21 -15
  4. daal4py/sklearn/_utils.py +11 -7
  5. daal4py/sklearn/ensemble/AdaBoostClassifier.py +9 -5
  6. daal4py/sklearn/ensemble/GBTDAAL.py +35 -16
  7. daal4py/sklearn/linear_model/tests/test_linear.py +12 -0
  8. daal4py/sklearn/metrics/_pairwise.py +91 -10
  9. daal4py/sklearn/monkeypatch/tests/test_patching.py +4 -1
  10. daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +3 -4
  11. daal4py/sklearn/utils/validation.py +6 -3
  12. onedal/_config.py +1 -0
  13. onedal/_device_offload.py +15 -40
  14. onedal/_onedal_py_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  15. onedal/_onedal_py_host.cpython-311-x86_64-linux-gnu.so +0 -0
  16. onedal/_onedal_py_spmd_dpc.cpython-311-x86_64-linux-gnu.so +0 -0
  17. onedal/basic_statistics/basic_statistics.py +5 -5
  18. onedal/basic_statistics/incremental_basic_statistics.py +34 -19
  19. onedal/basic_statistics/tests/test_basic_statistics.py +16 -72
  20. onedal/basic_statistics/tests/test_incremental_basic_statistics.py +100 -17
  21. onedal/basic_statistics/tests/utils.py +50 -0
  22. onedal/cluster/dbscan.py +5 -10
  23. onedal/cluster/kmeans.py +16 -19
  24. onedal/cluster/kmeans_init.py +7 -10
  25. onedal/common/_policy.py +0 -4
  26. onedal/common/hyperparameters.py +22 -13
  27. onedal/common/tests/test_policy.py +4 -3
  28. onedal/common/tests/test_sycl.py +128 -0
  29. onedal/covariance/covariance.py +6 -9
  30. onedal/covariance/incremental_covariance.py +41 -26
  31. onedal/covariance/tests/test_incremental_covariance.py +69 -1
  32. onedal/datatypes/__init__.py +2 -2
  33. onedal/datatypes/_data_conversion.py +76 -50
  34. onedal/datatypes/tests/common.py +126 -0
  35. onedal/datatypes/tests/test_data.py +314 -74
  36. onedal/decomposition/incremental_pca.py +42 -32
  37. onedal/decomposition/pca.py +7 -7
  38. onedal/decomposition/tests/test_incremental_pca.py +87 -0
  39. onedal/ensemble/forest.py +30 -14
  40. onedal/linear_model/incremental_linear_model.py +86 -52
  41. onedal/linear_model/linear_model.py +19 -23
  42. onedal/linear_model/logistic_regression.py +9 -11
  43. onedal/linear_model/tests/test_incremental_linear_regression.py +72 -27
  44. onedal/linear_model/tests/test_incremental_ridge_regression.py +64 -0
  45. onedal/linear_model/tests/test_linear_regression.py +110 -0
  46. onedal/neighbors/neighbors.py +55 -70
  47. onedal/primitives/kernel_functions.py +3 -4
  48. onedal/spmd/basic_statistics/incremental_basic_statistics.py +7 -5
  49. onedal/spmd/covariance/incremental_covariance.py +6 -5
  50. onedal/spmd/decomposition/incremental_pca.py +14 -7
  51. onedal/spmd/linear_model/incremental_linear_model.py +12 -8
  52. onedal/svm/svm.py +10 -10
  53. onedal/svm/tests/test_svc.py +8 -0
  54. onedal/tests/test_common.py +25 -9
  55. onedal/tests/utils/_dataframes_support.py +4 -10
  56. onedal/tests/utils/_device_selection.py +19 -24
  57. onedal/utils/_array_api.py +12 -22
  58. onedal/utils/_dpep_helpers.py +56 -0
  59. onedal/utils/tests/test_validation.py +142 -0
  60. onedal/utils/validation.py +52 -20
  61. {scikit_learn_intelex-2025.0.1.dist-info → scikit_learn_intelex-2025.2.0.dist-info}/METADATA +2 -2
  62. {scikit_learn_intelex-2025.0.1.dist-info → scikit_learn_intelex-2025.2.0.dist-info}/RECORD +136 -132
  63. sklearnex/__init__.py +1 -0
  64. sklearnex/_config.py +19 -1
  65. sklearnex/_device_offload.py +17 -12
  66. sklearnex/_utils.py +45 -11
  67. sklearnex/basic_statistics/basic_statistics.py +123 -27
  68. sklearnex/basic_statistics/incremental_basic_statistics.py +65 -34
  69. sklearnex/basic_statistics/tests/test_basic_statistics.py +190 -36
  70. sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +99 -28
  71. sklearnex/cluster/dbscan.py +12 -7
  72. sklearnex/cluster/k_means.py +56 -42
  73. sklearnex/cluster/tests/test_kmeans.py +15 -11
  74. sklearnex/conftest.py +9 -0
  75. sklearnex/covariance/incremental_covariance.py +32 -13
  76. sklearnex/covariance/tests/test_incremental_covariance.py +61 -0
  77. sklearnex/decomposition/pca.py +30 -19
  78. sklearnex/dispatcher.py +1 -10
  79. sklearnex/ensemble/_forest.py +72 -59
  80. sklearnex/ensemble/tests/test_forest.py +40 -20
  81. sklearnex/linear_model/incremental_linear.py +52 -40
  82. sklearnex/linear_model/incremental_ridge.py +18 -4
  83. sklearnex/linear_model/linear.py +114 -75
  84. sklearnex/linear_model/logistic_regression.py +49 -39
  85. sklearnex/linear_model/ridge.py +374 -8
  86. sklearnex/linear_model/tests/test_incremental_linear.py +70 -6
  87. sklearnex/linear_model/tests/test_incremental_ridge.py +61 -0
  88. sklearnex/linear_model/tests/test_linear.py +41 -41
  89. sklearnex/linear_model/tests/test_ridge.py +256 -0
  90. sklearnex/manifold/tests/test_tsne.py +226 -2
  91. sklearnex/neighbors/_lof.py +16 -11
  92. sklearnex/neighbors/common.py +4 -4
  93. sklearnex/neighbors/knn_classification.py +20 -15
  94. sklearnex/neighbors/knn_regression.py +18 -14
  95. sklearnex/neighbors/knn_unsupervised.py +22 -14
  96. sklearnex/neighbors/tests/test_neighbors.py +4 -2
  97. sklearnex/preview/__init__.py +1 -1
  98. sklearnex/preview/covariance/covariance.py +18 -13
  99. sklearnex/preview/covariance/tests/test_covariance.py +1 -1
  100. sklearnex/preview/decomposition/incremental_pca.py +30 -14
  101. sklearnex/preview/decomposition/tests/test_incremental_pca.py +70 -0
  102. sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +4 -4
  103. sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +7 -8
  104. sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +1 -1
  105. sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +4 -3
  106. sklearnex/spmd/covariance/tests/test_covariance_spmd.py +1 -1
  107. sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +1 -1
  108. sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +1 -1
  109. sklearnex/spmd/decomposition/tests/test_pca_spmd.py +1 -1
  110. sklearnex/spmd/ensemble/tests/test_forest_spmd.py +1 -1
  111. sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +17 -15
  112. sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +1 -1
  113. sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +1 -5
  114. sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +1 -1
  115. sklearnex/svm/_common.py +58 -47
  116. sklearnex/svm/nusvc.py +68 -29
  117. sklearnex/svm/nusvr.py +40 -18
  118. sklearnex/svm/svc.py +66 -27
  119. sklearnex/svm/svr.py +36 -18
  120. sklearnex/tests/test_common.py +451 -14
  121. sklearnex/tests/test_config.py +87 -7
  122. sklearnex/tests/test_hyperparameters.py +43 -0
  123. sklearnex/tests/test_memory_usage.py +69 -13
  124. sklearnex/tests/test_monkeypatch.py +4 -11
  125. sklearnex/tests/test_n_jobs_support.py +75 -70
  126. sklearnex/tests/test_patching.py +1 -9
  127. sklearnex/tests/test_run_to_run_stability.py +43 -13
  128. sklearnex/{preview/linear_model → tests/utils}/__init__.py +33 -4
  129. sklearnex/tests/{_utils.py → utils/base.py} +117 -9
  130. sklearnex/utils/__init__.py +2 -2
  131. sklearnex/utils/tests/test_validation.py +238 -0
  132. sklearnex/utils/validation.py +192 -1
  133. sklearnex/linear_model/logistic_path.py +0 -17
  134. sklearnex/preview/linear_model/ridge.py +0 -419
  135. sklearnex/preview/linear_model/tests/test_ridge.py +0 -102
  136. sklearnex/utils/tests/test_finite.py +0 -89
  137. {scikit_learn_intelex-2025.0.1.dist-info → scikit_learn_intelex-2025.2.0.dist-info}/LICENSE.txt +0 -0
  138. {scikit_learn_intelex-2025.0.1.dist-info → scikit_learn_intelex-2025.2.0.dist-info}/WHEEL +0 -0
  139. {scikit_learn_intelex-2025.0.1.dist-info → scikit_learn_intelex-2025.2.0.dist-info}/top_level.txt +0 -0
  140. /sklearnex/tests/{_utils_spmd.py → utils/spmd.py} +0 -0
@@ -15,6 +15,7 @@
15
15
  # ==============================================================================
16
16
 
17
17
  import logging
18
+ import sys
18
19
  import threading
19
20
  from functools import wraps
20
21
  from inspect import Parameter, signature
@@ -76,7 +77,7 @@ def _run_with_n_jobs(method):
76
77
  """
77
78
 
78
79
  @wraps(method)
79
- def method_wrapper(self, *args, **kwargs):
80
+ def n_jobs_wrapper(self, *args, **kwargs):
80
81
  # threading parallel backend branch
81
82
  if not isinstance(threading.current_thread(), threading._MainThread):
82
83
  warn(
@@ -117,7 +118,10 @@ def _run_with_n_jobs(method):
117
118
  n_jobs = max(1, n_threads + n_jobs + 1)
118
119
  # branch with set n_jobs
119
120
  old_n_threads = get_n_threads()
120
- if n_jobs != old_n_threads:
121
+ if n_jobs == old_n_threads:
122
+ return method(self, *args, **kwargs)
123
+
124
+ try:
121
125
  logger = logging.getLogger("sklearnex")
122
126
  cl = self.__class__
123
127
  logger.debug(
@@ -125,12 +129,11 @@ def _run_with_n_jobs(method):
125
129
  f"setting {n_jobs} threads (previous - {old_n_threads})"
126
130
  )
127
131
  set_n_threads(n_jobs)
128
- result = method(self, *args, **kwargs)
129
- if n_jobs != old_n_threads:
132
+ return method(self, *args, **kwargs)
133
+ finally:
130
134
  set_n_threads(old_n_threads)
131
- return result
132
135
 
133
- return method_wrapper
136
+ return n_jobs_wrapper
134
137
 
135
138
 
136
139
  def control_n_jobs(decorated_methods: list = []):
@@ -149,7 +152,8 @@ def control_n_jobs(decorated_methods: list = []):
149
152
 
150
153
  Parameters
151
154
  ----------
152
- decorated_methods (list): A list of method names to be executed with 'n_jobs'.
155
+ decorated_methods: list
156
+ A list of method names to be executed with 'n_jobs'.
153
157
 
154
158
  Example
155
159
  -------
@@ -209,14 +213,16 @@ def control_n_jobs(decorated_methods: list = []):
209
213
  and isinstance(original_class.__doc__, str)
210
214
  and "n_jobs : int" not in original_class.__doc__
211
215
  ):
212
- parameters_doc_tail = "\n Attributes"
213
- n_jobs_doc = """
214
- n_jobs : int, default=None
215
- The number of jobs to use in parallel for the computation.
216
- ``None`` means using all physical cores
217
- unless in a :obj:`joblib.parallel_backend` context.
218
- ``-1`` means using all logical cores.
219
- See :term:`Glossary <n_jobs>` for more details.
216
+ # Python 3.13 removed extra tab in class doc string
217
+ tab = " " if sys.version_info.minor < 13 else ""
218
+ parameters_doc_tail = f"\n{tab}Attributes"
219
+ n_jobs_doc = f"""
220
+ {tab}n_jobs : int, default=None
221
+ {tab} The number of jobs to use in parallel for the computation.
222
+ {tab} ``None`` means using all physical cores
223
+ {tab} unless in a :obj:`joblib.parallel_backend` context.
224
+ {tab} ``-1`` means using all logical cores.
225
+ {tab} See :term:`Glossary <n_jobs>` for more details.
220
226
  """
221
227
  original_class.__doc__ = original_class.__doc__.replace(
222
228
  parameters_doc_tail, n_jobs_doc + parameters_doc_tail
daal4py/sklearn/_utils.py CHANGED
@@ -18,7 +18,7 @@ import functools
18
18
  import os
19
19
  import sys
20
20
  import warnings
21
- from typing import Any, Callable, Tuple
21
+ from typing import Any, Tuple
22
22
 
23
23
  import numpy as np
24
24
  from numpy.lib.recfunctions import require_fields
@@ -95,17 +95,21 @@ def daal_check_version(
95
95
  return False
96
96
 
97
97
 
98
- @functools.lru_cache(maxsize=256, typed=False)
99
- def sklearn_check_version(ver):
100
- if hasattr(Version(ver), "base_version"):
101
- base_sklearn_version = Version(sklearn_version).base_version
102
- res = bool(Version(base_sklearn_version) >= Version(ver))
98
+ def _package_check_version(version_to_check, available_version):
99
+ if hasattr(Version(version_to_check), "base_version"):
100
+ base_package_version = Version(available_version).base_version
101
+ res = bool(Version(base_package_version) >= Version(version_to_check))
103
102
  else:
104
103
  # packaging module not available
105
- res = bool(Version(sklearn_version) >= Version(ver))
104
+ res = bool(Version(available_version) >= Version(version_to_check))
106
105
  return res
107
106
 
108
107
 
108
+ @functools.lru_cache(maxsize=256, typed=False)
109
+ def sklearn_check_version(ver):
110
+ return _package_check_version(ver, sklearn_version)
111
+
112
+
109
113
  def parse_dtype(dt):
110
114
  if dt == np.double:
111
115
  return "double"
@@ -25,13 +25,19 @@ from sklearn.utils.multiclass import check_classification_targets
25
25
  from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
26
26
 
27
27
  import daal4py as d4p
28
+ from daal4py.sklearn._utils import sklearn_check_version
28
29
 
29
30
  from .._n_jobs_support import control_n_jobs
30
31
  from .._utils import getFPType
31
32
 
33
+ if sklearn_check_version("1.6"):
34
+ from sklearn.utils.validation import validate_data
35
+ else:
36
+ validate_data = BaseEstimator._validate_data
37
+
32
38
 
33
39
  @control_n_jobs(decorated_methods=["fit", "predict"])
34
- class AdaBoostClassifier(BaseEstimator, ClassifierMixin):
40
+ class AdaBoostClassifier(ClassifierMixin, BaseEstimator):
35
41
  def __init__(
36
42
  self,
37
43
  split_criterion="gini",
@@ -89,7 +95,7 @@ class AdaBoostClassifier(BaseEstimator, ClassifierMixin):
89
95
  )
90
96
 
91
97
  # Check that X and y have correct shape
92
- X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
98
+ X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])
93
99
 
94
100
  check_classification_targets(y)
95
101
 
@@ -151,9 +157,7 @@ class AdaBoostClassifier(BaseEstimator, ClassifierMixin):
151
157
  check_is_fitted(self)
152
158
 
153
159
  # Input validation
154
- X = check_array(X, dtype=[np.single, np.double])
155
- if X.shape[1] != self.n_features_in_:
156
- raise ValueError("Shape of input is different from what was seen in `fit`")
160
+ X = validate_data(self, X, dtype=[np.float64, np.float32], reset=False)
157
161
 
158
162
  # Trivial case
159
163
  if self.n_classes_ == 1:
@@ -26,10 +26,16 @@ from sklearn.utils.multiclass import check_classification_targets
26
26
  from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
27
27
 
28
28
  import daal4py as d4p
29
+ from daal4py.sklearn._utils import sklearn_check_version
29
30
 
30
31
  from .._n_jobs_support import control_n_jobs
31
32
  from .._utils import getFPType
32
33
 
34
+ if sklearn_check_version("1.6"):
35
+ from sklearn.utils.validation import validate_data
36
+ else:
37
+ validate_data = BaseEstimator._validate_data
38
+
33
39
 
34
40
  class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel):
35
41
  def __init__(
@@ -128,15 +134,22 @@ class GBTDAALBase(BaseEstimator, d4p.mb.GBTDAALBaseModel):
128
134
  def _more_tags(self):
129
135
  return {"allow_nan": self.allow_nan_}
130
136
 
137
+ if sklearn_check_version("1.6"):
138
+
139
+ def __sklearn_tags__(self):
140
+ tags = super().__sklearn_tags__()
141
+ tags.input_tags.allow_nan = self.allow_nan_
142
+ return tags
143
+
131
144
 
132
145
  @control_n_jobs(decorated_methods=["fit", "predict"])
133
- class GBTDAALClassifier(GBTDAALBase, ClassifierMixin):
146
+ class GBTDAALClassifier(ClassifierMixin, GBTDAALBase):
134
147
  def fit(self, X, y):
135
148
  # Check the algorithm parameters
136
149
  self._check_params()
137
150
 
138
151
  # Check that X and y have correct shape
139
- X, y = check_X_y(X, y, y_numeric=False, dtype=[np.single, np.double])
152
+ X, y = check_X_y(X, y, y_numeric=False, dtype=[np.float64, np.float32])
140
153
 
141
154
  check_classification_targets(y)
142
155
 
@@ -196,15 +209,18 @@ class GBTDAALClassifier(GBTDAALBase, ClassifierMixin):
196
209
  def _predict(
197
210
  self, X, resultsToEvaluate, pred_contribs=False, pred_interactions=False
198
211
  ):
199
- # Input validation
200
- if not self.allow_nan_:
201
- X = check_array(X, dtype=[np.single, np.double])
202
- else:
203
- X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan")
204
-
205
212
  # Check is fit had been called
206
213
  check_is_fitted(self, ["n_features_in_", "n_classes_"])
207
214
 
215
+ # Input validation
216
+ X = validate_data(
217
+ self,
218
+ X,
219
+ dtype=[np.float64, np.float32],
220
+ force_all_finite="allow-nan" if self.allow_nan_ else True,
221
+ reset=False,
222
+ )
223
+
208
224
  # Trivial case
209
225
  if self.n_classes_ == 1:
210
226
  return np.full(X.shape[0], self.classes_[0])
@@ -251,13 +267,13 @@ class GBTDAALClassifier(GBTDAALBase, ClassifierMixin):
251
267
 
252
268
 
253
269
  @control_n_jobs(decorated_methods=["fit", "predict"])
254
- class GBTDAALRegressor(GBTDAALBase, RegressorMixin):
270
+ class GBTDAALRegressor(RegressorMixin, GBTDAALBase):
255
271
  def fit(self, X, y):
256
272
  # Check the algorithm parameters
257
273
  self._check_params()
258
274
 
259
275
  # Check that X and y have correct shape
260
- X, y = check_X_y(X, y, y_numeric=True, dtype=[np.single, np.double])
276
+ X, y = check_X_y(X, y, y_numeric=True, dtype=[np.float64, np.float32])
261
277
 
262
278
  # Convert to 2d array
263
279
  y_ = y.reshape((-1, 1))
@@ -297,15 +313,18 @@ class GBTDAALRegressor(GBTDAALBase, RegressorMixin):
297
313
  return self
298
314
 
299
315
  def predict(self, X, pred_contribs=False, pred_interactions=False):
300
- # Input validation
301
- if not self.allow_nan_:
302
- X = check_array(X, dtype=[np.single, np.double])
303
- else:
304
- X = check_array(X, dtype=[np.single, np.double], force_all_finite="allow-nan")
305
-
306
316
  # Check is fit had been called
307
317
  check_is_fitted(self, ["n_features_in_"])
308
318
 
319
+ # Input validation
320
+ X = validate_data(
321
+ self,
322
+ X,
323
+ dtype=[np.float64, np.float32],
324
+ force_all_finite="allow-nan" if self.allow_nan_ else True,
325
+ reset=False,
326
+ )
327
+
309
328
  fptype = getFPType(X)
310
329
  return self._predict_regression(X, fptype, pred_contribs, pred_interactions)
311
330
 
@@ -14,6 +14,18 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
18
+ from os import environ
19
+
20
+ from daal4py.sklearn._utils import sklearn_check_version
21
+
22
+ # sklearn requires manual enabling of Scipy array API support
23
+ # if `array-api-compat` package is present in environment
24
+ # TODO: create generic approach to handle this for all tests
25
+ if sklearn_check_version("1.6"):
26
+ environ["SCIPY_ARRAY_API"] = "1"
27
+
28
+
17
29
  import numpy as np
18
30
  import pytest
19
31
  from sklearn.datasets import make_regression
@@ -48,7 +48,12 @@ from daal4py.sklearn.utils.validation import _daal_check_array
48
48
  from .._utils import PatchingConditionsChain, getFPType, sklearn_check_version
49
49
 
50
50
  if sklearn_check_version("1.3"):
51
- from sklearn.utils._param_validation import Integral, StrOptions, validate_params
51
+ from sklearn.utils._param_validation import (
52
+ Hidden,
53
+ Integral,
54
+ StrOptions,
55
+ validate_params,
56
+ )
52
57
 
53
58
 
54
59
  def _daal4py_cosine_distance_dense(X):
@@ -65,7 +70,7 @@ def _daal4py_correlation_distance_dense(X):
65
70
  return res.correlationDistance
66
71
 
67
72
 
68
- def pairwise_distances(
73
+ def _pairwise_distances(
69
74
  X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds
70
75
  ):
71
76
  if metric not in _VALID_METRICS and not callable(metric) and metric != "precomputed":
@@ -140,16 +145,92 @@ def pairwise_distances(
140
145
  return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
141
146
 
142
147
 
148
+ # logic to deprecate `force_all_finite` from sklearn:
149
+ # it was renamed to `ensure_all_finite` since 1.6 and will be removed in 1.8
143
150
  if sklearn_check_version("1.3"):
151
+ pairwise_distances_parameters = {
152
+ "X": ["array-like", "sparse matrix"],
153
+ "Y": ["array-like", "sparse matrix", None],
154
+ "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
155
+ "n_jobs": [Integral, None],
156
+ "force_all_finite": [
157
+ "boolean",
158
+ StrOptions({"allow-nan"}),
159
+ Hidden(StrOptions({"deprecated"})),
160
+ ],
161
+ "ensure_all_finite": [
162
+ "boolean",
163
+ StrOptions({"allow-nan"}),
164
+ Hidden(None),
165
+ ],
166
+ }
167
+ if sklearn_check_version("1.6"):
168
+ if sklearn_check_version("1.8"):
169
+ del pairwise_distances_parameters["force_all_finite"]
170
+
171
+ def pairwise_distances(
172
+ X,
173
+ Y=None,
174
+ metric="euclidean",
175
+ *,
176
+ n_jobs=None,
177
+ ensure_all_finite=None,
178
+ **kwds,
179
+ ):
180
+ return _pairwise_distances(
181
+ X,
182
+ Y,
183
+ metric,
184
+ n_jobs=n_jobs,
185
+ force_all_finite=ensure_all_finite,
186
+ **kwds,
187
+ )
188
+
189
+ else:
190
+ from sklearn.utils.deprecation import _deprecate_force_all_finite
191
+
192
+ def pairwise_distances(
193
+ X,
194
+ Y=None,
195
+ metric="euclidean",
196
+ *,
197
+ n_jobs=None,
198
+ force_all_finite="deprecated",
199
+ ensure_all_finite=None,
200
+ **kwds,
201
+ ):
202
+ force_all_finite = _deprecate_force_all_finite(
203
+ force_all_finite, ensure_all_finite
204
+ )
205
+ return _pairwise_distances(
206
+ X, Y, metric, n_jobs=n_jobs, force_all_finite=force_all_finite, **kwds
207
+ )
208
+
209
+ else:
210
+ del pairwise_distances_parameters["ensure_all_finite"]
211
+
212
+ def pairwise_distances(
213
+ X,
214
+ Y=None,
215
+ metric="euclidean",
216
+ *,
217
+ n_jobs=None,
218
+ force_all_finite=True,
219
+ **kwds,
220
+ ):
221
+ return _pairwise_distances(
222
+ X,
223
+ Y,
224
+ metric,
225
+ n_jobs=n_jobs,
226
+ force_all_finite=force_all_finite,
227
+ **kwds,
228
+ )
229
+
144
230
  pairwise_distances = validate_params(
145
- {
146
- "X": ["array-like", "sparse matrix"],
147
- "Y": ["array-like", "sparse matrix", None],
148
- "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
149
- "n_jobs": [Integral, None],
150
- "force_all_finite": ["boolean", StrOptions({"allow-nan"})],
151
- },
231
+ pairwise_distances_parameters,
152
232
  prefer_skip_nested_validation=True,
153
233
  )(pairwise_distances)
154
-
234
+ else:
235
+ pairwise_distances = _pairwise_distances
155
236
  pairwise_distances.__doc__ = pairwise_distances_original.__doc__
@@ -55,7 +55,10 @@ def get_result_log():
55
55
  absolute_path = str(pathlib.Path(__file__).parent.absolute())
56
56
  try:
57
57
  process = subprocess.check_output(
58
- [sys.executable, absolute_path + "/utils/_launch_algorithms.py"]
58
+ [
59
+ sys.executable,
60
+ os.sep.join([absolute_path, "utils", "_launch_algorithms.py"]),
61
+ ]
59
62
  )
60
63
  except subprocess.CalledProcessError as e:
61
64
  print(e)
@@ -29,8 +29,7 @@ import sys
29
29
  from sklearn.datasets import load_diabetes, load_iris, make_regression
30
30
  from sklearn.metrics import pairwise_distances, roc_auc_score
31
31
 
32
- absolute_path = str(pathlib.Path(__file__).parent.absolute())
33
- sys.path.append(absolute_path + "/../")
32
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.absolute()))
34
33
  from _models_info import MODELS_INFO, TYPES
35
34
 
36
35
 
@@ -84,7 +83,7 @@ def run_patch(model_info, dtype):
84
83
  logging.info(i)
85
84
 
86
85
 
87
- def run_algotithms():
86
+ def run_algorithms():
88
87
  for info in MODELS_INFO:
89
88
  for t in TYPES:
90
89
  model_name = get_class_name(info["model"])
@@ -114,5 +113,5 @@ def run_utils():
114
113
 
115
114
 
116
115
  if __name__ == "__main__":
117
- run_algotithms()
116
+ run_algorithms()
118
117
  run_utils()
@@ -98,9 +98,12 @@ def _assert_all_finite(
98
98
  )
99
99
  _dal_ready = _patching_status.and_conditions(
100
100
  [
101
- (X.ndim in [1, 2], "X has not 1 or 2 dimensions."),
102
- (not np.any(np.equal(X.shape, 0)), "X shape contains 0."),
103
- (dt in [np.float32, np.float64], "X dtype is not float32 or float64."),
101
+ (X.ndim in [1, 2], f"Input {input_name} does not have 1 or 2 dimensions."),
102
+ (not np.any(np.equal(X.shape, 0)), f"Input {input_name} shape contains a 0."),
103
+ (
104
+ dt in [np.float32, np.float64],
105
+ f"Input {input_name} dtype is not float32 or float64.",
106
+ ),
104
107
  ]
105
108
  )
106
109
  _patching_status.write_log()
onedal/_config.py CHANGED
@@ -21,6 +21,7 @@ import threading
21
21
  _default_global_config = {
22
22
  "target_offload": "auto",
23
23
  "allow_fallback_to_host": False,
24
+ "allow_sklearn_after_onedal": True,
24
25
  }
25
26
 
26
27
  _threadlocal = threading.local()
onedal/_device_offload.py CHANGED
@@ -23,49 +23,26 @@ from sklearn import get_config
23
23
 
24
24
  from ._config import _get_config
25
25
  from .utils._array_api import _asarray, _is_numpy_namespace
26
+ from .utils._dpep_helpers import dpctl_available, dpnp_available
26
27
 
27
- try:
28
+ if dpctl_available:
28
29
  from dpctl import SyclQueue
29
30
  from dpctl.memory import MemoryUSMDevice, as_usm_memory
30
31
  from dpctl.tensor import usm_ndarray
32
+ else:
33
+ import onedal
31
34
 
32
- dpctl_available = True
33
- except ImportError:
34
- dpctl_available = False
35
+ # setting fallback to `object` will make if isinstance call
36
+ # in _get_global_queue always true for situations without the
37
+ # dpc backend when `device_offload` is used. Instead, it will
38
+ # fail at the policy check phase yielding a RuntimeError
39
+ SyclQueue = getattr(onedal._backend, "SyclQueue", object)
35
40
 
36
- try:
41
+ if dpnp_available:
37
42
  import dpnp
38
43
 
39
44
  from .utils._array_api import _convert_to_dpnp
40
45
 
41
- dpnp_available = True
42
- except ImportError:
43
- dpnp_available = False
44
-
45
-
46
- class DummySyclQueue:
47
- """This class is designed to act like dpctl.SyclQueue
48
- to allow device dispatching in scenarios when dpctl is not available"""
49
-
50
- class DummySyclDevice:
51
- def __init__(self, filter_string):
52
- self._filter_string = filter_string
53
- self.is_cpu = "cpu" in filter_string
54
- self.is_gpu = "gpu" in filter_string
55
- self.has_aspect_fp64 = self.is_cpu
56
-
57
- if not (self.is_cpu):
58
- logging.warning(
59
- "Device support is limited. "
60
- "Please install dpctl for full experience"
61
- )
62
-
63
- def get_filter_string(self):
64
- return self._filter_string
65
-
66
- def __init__(self, filter_string):
67
- self.sycl_device = self.DummySyclDevice(filter_string)
68
-
69
46
 
70
47
  def _copy_to_usm(queue, array):
71
48
  if not dpctl_available:
@@ -140,25 +117,23 @@ def _transfer_to_host(queue, *data):
140
117
  raise RuntimeError("Input data shall be located on single target device")
141
118
 
142
119
  host_data.append(item)
143
- return queue, host_data
120
+ return has_usm_data, queue, host_data
144
121
 
145
122
 
146
123
  def _get_global_queue():
147
124
  target = _get_config()["target_offload"]
148
125
 
149
- QueueClass = DummySyclQueue if not dpctl_available else SyclQueue
150
-
151
126
  if target != "auto":
152
- if isinstance(target, QueueClass):
127
+ if isinstance(target, SyclQueue):
153
128
  return target
154
- return QueueClass(target)
129
+ return SyclQueue(target)
155
130
  return None
156
131
 
157
132
 
158
133
  def _get_host_inputs(*args, **kwargs):
159
134
  q = _get_global_queue()
160
- q, hostargs = _transfer_to_host(q, *args)
161
- q, hostvalues = _transfer_to_host(q, *kwargs.values())
135
+ _, q, hostargs = _transfer_to_host(q, *args)
136
+ _, q, hostvalues = _transfer_to_host(q, *kwargs.values())
162
137
  hostkwargs = dict(zip(kwargs.keys(), hostvalues))
163
138
  return q, hostargs, hostkwargs
164
139
 
@@ -20,7 +20,7 @@ from abc import ABCMeta, abstractmethod
20
20
  import numpy as np
21
21
 
22
22
  from ..common._base import BaseEstimator
23
- from ..datatypes import _convert_to_supported, from_table, to_table
23
+ from ..datatypes import from_table, to_table
24
24
  from ..utils import _is_csr
25
25
  from ..utils.validation import _check_array
26
26
 
@@ -57,7 +57,7 @@ class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
57
57
  def _get_onedal_params(self, is_csr, dtype=np.float32):
58
58
  options = self._get_result_options(self.options)
59
59
  return {
60
- "fptype": "float" if dtype == np.float32 else "double",
60
+ "fptype": dtype,
61
61
  "method": "sparse" if is_csr else self.algorithm,
62
62
  "result_option": options,
63
63
  }
@@ -81,11 +81,11 @@ class BasicStatistics(BaseBasicStatistics):
81
81
  if sample_weight is not None:
82
82
  sample_weight = _check_array(sample_weight, ensure_2d=False)
83
83
 
84
- data, sample_weight = _convert_to_supported(policy, data, sample_weight)
85
84
  is_single_dim = data.ndim == 1
86
- data_table, weights_table = to_table(data, sample_weight)
87
85
 
88
- dtype = data.dtype
86
+ data_table, weights_table = to_table(data, sample_weight, queue=queue)
87
+
88
+ dtype = data_table.dtype
89
89
  raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
90
90
  for opt, raw_value in raw_result.items():
91
91
  value = from_table(raw_value).ravel()