scikit-learn-intelex 2024.1.0__py311-none-win_amd64.whl → 2024.3.0__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +6 -4
  3. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  4. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  5. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +130 -0
  6. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  7. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +338 -0
  8. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  9. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +91 -59
  10. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +15 -24
  11. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  12. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +1 -2
  13. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +3 -10
  14. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +32 -40
  15. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +91 -0
  16. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
  17. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +204 -0
  18. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +13 -18
  19. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +12 -17
  20. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +10 -15
  21. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +12 -16
  22. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  23. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +3 -8
  24. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +46 -12
  25. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
  26. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +19 -0
  27. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
  28. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  29. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
  30. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  31. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +9 -6
  32. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +6 -7
  33. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +9 -6
  34. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -4
  35. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +155 -0
  36. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +9 -7
  37. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  38. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +93 -0
  39. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  40. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +361 -0
  41. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  42. scikit_learn_intelex-2024.3.0.dist-info/RECORD +98 -0
  43. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  44. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  45. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
  46. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
  47. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  48. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -376
  49. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -19
  50. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  51. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  52. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -227
  53. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -31
  54. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -122
  55. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  56. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  57. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  58. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  59. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  60. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  61. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  62. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  63. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  65. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  66. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  67. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  69. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  70. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  71. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  72. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  73. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  74. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  75. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  76. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  77. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  78. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  79. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  81. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  82. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  83. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  84. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  85. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  86. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  87. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  89. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  91. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  92. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  93. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  94. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  95. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  96. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  98. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  99. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  100. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  101. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  102. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  104. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  105. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  106. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  107. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  108. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  109. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  110. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  111. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  112. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -14,32 +14,47 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
+ import warnings
18
+
19
+ import numpy as np
17
20
  from scipy import sparse as sp
18
21
  from sklearn.covariance import EmpiricalCovariance as sklearn_EmpiricalCovariance
19
22
  from sklearn.utils import check_array
20
23
 
21
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
24
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
25
+ from daal4py.sklearn._utils import sklearn_check_version
22
26
  from onedal.common.hyperparameters import get_hyperparameters
23
27
  from onedal.covariance import EmpiricalCovariance as onedal_EmpiricalCovariance
28
+ from sklearnex import config_context
29
+ from sklearnex.metrics import pairwise_distances
24
30
 
25
- from ..._device_offload import dispatch
31
+ from ..._device_offload import dispatch, wrap_output_data
26
32
  from ..._utils import PatchingConditionsChain, register_hyperparameters
27
33
 
28
34
 
29
35
  @register_hyperparameters({"fit": get_hyperparameters("covariance", "compute")})
30
- @control_n_jobs
36
+ @control_n_jobs(decorated_methods=["fit", "mahalanobis"])
31
37
  class EmpiricalCovariance(sklearn_EmpiricalCovariance):
32
38
  __doc__ = sklearn_EmpiricalCovariance.__doc__
33
39
 
40
+ if sklearn_check_version("1.2"):
41
+ _parameter_constraints: dict = {
42
+ **sklearn_EmpiricalCovariance._parameter_constraints,
43
+ }
44
+
34
45
  def _save_attributes(self):
35
46
  assert hasattr(self, "_onedal_estimator")
36
- self.covariance_ = self._onedal_estimator.covariance_
47
+ self._set_covariance(self._onedal_estimator.covariance_)
37
48
  self.location_ = self._onedal_estimator.location_
38
49
 
39
50
  _onedal_covariance = staticmethod(onedal_EmpiricalCovariance)
40
51
 
41
- @run_with_n_jobs
42
52
  def _onedal_fit(self, X, queue=None):
