scikit-learn-intelex 2023.2.1__py310-none-win_amd64.whl → 2024.0.1__py310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (109) hide show
  1. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
  2. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
  3. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
  4. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
  7. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
  8. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
  11. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
  12. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
  13. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
  14. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
  15. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
  16. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
  17. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  20. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
  21. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
  22. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
  23. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
  24. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
  25. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
  26. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
  27. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
  28. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
  29. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
  30. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
  31. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
  32. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
  33. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
  34. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
  35. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
  36. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
  37. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
  38. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
  39. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
  40. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  43. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
  44. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  46. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
  47. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
  48. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
  49. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
  50. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
  51. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
  52. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
  53. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
  54. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
  55. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
  56. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  58. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
  59. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
  60. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
  61. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
  62. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
  63. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
  64. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
  65. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
  66. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
  67. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
  68. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
  69. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
  70. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
  71. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
  72. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  75. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
  76. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
  77. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  79. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  81. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
  82. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  84. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
  85. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  86. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  87. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
  88. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
  89. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  90. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  91. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
  92. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
  93. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
  94. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
  95. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
  96. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
  97. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
  98. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
  99. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
  100. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  101. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
  102. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
  103. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
  104. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
  105. scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
  106. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  107. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  108. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  109. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -14,35 +14,78 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
- from daal4py.sklearn._utils import daal_check_version
18
17
  import logging
18
+ from abc import ABC
19
19
 
20
- if daal_check_version((2023, 'P', 100)):
21
- import numpy as np
20
+ from daal4py.sklearn._utils import daal_check_version
21
+
22
+
23
+ def get_coef(self):
24
+ return self._coef_
25
+
26
+
27
+ def set_coef(self, value):
28
+ self._coef_ = value
29
+ if hasattr(self, "_onedal_estimator"):
30
+ self._onedal_estimator.coef_ = value
31
+ if not self._is_in_fit:
32
+ del self._onedal_estimator._onedal_model
33
+
34
+
35
+ def get_intercept(self):
36
+ return self._intercept_
22
37
 
23
- from ._common import BaseLinearRegression
24
- from ..._device_offload import dispatch, wrap_output_data
25
38
 
26
- from ...utils.validation import _assert_all_finite
27
- from daal4py.sklearn._utils import (
28
- get_dtype, make2d, sklearn_check_version, PatchingConditionsChain)
39
+ def set_intercept(self, value):
40
+ self._intercept_ = value
41
+ if hasattr(self, "_onedal_estimator"):
42
+ self._onedal_estimator.intercept_ = value
43
+ if not self._is_in_fit:
44
+ del self._onedal_estimator._onedal_model
45
+
46
+
47
+ class BaseLinearRegression(ABC):
48
+ def _save_attributes(self):
49
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
50
+ self.fit_status_ = 0
51
+ self._coef_ = self._onedal_estimator.coef_
52
+ self._intercept_ = self._onedal_estimator.intercept_
53
+ self._sparse = False
54
+
55
+ self.coef_ = property(get_coef, set_coef)
56
+ self.intercept_ = property(get_intercept, set_intercept)
57
+
58
+ self._is_in_fit = True
59
+ self.coef_ = self._coef_
60
+ self.intercept_ = self._intercept_
61
+ self._is_in_fit = False
62
+
63
+
64
+ if daal_check_version((2023, "P", 100)):
65
+ import numpy as np
29
66
  from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
30
67
 
31
- if sklearn_check_version('1.0') and not sklearn_check_version('1.2'):
68
+ from daal4py.sklearn._utils import get_dtype, make2d, sklearn_check_version
69
+
70
+ from .._device_offload import dispatch, wrap_output_data
71
+ from .._utils import PatchingConditionsChain, get_patch_message
72
+ from ..utils.validation import _assert_all_finite
73
+
74
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
32
75
  from sklearn.linear_model._base import _deprecate_normalize
33
76
 
34
- from sklearn.utils.validation import _deprecate_positional_args, check_X_y
35
- from sklearn.exceptions import NotFittedError
36
77
  from scipy.sparse import issparse
78
+ from sklearn.exceptions import NotFittedError
79
+ from sklearn.utils.validation import _deprecate_positional_args, check_X_y
37
80
 
38
81
  from onedal.linear_model import LinearRegression as onedal_LinearRegression
39
- from onedal.datatypes import (_num_samples, _get_2d_shape)
82
+ from onedal.utils import _num_features, _num_samples
40
83
 
41
84
  class LinearRegression(sklearn_LinearRegression, BaseLinearRegression):
42
85
  __doc__ = sklearn_LinearRegression.__doc__
43
86
  intercept_, coef_ = None, None
44
87
 
45
- if sklearn_check_version('1.2'):
88
+ if sklearn_check_version("1.2"):
46
89
  _parameter_constraints: dict = {
47
90
  **sklearn_LinearRegression._parameter_constraints
48
91
  }
@@ -60,11 +103,13 @@ if daal_check_version((2023, 'P', 100)):
60
103
  n_jobs=n_jobs,
61
104
  positive=positive,
62
105
  )
63
- elif sklearn_check_version('0.24'):
106
+
107
+ elif sklearn_check_version("0.24"):
108
+
64
109
  def __init__(
65
110
  self,
66
111
  fit_intercept=True,
67
- normalize='deprecated' if sklearn_check_version('1.0') else False,
112
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
68
113
  copy_X=True,
69
114
  n_jobs=None,
70
115
  positive=False,
@@ -76,7 +121,9 @@ if daal_check_version((2023, 'P', 100)):
76
121
  n_jobs=n_jobs,
77
122
  positive=positive,
78
123
  )
124
+
79
125
  else:
126
+
80
127
  def __init__(
81
128
  self,
82
129
  fit_intercept=True,
@@ -88,7 +135,7 @@ if daal_check_version((2023, 'P', 100)):
88
135
  fit_intercept=fit_intercept,
89
136
  normalize=normalize,
90
137
  copy_X=copy_X,
91
- n_jobs=n_jobs
138
+ n_jobs=n_jobs,
92
139
  )
93
140
 
94
141
  def fit(self, X, y, sample_weight=None):
@@ -109,15 +156,22 @@ if daal_check_version((2023, 'P', 100)):
109
156
  self : object
110
157
  Fitted Estimator.
111
158
  """
112
- if sklearn_check_version('1.0'):
159
+ if sklearn_check_version("1.0"):
113
160
  self._check_feature_names(X, reset=True)
114
161
  if sklearn_check_version("1.2"):
115
162
  self._validate_params()
116
163
 
117
- dispatch(self, 'fit', {
118
- 'onedal': self.__class__._onedal_fit,
119
- 'sklearn': sklearn_LinearRegression.fit,
120
- }, X, y, sample_weight)
164
+ dispatch(
165
+ self,
166
+ "fit",
167
+ {
168
+ "onedal": self.__class__._onedal_fit,
169
+ "sklearn": sklearn_LinearRegression.fit,
170
+ },
171
+ X,
172
+ y,
173
+ sample_weight,
174
+ )
121
175
  return self
122
176
 
123
177
  @wrap_output_data
@@ -135,16 +189,21 @@ if daal_check_version((2023, 'P', 100)):
135
189
  """
136
190
  if sklearn_check_version("1.0"):
137
191
  self._check_feature_names(X, reset=False)
138
- return dispatch(self, 'predict', {
139
- 'onedal': self.__class__._onedal_predict,
140
- 'sklearn': sklearn_LinearRegression.predict,
141
- }, X)
192
+ return dispatch(
193
+ self,
194
+ "predict",
195
+ {
196
+ "onedal": self.__class__._onedal_predict,
197
+ "sklearn": sklearn_LinearRegression.predict,
198
+ },
199
+ X,
200
+ )
142
201
 
143
202
  def _test_type_and_finiteness(self, X_in):
144
203
  X = X_in if isinstance(X_in, np.ndarray) else np.asarray(X_in)
145
204
 
146
205
  dtype = X.dtype
147
- if 'complex' in str(type(dtype)):
206
+ if "complex" in str(type(dtype)):
148
207
  return False
149
208
 
150
209
  try:
@@ -154,77 +213,99 @@ if daal_check_version((2023, 'P', 100)):
154
213
  return True
155
214
 
156
215
  def _onedal_fit_supported(self, method_name, *data):
157
- assert method_name == 'fit'
216
+ assert method_name == "fit"
158
217
  assert len(data) == 3
159
218
  X, y, sample_weight = data
160
219
 
161
220
  class_name = self.__class__.__name__
162
221
  patching_status = PatchingConditionsChain(
163
- f'sklearn.linear_model.{class_name}.fit')
222
+ f"sklearn.linear_model.{class_name}.fit"
223
+ )
164
224
 
165
- normalize_is_set = hasattr(self, 'normalize') and self.normalize \
166
- and self.normalize != 'deprecated'
167
- positive_is_set = hasattr(self, 'positive') and self.positive
225
+ normalize_is_set = (
226
+ hasattr(self, "normalize")
227
+ and self.normalize
228
+ and self.normalize != "deprecated"
229
+ )
230
+ positive_is_set = hasattr(self, "positive") and self.positive
231
+
232
+ n_samples = _num_samples(X)
233
+ n_features = _num_features(X, fallback_1d=True)
168
234
 