53
+ if X.shape[0] == 1:
54
+ warnings.warn(
55
+ "Only one sample available. You may want to reshape your data array"
56
+ )
57
+
43
58
  onedal_params = {
44
59
  "method": "dense",
45
60
  "bias": True,
@@ -54,7 +69,7 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
54
69
  patching_status = PatchingConditionsChain(
55
70
  f"sklearn.covariance.{class_name}.{method_name}"
56
71
  )
57
- if method_name == "fit":
72
+ if method_name in ["fit", "mahalanobis"]:
58
73
  (X,) = data
59
74
  patching_status.and_conditions(
60
75
  [
@@ -62,10 +77,6 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
62
77
  self.assume_centered == False,
63
78
  "assume_centered parameter is not supported on oneDAL side",
64
79
  ),
65
- (
66
- self.store_precision == False,
67
- "precision matrix calculation is not supported on oneDAL side",
68
- ),
69
80
  (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
70
81
  ]
71
82
  )
@@ -79,9 +90,9 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
79
90
  if sklearn_check_version("1.2"):
80
91
  self._validate_params()
81
92
  if sklearn_check_version("0.23"):
82
- self._validate_data(X)
93
+ X = self._validate_data(X, force_all_finite=False)
83
94
  else:
84
- check_array(X)
95
+ X = check_array(X, force_all_finite=False)
85
96
 
86
97
  dispatch(
87
98
  self,
@@ -95,4 +106,27 @@ class EmpiricalCovariance(sklearn_EmpiricalCovariance):
95
106
 
96
107
  return self
97
108
 
109
+ # expose sklearnex pairwise_distances if mahalanobis distance eventually supported
110
+ @wrap_output_data
111
+ def mahalanobis(self, X):
112
+ if sklearn_check_version("1.0"):
113
+ X = self._validate_data(X, reset=False)
114
+ else:
115
+ X = check_array(X)
116
+
117
+ precision = self.get_precision()
118
+ with config_context(assume_finite=True):
119
+ # compute mahalanobis distances
120
+ dist = pairwise_distances(
121
+ X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision
122
+ )
123
+
124
+ return np.reshape(dist, (len(X),)) ** 2
125
+
126
+ error_norm = wrap_output_data(sklearn_EmpiricalCovariance.error_norm)
127
+ score = wrap_output_data(sklearn_EmpiricalCovariance.score)
128
+
98
129
  fit.__doc__ = sklearn_EmpiricalCovariance.fit.__doc__
130
+ mahalanobis.__doc__ = sklearn_EmpiricalCovariance.mahalanobis
131
+ error_norm.__doc__ = sklearn_EmpiricalCovariance.error_norm.__doc__
132
+ score.__doc__ = sklearn_EmpiricalCovariance.score.__doc__
@@ -17,6 +17,7 @@
17
17
  __all__ = [
18
18
  "basic_statistics",
19
19
  "cluster",
20
+ "covariance",
20
21
  "decomposition",
21
22
  "ensemble",
22
23
  "linear_model",
@@ -0,0 +1,19 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from .covariance import EmpiricalCovariance
18
+
19
+ __all__ = ["EmpiricalCovariance"]
@@ -0,0 +1,21 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from onedal.spmd.covariance import EmpiricalCovariance
18
+
19
+ # TODO:
20
+ # Currently it uses `onedal` module interface.
21
+ # Add sklearnex dispatching.
@@ -14,8 +14,6 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
- from abc import ABC
18
-
19
17
  from onedal.spmd.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
20
18
  from onedal.spmd.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
21
19
 
@@ -23,16 +21,9 @@ from ...ensemble import RandomForestClassifier as RandomForestClassifier_Batch
23
21
  from ...ensemble import RandomForestRegressor as RandomForestRegressor_Batch
24
22
 
25
23
 
26
- class BaseForestSPMD(ABC):
27
- def _onedal_classifier(self, **onedal_params):
28
- return onedal_RandomForestClassifier(**onedal_params)
29
-
30
- def _onedal_regressor(self, **onedal_params):
31
- return onedal_RandomForestRegressor(**onedal_params)
32
-
33
-
34
- class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
24
+ class RandomForestClassifier(RandomForestClassifier_Batch):
35
25
  __doc__ = RandomForestClassifier_Batch.__doc__
26
+ _onedal_factory = onedal_RandomForestClassifier
36
27
 
37
28
  def _onedal_cpu_supported(self, method_name, *data):
38
29
  # TODO:
@@ -55,8 +46,9 @@ class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
55
46
  return ready
56
47
 
57
48
 
58
- class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch):
49
+ class RandomForestRegressor(RandomForestRegressor_Batch):
59
50
  __doc__ = RandomForestRegressor_Batch.__doc__
51
+ _onedal_factory = onedal_RandomForestRegressor
60
52
 
61
53
  def _onedal_cpu_supported(self, method_name, *data):
62
54
  # TODO:
@@ -15,5 +15,6 @@
15
15
  # ==============================================================================
16
16
 
17
17
  from .linear_model import LinearRegression
18
+ from .logistic_regression import LogisticRegression
18
19
 
19
- __all__ = ["LinearRegression"]
20
+ __all__ = ["LinearRegression", "LogisticRegression"]
@@ -0,0 +1,21 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from onedal.spmd.linear_model import LogisticRegression
18
+
19
+ # TODO:
20
+ # Currently it uses `onedal` module interface.
21
+ # Add sklearnex dispatching.
@@ -18,7 +18,8 @@ from sklearn.exceptions import NotFittedError
18
18
  from sklearn.svm import NuSVC as sklearn_NuSVC
19
19
  from sklearn.utils.validation import _deprecate_positional_args
20
20
 
21
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
21
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
22
+ from daal4py.sklearn._utils import sklearn_check_version
22
23
 
23
24
  from .._device_offload import dispatch, wrap_output_data
24
25
  from ._common import BaseSVC
@@ -29,7 +30,9 @@ if sklearn_check_version("1.0"):
29
30
  from onedal.svm import NuSVC as onedal_NuSVC
30
31
 
31
32
 
32
- @control_n_jobs
33
+ @control_n_jobs(
34
+ decorated_methods=["fit", "predict", "_predict_proba", "decision_function"]
35
+ )
33
36
  class NuSVC(sklearn_NuSVC, BaseSVC):
34
37
  __doc__ = sklearn_NuSVC.__doc__
35
38
 
@@ -195,6 +198,8 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
195
198
  self._check_proba()
196
199
  return self._predict_proba
197
200
 
201
+ predict_proba.__doc__ = sklearn_NuSVC.predict_proba.__doc__
202
+
198
203
  @wrap_output_data
199
204
  def _predict_proba(self, X):
200
205
  if sklearn_check_version("1.0"):
@@ -229,7 +234,8 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
229
234
  X,
230
235
  )
231
236
 
232
- @run_with_n_jobs
237
+ decision_function.__doc__ = sklearn_NuSVC.decision_function.__doc__
238
+
233
239
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
234
240
  onedal_params = {
235
241
  "nu": self.nu,
@@ -253,11 +259,9 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
253
259
  self._fit_proba(X, y, sample_weight, queue=queue)
254
260
  self._save_attributes()
255
261
 
256
- @run_with_n_jobs
257
262
  def _onedal_predict(self, X, queue=None):
258
263
  return self._onedal_estimator.predict(X, queue=queue)
259
264
 
260
- @run_with_n_jobs
261
265
  def _onedal_predict_proba(self, X, queue=None):
262
266
  if getattr(self, "clf_prob", None) is None:
263
267
  raise NotFittedError(
@@ -272,6 +276,5 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
272
276
  with config_context(**cfg):
273
277
  return self.clf_prob.predict_proba(X)
274
278
 
275
- @run_with_n_jobs
276
279
  def _onedal_decision_function(self, X, queue=None):
277
280
  return self._onedal_estimator.decision_function(X, queue=queue)
@@ -17,14 +17,15 @@
17
17
  from sklearn.svm import NuSVR as sklearn_NuSVR
18
18
  from sklearn.utils.validation import _deprecate_positional_args
19
19
 
20
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
20
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
21
+ from daal4py.sklearn._utils import sklearn_check_version
21
22
  from onedal.svm import NuSVR as onedal_NuSVR
22
23
 
23
24
  from .._device_offload import dispatch, wrap_output_data
24
25
  from ._common import BaseSVR
25
26
 
26
27
 
27
- @control_n_jobs
28
+ @control_n_jobs(decorated_methods=["fit", "predict"])
28
29
  class NuSVR(sklearn_NuSVR, BaseSVR):
29
30
  __doc__ = sklearn_NuSVR.__doc__
30
31
 
@@ -35,14 +36,14 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
35
36
  def __init__(
36
37
  self,
37
38
  *,
39
+ nu=0.5,
40
+ C=1.0,
38
41
  kernel="rbf",
39
42
  degree=3,
40
43
  gamma="scale",
41
44
  coef0=0.0,
42
- tol=1e-3,
43
- C=1.0,
44
- nu=0.5,
45
45
  shrinking=True,
46
+ tol=1e-3,
46
47
  cache_size=200,
47
48
  verbose=False,
48
49
  max_iter=-1,
@@ -142,7 +143,6 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
142
143
  X,
143
144
  )
144
145
 
145
- @run_with_n_jobs
146
146
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
147
147
  onedal_params = {
148
148
  "C": self.C,
@@ -161,6 +161,5 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
161
161
  self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
162
162
  self._save_attributes()
163
163
 
164
- @run_with_n_jobs
165
164
  def _onedal_predict(self, X, queue=None):
166
165
  return self._onedal_estimator.predict(X, queue=queue)
@@ -20,7 +20,8 @@ from sklearn.exceptions import NotFittedError
20
20
  from sklearn.svm import SVC as sklearn_SVC
21
21
  from sklearn.utils.validation import _deprecate_positional_args
22
22
 
23
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
23
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
24
+ from daal4py.sklearn._utils import sklearn_check_version
24
25
 
25
26
  from .._device_offload import dispatch, wrap_output_data
26
27
  from .._utils import PatchingConditionsChain
@@ -32,7 +33,9 @@ if sklearn_check_version("1.0"):
32
33
  from onedal.svm import SVC as onedal_SVC
33
34
 
34
35
 
35
- @control_n_jobs
36
+ @control_n_jobs(
37
+ decorated_methods=["fit", "predict", "_predict_proba", "decision_function"]
38
+ )
36
39
  class SVC(sklearn_SVC, BaseSVC):
37
40
  __doc__ = sklearn_SVC.__doc__
38
41
 
@@ -197,6 +200,8 @@ class SVC(sklearn_SVC, BaseSVC):
197
200
  self._check_proba()
198
201
  return self._predict_proba
199
202
 
203
+ predict_proba.__doc__ = sklearn_SVC.predict_proba.__doc__
204
+
200
205
  @wrap_output_data
201
206
  def _predict_proba(self, X):
202
207
  sklearn_pred_proba = (
@@ -229,6 +234,8 @@ class SVC(sklearn_SVC, BaseSVC):
229
234
  X,
230
235
  )
231
236
 
237
+ decision_function.__doc__ = sklearn_SVC.decision_function.__doc__
238
+
232
239
  def _onedal_gpu_supported(self, method_name, *data):
233
240
  class_name = self.__class__.__name__
234
241
  patching_status = PatchingConditionsChain(
@@ -258,7 +265,6 @@ class SVC(sklearn_SVC, BaseSVC):
258
265
  return patching_status
259
266
  raise RuntimeError(f"Unknown method {method_name} in {class_name}")
260
267
 
261
- @run_with_n_jobs
262
268
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
263
269
  onedal_params = {
264
270
  "C": self.C,
@@ -282,11 +288,9 @@ class SVC(sklearn_SVC, BaseSVC):
282
288
  self._fit_proba(X, y, sample_weight, queue=queue)
283
289
  self._save_attributes()
284
290
 
285
- @run_with_n_jobs
286
291
  def _onedal_predict(self, X, queue=None):
287
292
  return self._onedal_estimator.predict(X, queue=queue)
288
293
 
289
- @run_with_n_jobs
290
294
  def _onedal_predict_proba(self, X, queue=None):
291
295
  if getattr(self, "clf_prob", None) is None:
292
296
  raise NotFittedError(
@@ -301,6 +305,5 @@ class SVC(sklearn_SVC, BaseSVC):
301
305
  with config_context(**cfg):
302
306
  return self.clf_prob.predict_proba(X)
303
307
 
304
- @run_with_n_jobs
305
308
  def _onedal_decision_function(self, X, queue=None):
306
309
  return self._onedal_estimator.decision_function(X, queue=queue)
@@ -17,14 +17,15 @@
17
17
  from sklearn.svm import SVR as sklearn_SVR
18
18
  from sklearn.utils.validation import _deprecate_positional_args
19
19
 
20
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
20
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
21
+ from daal4py.sklearn._utils import sklearn_check_version
21
22
  from onedal.svm import SVR as onedal_SVR
22
23
 
23
24
  from .._device_offload import dispatch, wrap_output_data
24
25
  from ._common import BaseSVR
25
26
 
26
27
 
27
- @control_n_jobs
28
+ @control_n_jobs(decorated_methods=["fit", "predict"])
28
29
  class SVR(sklearn_SVR, BaseSVR):
29
30
  __doc__ = sklearn_SVR.__doc__
30
31
 
@@ -143,7 +144,6 @@ class SVR(sklearn_SVR, BaseSVR):
143
144
  X,
144
145
  )
145
146
 
146
- @run_with_n_jobs
147
147
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
148
148
  onedal_params = {
149
149
  "C": self.C,
@@ -162,6 +162,5 @@ class SVR(sklearn_SVR, BaseSVR):
162
162
  self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
163
163
  self._save_attributes()
164
164
 
165
- @run_with_n_jobs
166
165
  def _onedal_predict(self, X, queue=None):
167
166
  return self._onedal_estimator.predict(X, queue=queue)
@@ -0,0 +1,155 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from inspect import isclass
18
+
19
+ import numpy as np
20
+ from sklearn.base import (
21
+ BaseEstimator,
22
+ ClassifierMixin,
23
+ ClusterMixin,
24
+ OutlierMixin,
25
+ RegressorMixin,
26
+ TransformerMixin,
27
+ )
28
+ from sklearn.datasets import load_diabetes, load_iris
29
+ from sklearn.neighbors._base import KNeighborsMixin
30
+
31
+ from onedal.tests.utils._dataframes_support import _convert_to_dataframe
32
+ from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
33
+ from sklearnex.neighbors import (
34
+ KNeighborsClassifier,
35
+ KNeighborsRegressor,
36
+ LocalOutlierFactor,
37
+ NearestNeighbors,
38
+ )
39
+ from sklearnex.svm import SVC, NuSVC
40
+
41
+
42
+ def _load_all_models(with_sklearnex=True, estimator=True):
43
+ # insure that patch state is correct as dictated by patch_sklearn boolean
44
+ # and return it to the previous state no matter what occurs.
45
+ already_patched_map = sklearn_is_patched(return_map=True)
46
+ already_patched = any(already_patched_map.values())
47
+ try:
48
+ if with_sklearnex:
49
+ patch_sklearn()
50
+ elif already_patched:
51
+ unpatch_sklearn()
52
+
53
+ models = {}
54
+ for patch_infos in get_patch_map().values():
55
+ candidate = getattr(patch_infos[0][0][0], patch_infos[0][0][1], None)
56
+ if candidate is not None and isclass(candidate) == estimator:
57
+ if not estimator or issubclass(candidate, BaseEstimator):
58
+ models[patch_infos[0][0][1]] = candidate
59
+ finally:
60
+ if with_sklearnex:
61
+ unpatch_sklearn()
62
+ # both branches are now in an unpatched state, repatch as necessary
63
+ if already_patched:
64
+ patch_sklearn(name=[i for i in already_patched_map if already_patched_map[i]])
65
+
66
+ return models
67
+
68
+
69
+ PATCHED_MODELS = _load_all_models(with_sklearnex=True)
70
+ UNPATCHED_MODELS = _load_all_models(with_sklearnex=False)
71
+
72
+ PATCHED_FUNCTIONS = _load_all_models(with_sklearnex=True, estimator=False)
73
+ UNPATCHED_FUNCTIONS = _load_all_models(with_sklearnex=False, estimator=False)
74
+
75
+ mixin_map = [
76
+ [
77
+ ClassifierMixin,
78
+ ["decision_function", "predict", "predict_proba", "predict_log_proba", "score"],
79
+ "classification",
80
+ ],
81
+ [RegressorMixin, ["predict", "score"], "regression"],
82
+ [ClusterMixin, ["fit_predict"], "classification"],
83
+ [TransformerMixin, ["fit_transform", "transform", "score"], "classification"],
84
+ [OutlierMixin, ["fit_predict", "predict"], "classification"],
85
+ [KNeighborsMixin, ["kneighbors"], None],
86
+ ]
87
+
88
+
89
+ SPECIAL_INSTANCES = {
90
+ str(i): i
91
+ for i in [
92
+ LocalOutlierFactor(novelty=True),
93
+ SVC(probability=True),
94
+ NuSVC(probability=True),
95
+ KNeighborsClassifier(algorithm="brute"),
96
+ KNeighborsRegressor(algorithm="brute"),
97
+ NearestNeighbors(algorithm="brute"),
98
+ ]
99
+ }
100
+
101
+
102
+ def gen_models_info(algorithms):
103
+ output = []
104
+ for i in algorithms:
105
+ # split handles SPECIAL_INSTANCES or custom inputs
106
+ # custom sklearn inputs must be a dict of estimators
107
+ # with keys set by the __str__ method
108
+ est = PATCHED_MODELS[i.split("(")[0]]
109
+
110
+ methods = set()
111
+ candidates = set(
112
+ [i for i in dir(est) if not i.startswith("_") and not i.endswith("_")]
113
+ )
114
+
115
+ for mixin, method, _ in mixin_map:
116
+ if issubclass(est, mixin):
117
+ methods |= candidates & set(method)
118
+
119
+ output += [[i, j] for j in methods]
120
+ return output
121
+
122
+
123
+ def gen_dataset(estimator, queue=None, target_df=None, dtype=np.float64):
124
+ dataset = None
125
+ name = estimator.__class__.__name__
126
+ est = PATCHED_MODELS[name]
127
+ for mixin, _, data in mixin_map:
128
+ if issubclass(est, mixin) and data is not None:
129
+ dataset = data
130
+ # load data
131
+ if dataset == "classification" or dataset is None:
132
+ X, y = load_iris(return_X_y=True)
133
+ elif dataset == "regression":
134
+ X, y = load_diabetes(return_X_y=True)
135
+ else:
136
+ raise ValueError("Unknown dataset type")
137
+
138
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=target_df, dtype=dtype)
139
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=target_df, dtype=dtype)
140
+ return X, y
141
+
142
+
143
+ DTYPES = [
144
+ np.int8,
145
+ np.int16,
146
+ np.int32,
147
+ np.int64,
148
+ np.float16,
149
+ np.float32,
150
+ np.float64,
151
+ np.uint8,
152
+ np.uint16,
153
+ np.uint32,
154
+ np.uint64,
155
+ ]
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+
17
18
  import gc
18
19
  import logging
19
20
  import tracemalloc
@@ -30,7 +31,6 @@ from sklearn.model_selection import KFold
30
31
  from sklearnex import get_patch_map
31
32
  from sklearnex.metrics import pairwise_distances, roc_auc_score
32
33
  from sklearnex.model_selection import train_test_split
33
- from sklearnex.preview.decomposition import PCA as PreviewPCA
34
34
  from sklearnex.utils import _assert_all_finite
35
35
 
36
36
 
@@ -75,6 +75,8 @@ class RocAucEstimator:
75
75
 
76
76
 
77
77
  # add all daal4py estimators enabled in patching (except banned)
78
+
79
+
78
80
  def get_patched_estimators(ban_list, output_list):
79
81
  patched_estimators = get_patch_map().values()
80
82
  for listing in patched_estimators:
@@ -94,12 +96,8 @@ def remove_duplicated_estimators(estimators_list):
94
96
  return estimators_map.values()
95
97
 
96
98
 
97
- BANNED_ESTIMATORS = (
98
- "LocalOutlierFactor", # fails on ndarray_c for sklearn > 1.0
99
- "TSNE", # too slow for using in testing on common data size
100
- )
99
+ BANNED_ESTIMATORS = ("TSNE",) # too slow for using in testing on common data size
101
100
  estimators = [
102
- PreviewPCA,
103
101
  TrainTestSplitEstimator,
104
102
  FiniteCheckEstimator,
105
103
  CosineDistancesEstimator,
@@ -156,6 +154,7 @@ def split_train_inference(kf, x, y, estimator):
156
154
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]
157
155
  # TODO: add parameters for all estimators to prevent
158
156
  # fallback to stock scikit-learn with default parameters
157
+
159
158
  alg = estimator()
160
159
  alg.fit(x_train, y_train)
161
160
  if hasattr(alg, "predict"):
@@ -166,7 +165,6 @@ def split_train_inference(kf, x, y, estimator):
166
165
  alg.kneighbors(x_test)
167
166
  del alg, x_train, x_test, y_train, y_test
168
167
  mem_tracks.append(tracemalloc.get_traced_memory()[0])
169
-
170
168
  return mem_tracks
171
169
 
172
170
 
@@ -218,6 +216,10 @@ def _kfold_function_template(estimator, data_transform_function, data_shape):
218
216
  )
219
217
 
220
218
 
219
+ # disable fallback check as logging impacts memory use
220
+
221
+
222
+ @pytest.mark.allow_sklearn_fallback
221
223
  @pytest.mark.parametrize("data_transform_function", data_transforms)
222
224
  @pytest.mark.parametrize("estimator", estimators)
223
225
  @pytest.mark.parametrize("data_shape", data_shapes)