169
- n_samples, n_features = _get_2d_shape(X, fallback_1d=True)
170
235
  # Check if equations are well defined
171
- is_good_for_onedal = n_samples > \
172
- (n_features + int(self.fit_intercept))
173
-
174
- dal_ready = patching_status.and_conditions([
175
- (sample_weight is None, 'Sample weight is not supported.'),
176
- (not issparse(X) and not issparse(y), 'Sparse input is not supported.'),
177
- (not normalize_is_set, 'Normalization is not supported.'),
178
- (not positive_is_set, 'Forced positive coefficients are not supported.'),
179
- (is_good_for_onedal,
180
- 'The shape of X (fitting) does not satisfy oneDAL requirements:.'
181
- 'Number of features + 1 >= number of samples.')
182
- ])
236
+ is_good_for_onedal = n_samples > (n_features + int(self.fit_intercept))
237
+
238
+ dal_ready = patching_status.and_conditions(
239
+ [
240
+ (sample_weight is None, "Sample weight is not supported."),
241
+ (
242
+ not issparse(X) and not issparse(y),
243
+ "Sparse input is not supported.",
244
+ ),
245
+ (not normalize_is_set, "Normalization is not supported."),
246
+ (
247
+ not positive_is_set,
248
+ "Forced positive coefficients are not supported.",
249
+ ),
250
+ (
251
+ is_good_for_onedal,
252
+ "The shape of X (fitting) does not satisfy oneDAL requirements:."
253
+ "Number of features + 1 >= number of samples.",
254
+ ),
255
+ ]
256
+ )
183
257
  if not dal_ready:
184
- return patching_status.get_status(logs=True)
258
+ return patching_status
185
259
 
186
260
  if not patching_status.and_condition(
187
- self._test_type_and_finiteness(X), 'Input X is not supported.'
261
+ self._test_type_and_finiteness(X), "Input X is not supported."
188
262
  ):
189
- return patching_status.get_status(logs=True)
263
+ return patching_status
190
264
 
191
265
  patching_status.and_condition(
192
- self._test_type_and_finiteness(y), 'Input y is not supported.')
266
+ self._test_type_and_finiteness(y), "Input y is not supported."
267
+ )
193
268
 
194
- return patching_status.get_status(logs=True)
269
+ return patching_status
195
270
 
196
271
  def _onedal_predict_supported(self, method_name, *data):
197
- assert method_name == 'predict'
272
+ assert method_name == "predict"
198
273
  assert len(data) == 1
199
274
 
200
275
  class_name = self.__class__.__name__
201
276
  patching_status = PatchingConditionsChain(
202
- f'sklearn.linear_model.{class_name}.predict')
277
+ f"sklearn.linear_model.{class_name}.predict"
278
+ )
203
279
 
204
280
  n_samples = _num_samples(*data)
205
- model_is_sparse = issparse(self.coef_) or \
206
- (self.fit_intercept and issparse(self.intercept_))
207
- dal_ready = patching_status.and_conditions([
208
- (n_samples > 0, 'Number of samples is less than 1.'),
209
- (not issparse(*data), 'Sparse input is not supported.'),
210
- (not model_is_sparse, 'Sparse coefficients are not supported.'),
211
- (hasattr(self, '_onedal_estimator'), 'oneDAL model was not trained.')
212
- ])
281
+ model_is_sparse = issparse(self.coef_) or (
282
+ self.fit_intercept and issparse(self.intercept_)
283
+ )
284
+ dal_ready = patching_status.and_conditions(
285
+ [
286
+ (n_samples > 0, "Number of samples is less than 1."),
287
+ (not issparse(*data), "Sparse input is not supported."),
288
+ (not model_is_sparse, "Sparse coefficients are not supported."),
289
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
290
+ ]
291
+ )
213
292
  if not dal_ready:
214
- return patching_status.get_status(logs=True)
293
+ return patching_status
215
294
 
216
295
  patching_status.and_condition(
217
- self._test_type_and_finiteness(*data), 'Input X is not supported.')
296
+ self._test_type_and_finiteness(*data), "Input X is not supported."
297
+ )
218
298
 
219
- return patching_status.get_status(logs=True)
299
+ return patching_status
220
300
 
221
301
  def _onedal_supported(self, method_name, *data):
222
- if method_name == 'fit':
302
+ if method_name == "fit":
223
303
  return self._onedal_fit_supported(method_name, *data)
224
- if method_name == 'predict':
304
+ if method_name == "predict":
225
305
  return self._onedal_predict_supported(method_name, *data)
226
306
  raise RuntimeError(
227
- f'Unknown method {method_name} in {self.__class__.__name__}')
307
+ f"Unknown method {method_name} in {self.__class__.__name__}"
308
+ )
228
309
 
229
310
  def _onedal_gpu_supported(self, method_name, *data):
230
311
  return self._onedal_supported(method_name, *data)
@@ -233,30 +314,27 @@ if daal_check_version((2023, 'P', 100)):
233
314
  return self._onedal_supported(method_name, *data)
234
315
 
235
316
  def _initialize_onedal_estimator(self):
236
- onedal_params = {
237
- 'fit_intercept': self.fit_intercept,
238
- 'copy_X': self.copy_X}
317
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
239
318
  self._onedal_estimator = onedal_LinearRegression(**onedal_params)
240
319
 
241
320
  def _onedal_fit(self, X, y, sample_weight, queue=None):
242
321
  assert sample_weight is None
243
322
 
244
323
  check_params = {
245
- 'X': X,
246
- 'y': y,
247
- 'dtype': [np.float64, np.float32],
248
- 'accept_sparse': ['csr', 'csc', 'coo'],
249
- 'y_numeric': True,
250
- 'multi_output': True,
251
- 'force_all_finite': False
324
+ "X": X,
325
+ "y": y,
326
+ "dtype": [np.float64, np.float32],
327
+ "accept_sparse": ["csr", "csc", "coo"],
328
+ "y_numeric": True,
329
+ "multi_output": True,
330
+ "force_all_finite": False,
252
331
  }
253
- if sklearn_check_version('1.2'):
332
+ if sklearn_check_version("1.2"):
254
333
  X, y = self._validate_data(**check_params)
255
334
  else:
256
335
  X, y = check_X_y(**check_params)
257
336
 
258
- if sklearn_check_version(
259
- '1.0') and not sklearn_check_version('1.2'):
337
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
260
338
  self._normalize = _deprecate_normalize(
261
339
  self.normalize,
262
340
  default=False,
@@ -264,13 +342,22 @@ if daal_check_version((2023, 'P', 100)):
264
342
  )
265
343
 
266
344
  self._initialize_onedal_estimator()
267
- self._onedal_estimator.fit(X, y, queue=queue)
345
+ try:
346
+ self._onedal_estimator.fit(X, y, queue=queue)
347
+ self._save_attributes()
348
+
349
+ except RuntimeError:
350
+ logging.getLogger("sklearnex").info(
351
+ f"{self.__class__.__name__}.fit "
352
+ + get_patch_message("sklearn_after_onedal")
353
+ )
268
354
 
269
- self._save_attributes()
355
+ del self._onedal_estimator
356
+ super().fit(X, y)
270
357
 
271
358
  def _onedal_predict(self, X, queue=None):
272
359
  X = self._validate_data(X, accept_sparse=False, reset=False)
273
- if not hasattr(self, '_onedal_estimator'):
360
+ if not hasattr(self, "_onedal_estimator"):
274
361
  self._initialize_onedal_estimator()
275
362
  self._onedal_estimator.coef_ = self.coef_
276
363
  self._onedal_estimator.intercept_ = self.intercept_
@@ -279,5 +366,8 @@ if daal_check_version((2023, 'P', 100)):
279
366
 
280
367
  else:
281
368
  from daal4py.sklearn.linear_model import LinearRegression
282
- logging.warning('Preview LinearRegression requires oneDAL version >= 2023.1 '
283
- 'but it was not found')
369
+
370
+ logging.warning(
371
+ "Sklearnex LinearRegression requires oneDAL version >= 2023.1 "
372
+ "but it was not found"
373
+ )
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
- from daal4py.sklearn.linear_model import logistic_regression_path, LogisticRegression
18
+ from daal4py.sklearn.linear_model import LogisticRegression, logistic_regression_path
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from daal4py.sklearn.linear_model import Ridge
@@ -16,46 +16,62 @@
16
16
  # ===============================================================================
17
17
 
18
18
  import numpy as np
19
+ import pytest
19
20
  from numpy.testing import assert_allclose
20
21
  from sklearn.datasets import make_regression
22
+
21
23
  from daal4py.sklearn._utils import daal_check_version
24
+ from onedal.tests.utils._dataframes_support import (
25
+ _as_numpy,
26
+ _convert_to_dataframe,
27
+ get_dataframes_and_queues,
28
+ )
22
29
 
23
30
 
24
- def test_sklearnex_import_linear():
31
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
32
+ def test_sklearnex_import_linear(dataframe, queue):
25
33
  from sklearnex.linear_model import LinearRegression
34
+
26
35
  X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
27
36
  y = np.dot(X, np.array([1, 2])) + 3
37
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
38
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
28
39
  linreg = LinearRegression().fit(X, y)
29
- assert 'daal4py' in linreg.__module__
40
+ if daal_check_version((2023, "P", 100)):
41
+ assert hasattr(linreg, "_onedal_estimator")
42
+ assert "sklearnex" in linreg.__module__
30
43
  assert linreg.n_features_in_ == 2
31
- assert_allclose(linreg.intercept_, 3.)
32
- assert_allclose(linreg.coef_, [1., 2.])
44
+ assert_allclose(_as_numpy(linreg.intercept_), 3.0)
45
+ assert_allclose(_as_numpy(linreg.coef_), [1.0, 2.0])
33
46
 
34
47
 
35
48
  def test_sklearnex_import_ridge():
36
49
  from sklearnex.linear_model import Ridge
50
+
37
51
  X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
38
52
  y = np.dot(X, np.array([1, 2])) + 3
39
53
  ridgereg = Ridge().fit(X, y)
40
- assert 'daal4py' in ridgereg.__module__
54
+ assert "daal4py" in ridgereg.__module__
41
55
  assert_allclose(ridgereg.intercept_, 4.5)
42
56
  assert_allclose(ridgereg.coef_, [0.8, 1.4])
43
57
 
44
58
 
45
59
  def test_sklearnex_import_lasso():
46
60
  from sklearnex.linear_model import Lasso
61
+
47
62
  X = [[0, 0], [1, 1], [2, 2]]
48
63
  y = [0, 1, 2]
49
64
  lasso = Lasso(alpha=0.1).fit(X, y)
50
- assert 'daal4py' in lasso.__module__
65
+ assert "daal4py" in lasso.__module__
51
66
  assert_allclose(lasso.intercept_, 0.15)
52
67
  assert_allclose(lasso.coef_, [0.85, 0.0])
53
68
 
54
69
 
55
70
  def test_sklearnex_import_elastic():
56
71
  from sklearnex.linear_model import ElasticNet
72
+
57
73
  X, y = make_regression(n_features=2, random_state=0)
58
74
  elasticnet = ElasticNet(random_state=0).fit(X, y)
59
- assert 'daal4py' in elasticnet.__module__
75
+ assert "daal4py" in elasticnet.__module__
60
76
  assert_allclose(elasticnet.intercept_, 1.451, atol=1e-3)
61
77
  assert_allclose(elasticnet.coef_, [18.838, 64.559], atol=1e-3)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  import numpy as np
19
19
  from numpy.testing import assert_allclose
@@ -22,7 +22,8 @@ from sklearn.datasets import load_iris
22
22
 
23
23
  def test_sklearnex_import():
24
24
  from sklearnex.linear_model import LogisticRegression
25
+
25
26
  X, y = load_iris(return_X_y=True)
26
27
  logreg = LogisticRegression(random_state=0, max_iter=200).fit(X, y)
27
- assert 'daal4py' in logreg.__module__
28
+ assert "daal4py" in logreg.__module__
28
29
  assert_allclose(logreg.score(X, y), 0.9733, atol=1e-3)
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,8 +13,8 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from .t_sne import TSNE
19
19
 
20
- __all__ = ['TSNE']
20
+ __all__ = ["TSNE"]
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from daal4py.sklearn.manifold import TSNE
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  import numpy as np
19
19
  from numpy.testing import assert_allclose
@@ -21,6 +21,7 @@ from numpy.testing import assert_allclose
21
21
 
22
22
  def test_sklearnex_import():
23
23
  from sklearnex.manifold import TSNE
24
+
24
25
  X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
25
26
  tsne = TSNE(n_components=2, perplexity=2.0).fit(X)
26
- assert 'daal4py' in tsne.__module__
27
+ assert "daal4py" in tsne.__module__
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,12 +13,12 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
- from .ranking import roc_auc_score
19
18
  from .pairwise import pairwise_distances
19
+ from .ranking import roc_auc_score
20
20
 
21
21
  __all__ = [
22
- 'roc_auc_score',
23
- 'pairwise_distances',
22
+ "roc_auc_score",
23
+ "pairwise_distances",
24
24
  ]
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from daal4py.sklearn.metrics import pairwise_distances
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- #===============================================================================
2
+ # ===============================================================================
3
3
  # Copyright 2021 Intel Corporation
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,6 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- #===============================================================================
16
+ # ===============================================================================
17
17
 
18
18
  from daal4py.sklearn.metrics import roc_auc_score