scikit-learn-intelex 2024.4.0__py312-none-win_amd64.whl → 2024.6.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (113) hide show
  1. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +8 -1
  2. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +2 -4
  3. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -0
  4. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
  5. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/conftest.py +11 -1
  6. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +317 -0
  7. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +54 -17
  8. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +68 -13
  9. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +6 -4
  10. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +46 -1
  11. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +114 -22
  12. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +13 -3
  13. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
  14. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +5 -3
  15. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +464 -0
  16. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +27 -9
  17. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +13 -15
  18. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +200 -0
  19. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +2 -2
  20. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +24 -0
  21. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +2 -2
  22. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  23. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  24. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +228 -0
  25. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  26. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +330 -0
  27. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +40 -4
  28. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +31 -2
  29. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +40 -4
  30. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +31 -2
  31. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/_utils.py +70 -29
  32. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +54 -0
  33. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +290 -0
  34. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +4 -0
  35. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +22 -10
  36. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +283 -0
  37. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/_namespace.py +1 -1
  38. scikit_learn_intelex-2024.6.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
  39. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/METADATA +230 -230
  40. scikit_learn_intelex-2024.6.0.dist-info/RECORD +108 -0
  41. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/WHEEL +1 -1
  42. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -130
  43. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
  44. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -227
  45. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
  46. scikit_learn_intelex-2024.4.0.dist-info/RECORD +0 -101
  47. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  48. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  49. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  50. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  51. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  52. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  53. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +0 -0
  54. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  55. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  56. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  57. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  58. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  59. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  60. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  61. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  62. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  63. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  64. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  65. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -0
  66. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  67. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  68. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  69. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  70. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  71. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  72. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  73. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  74. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  75. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  76. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  77. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +0 -0
  78. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  79. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -0
  80. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +0 -0
  81. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  83. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  84. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  86. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  87. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  90. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  91. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  92. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  93. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  94. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  95. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  96. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  97. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  98. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
  99. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  100. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  101. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  102. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  104. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  105. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  106. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  107. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -0
  108. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  109. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  110. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  111. {scikit_learn_intelex-2024.4.0.data → scikit_learn_intelex-2024.6.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  112. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/LICENSE.txt +0 -0
  113. {scikit_learn_intelex-2024.4.0.dist-info → scikit_learn_intelex-2024.6.0.dist-info}/top_level.txt +0 -0
@@ -65,6 +65,17 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
65
65
  def fit(self, X, y, sample_weight=None):
66
66
  if sklearn_check_version("1.2"):
67
67
  self._validate_params()
68
+ elif self.nu <= 0 or self.nu > 1:
69
+ # else if added to correct issues with
70
+ # sklearn tests:
71
+ # svm/tests/test_sparse.py::test_error
72
+ # svm/tests/test_svm.py::test_bad_input
73
+ # for sklearn versions < 1.2 (i.e. without
74
+ # validate_params parameter checking)
75
+ # Without this, a segmentation fault with
76
+ # Windows fatal exception: access violation
77
+ # occurs
78
+ raise ValueError("nu <= 0 or nu > 1")
68
79
  if sklearn_check_version("1.0"):
69
80
  self._check_feature_names(X, reset=True)
70
81
  dispatch(
@@ -76,7 +87,7 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
76
87
  },
77
88
  X,
78
89
  y,
79
- sample_weight,
90
+ sample_weight=sample_weight,
80
91
  )
81
92
  return self
82
93
 
@@ -94,13 +105,30 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
94
105
  X,
95
106
  )
96
107
 
108
+ @wrap_output_data
109
+ def score(self, X, y, sample_weight=None):
110
+ if sklearn_check_version("1.0"):
111
+ self._check_feature_names(X, reset=False)
112
+ return dispatch(
113
+ self,
114
+ "score",
115
+ {
116
+ "onedal": self.__class__._onedal_score,
117
+ "sklearn": sklearn_NuSVR.score,
118
+ },
119
+ X,
120
+ y,
121
+ sample_weight=sample_weight,
122
+ )
123
+
97
124
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
125
+ X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
98
126
  onedal_params = {
99
127
  "C": self.C,
100
128
  "nu": self.nu,
101
129
  "kernel": self.kernel,
102
130
  "degree": self.degree,
103
- "gamma": self.gamma,
131
+ "gamma": self._compute_gamma_sigma(X),
104
132
  "coef0": self.coef0,
105
133
  "tol": self.tol,
106
134
  "shrinking": self.shrinking,
@@ -117,3 +145,4 @@ class NuSVR(sklearn_NuSVR, BaseSVR):
117
145
 
118
146
  fit.__doc__ = sklearn_NuSVR.fit.__doc__
119
147
  predict.__doc__ = sklearn_NuSVR.predict.__doc__
148
+ score.__doc__ = sklearn_NuSVR.score.__doc__
@@ -85,6 +85,17 @@ class SVC(sklearn_SVC, BaseSVC):
85
85
  def fit(self, X, y, sample_weight=None):
86
86
  if sklearn_check_version("1.2"):
87
87
  self._validate_params()
88
+ elif self.C <= 0:
89
+ # else if added to correct issues with
90
+ # sklearn tests:
91
+ # svm/tests/test_sparse.py::test_error
92
+ # svm/tests/test_svm.py::test_bad_input
93
+ # for sklearn versions < 1.2 (i.e. without
94
+ # validate_params parameter checking)
95
+ # Without this, a segmentation fault with
96
+ # Windows fatal exception: access violation
97
+ # occurs
98
+ raise ValueError("C <= 0")
88
99
  if sklearn_check_version("1.0"):
89
100
  self._check_feature_names(X, reset=True)
90
101
  dispatch(
@@ -96,8 +107,9 @@ class SVC(sklearn_SVC, BaseSVC):
96
107
  },
97
108
  X,
98
109
  y,
99
- sample_weight,
110
+ sample_weight=sample_weight,
100
111
  )
112
+
101
113
  return self
102
114
 
103
115
  @wrap_output_data
@@ -270,12 +282,30 @@ class SVC(sklearn_SVC, BaseSVC):
270
282
  return patching_status
271
283
  raise RuntimeError(f"Unknown method {method_name} in {class_name}")
272
284
 
285
+ def _get_sample_weight(self, X, y, sample_weight=None):
286
+ sample_weight = super()._get_sample_weight(X, y, sample_weight)
287
+ if sample_weight is None:
288
+ return sample_weight
289
+
290
+ if np.any(sample_weight <= 0) and len(np.unique(y[sample_weight > 0])) != len(
291
+ self.classes_
292
+ ):
293
+ raise ValueError(
294
+ "Invalid input - all samples with positive weights "
295
+ "belong to the same class"
296
+ if sklearn_check_version("1.2")
297
+ else "Invalid input - all samples with positive weights "
298
+ "have the same label."
299
+ )
300
+ return sample_weight
301
+
273
302
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
303
+ X, _, weights = self._onedal_fit_checks(X, y, sample_weight)
274
304
  onedal_params = {
275
305
  "C": self.C,
276
306
  "kernel": self.kernel,
277
307
  "degree": self.degree,
278
- "gamma": self.gamma,
308
+ "gamma": self._compute_gamma_sigma(X),
279
309
  "coef0": self.coef0,
280
310
  "tol": self.tol,
281
311
  "shrinking": self.shrinking,
@@ -287,10 +317,16 @@ class SVC(sklearn_SVC, BaseSVC):
287
317
  }
288
318
 
289
319
  self._onedal_estimator = onedal_SVC(**onedal_params)
290
- self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
320
+ self._onedal_estimator.fit(X, y, weights, queue=queue)
291
321
 
292
322
  if self.probability:
293
- self._fit_proba(X, y, sample_weight, queue=queue)
323
+ self._fit_proba(
324
+ X,
325
+ y,
326
+ sample_weight=sample_weight,
327
+ queue=queue,
328
+ )
329
+
294
330
  self._save_attributes()
295
331
 
296
332
  def _onedal_predict(self, X, queue=None):
@@ -65,6 +65,17 @@ class SVR(sklearn_SVR, BaseSVR):
65
65
  def fit(self, X, y, sample_weight=None):
66
66
  if sklearn_check_version("1.2"):
67
67
  self._validate_params()
68
+ elif self.C <= 0:
69
+ # else if added to correct issues with
70
+ # sklearn tests:
71
+ # svm/tests/test_sparse.py::test_error
72
+ # svm/tests/test_svm.py::test_bad_input
73
+ # for sklearn versions < 1.2 (i.e. without
74
+ # validate_params parameter checking)
75
+ # Without this, a segmentation fault with
76
+ # Windows fatal exception: access violation
77
+ # occurs
78
+ raise ValueError("C <= 0")
68
79
  if sklearn_check_version("1.0"):
69
80
  self._check_feature_names(X, reset=True)
70
81
  dispatch(
@@ -76,7 +87,7 @@ class SVR(sklearn_SVR, BaseSVR):
76
87
  },
77
88
  X,
78
89
  y,
79
- sample_weight,
90
+ sample_weight=sample_weight,
80
91
  )
81
92
 
82
93
  return self
@@ -95,13 +106,30 @@ class SVR(sklearn_SVR, BaseSVR):
95
106
  X,
96
107
  )
97
108
 
109
+ @wrap_output_data
110
+ def score(self, X, y, sample_weight=None):
111
+ if sklearn_check_version("1.0"):
112
+ self._check_feature_names(X, reset=False)
113
+ return dispatch(
114
+ self,
115
+ "score",
116
+ {
117
+ "onedal": self.__class__._onedal_score,
118
+ "sklearn": sklearn_SVR.score,
119
+ },
120
+ X,
121
+ y,
122
+ sample_weight=sample_weight,
123
+ )
124
+
98
125
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
126
+ X, _, sample_weight = self._onedal_fit_checks(X, y, sample_weight)
99
127
  onedal_params = {
100
128
  "C": self.C,
101
129
  "epsilon": self.epsilon,
102
130
  "kernel": self.kernel,
103
131
  "degree": self.degree,
104
- "gamma": self.gamma,
132
+ "gamma": self._compute_gamma_sigma(X),
105
133
  "coef0": self.coef0,
106
134
  "tol": self.tol,
107
135
  "shrinking": self.shrinking,
@@ -118,3 +146,4 @@ class SVR(sklearn_SVR, BaseSVR):
118
146
 
119
147
  fit.__doc__ = sklearn_SVR.fit.__doc__
120
148
  predict.__doc__ = sklearn_SVR.predict.__doc__
149
+ score.__doc__ = sklearn_SVR.score.__doc__
@@ -14,9 +14,12 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
+ from functools import partial
17
18
  from inspect import isclass
18
19
 
19
20
  import numpy as np
21
+ from scipy import sparse as sp
22
+ from sklearn import clone
20
23
  from sklearn.base import (
21
24
  BaseEstimator,
22
25
  ClassifierMixin,
@@ -87,18 +90,26 @@ mixin_map = [
87
90
  ]
88
91
 
89
92
 
90
- SPECIAL_INSTANCES = {
91
- str(i): i
92
- for i in [
93
- LocalOutlierFactor(novelty=True),
94
- SVC(probability=True),
95
- NuSVC(probability=True),
96
- KNeighborsClassifier(algorithm="brute"),
97
- KNeighborsRegressor(algorithm="brute"),
98
- NearestNeighbors(algorithm="brute"),
99
- LogisticRegression(solver="newton-cg"),
100
- ]
101
- }
93
+ class _sklearn_clone_dict(dict):
94
+
95
+ def __getitem__(self, key):
96
+ return clone(super().__getitem__(key))
97
+
98
+
99
+ SPECIAL_INSTANCES = _sklearn_clone_dict(
100
+ {
101
+ str(i): i
102
+ for i in [
103
+ LocalOutlierFactor(novelty=True),
104
+ SVC(probability=True),
105
+ NuSVC(probability=True),
106
+ KNeighborsClassifier(algorithm="brute"),
107
+ KNeighborsRegressor(algorithm="brute"),
108
+ NearestNeighbors(algorithm="brute"),
109
+ LogisticRegression(solver="newton-cg"),
110
+ ]
111
+ }
112
+ )
102
113
 
103
114
 
104
115
  def gen_models_info(algorithms):
@@ -107,8 +118,8 @@ def gen_models_info(algorithms):
107
118
 
108
119
  if i in PATCHED_MODELS:
109
120
  est = PATCHED_MODELS[i]
110
- elif i in SPECIAL_INSTANCES:
111
- est = SPECIAL_INSTANCES[i].__class__
121
+ elif isinstance(algorithms[i], BaseEstimator):
122
+ est = algorithms[i].__class__
112
123
  else:
113
124
  raise KeyError(f"Unrecognized sklearnex estimator: {i}")
114
125
 
@@ -129,24 +140,54 @@ def gen_models_info(algorithms):
129
140
  return output
130
141
 
131
142
 
132
- def gen_dataset(estimator, queue=None, target_df=None, dtype=np.float64):
133
- dataset = None
134
- name = estimator.__class__.__name__
135
- est = PATCHED_MODELS[name]
143
+ def gen_dataset_type(est):
144
+ # est should be an estimator or estimator class
145
+ # dataset initialized to classification, but will be swapped
146
+ # for other types as necessary
147
+ dataset = "classification"
148
+ estimator = est.__class__ if isinstance(est, BaseEstimator) else est
149
+
136
150
  for mixin, _, data in mixin_map:
137
- if issubclass(est, mixin) and data is not None:
151
+ if issubclass(estimator, mixin) and data is not None:
138
152
  dataset = data
153
+ return dataset
154
+
155
+
156
+ _dataset_dict = {
157
+ "classification": [partial(load_iris, return_X_y=True)],
158
+ "regression": [partial(load_diabetes, return_X_y=True)],
159
+ }
160
+
161
+
162
+ def gen_dataset(
163
+ est,
164
+ datasets=_dataset_dict,
165
+ sparse=False,
166
+ queue=None,
167
+ target_df=None,
168
+ dtype=None,
169
+ ):
170
+ dataset_type = gen_dataset_type(est)
171
+ output = []
139
172
  # load data
140
- if dataset == "classification" or dataset is None:
141
- X, y = load_iris(return_X_y=True)
142
- elif dataset == "regression":
143
- X, y = load_diabetes(return_X_y=True)
144
- else:
145
- raise ValueError("Unknown dataset type")
146
-
147
- X = _convert_to_dataframe(X, sycl_queue=queue, target_df=target_df, dtype=dtype)
148
- y = _convert_to_dataframe(y, sycl_queue=queue, target_df=target_df, dtype=dtype)
149
- return X, y
173
+ flag = dtype is None
174
+
175
+ for func in datasets[dataset_type]:
176
+ X, y = func()
177
+ if flag:
178
+ dtype = X.dtype if hasattr(X, "dtype") else np.float64
179
+
180
+ if sparse:
181
+ X = sp.csr_matrix(X)
182
+ else:
183
+ X = _convert_to_dataframe(
184
+ X, sycl_queue=queue, target_df=target_df, dtype=dtype
185
+ )
186
+ y = _convert_to_dataframe(
187
+ y, sycl_queue=queue, target_df=target_df, dtype=dtype
188
+ )
189
+ output += [[X, y]]
190
+ return output
150
191
 
151
192
 
152
193
  DTYPES = [
@@ -0,0 +1,54 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import os
18
+ from glob import glob
19
+
20
+ import pytest
21
+
22
+ ALLOWED_LOCATIONS = [
23
+ "_config.py",
24
+ "_device_offload.py",
25
+ "test",
26
+ "svc.py",
27
+ "svm" + os.sep + "_common.py",
28
+ ]
29
+
30
+
31
+ def test_target_offload_ban():
32
+ """This test blocks the use of target_offload in
33
+ in sklearnex files. Offloading computation to devices
34
+ via target_offload should only occur externally, and not
35
+ within the architecture of the sklearnex classes. This
36
+ is for clarity, traceability and maintainability.
37
+ """
38
+ from sklearnex import __file__ as loc
39
+
40
+ path = loc.replace("__init__.py", "")
41
+ files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]
42
+
43
+ output = []
44
+
45
+ for f in files:
46
+ if open(f, "r").read().find("target_offload") != -1:
47
+ output += [f.replace(path, "sklearnex" + os.sep)]
48
+
49
+ # remove this file from the list
50
+ for allowed in ALLOWED_LOCATIONS:
51
+ output = [i for i in output if allowed not in i]
52
+
53
+ output = "\n".join(output)
54
+ assert output == "", f"sklearn versioning is occuring in: \n{output}"
@@ -0,0 +1,290 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import gc
18
+ import logging
19
+ import os
20
+ import tracemalloc
21
+ import types
22
+ import warnings
23
+ from inspect import isclass
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import pytest
28
+ from scipy.stats import pearsonr
29
+ from sklearn.base import BaseEstimator, clone
30
+ from sklearn.datasets import make_classification
31
+ from sklearn.model_selection import KFold
32
+
33
+ from onedal import _is_dpc_backend
34
+ from onedal.tests.utils._dataframes_support import (
35
+ _convert_to_dataframe,
36
+ get_dataframes_and_queues,
37
+ )
38
+ from onedal.tests.utils._device_selection import get_queues, is_dpctl_available
39
+ from sklearnex import config_context
40
+ from sklearnex.tests._utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES
41
+ from sklearnex.utils import get_namespace
42
+
43
+ if _is_dpc_backend:
44
+ from onedal import _backend
45
+
46
+
47
+ CPU_SKIP_LIST = (
48
+ "TSNE", # too slow for using in testing on common data size
49
+ "config_context", # does not malloc
50
+ "get_config", # does not malloc
51
+ "set_config", # does not malloc
52
+ "SVC(probability=True)", # memory leak fortran numpy (investigate _fit_proba)
53
+ "NuSVC(probability=True)", # memory leak fortran numpy (investigate _fit_proba)
54
+ "IncrementalEmpiricalCovariance", # dataframe_f issues
55
+ "IncrementalLinearRegression", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f
56
+ "IncrementalPCA", # TODO fix memory leak issue in private CI for data_shape = (1000, 100), data_transform_function = dataframe_f
57
+ "LogisticRegression(solver='newton-cg')", # memory leak fortran (1000, 100)
58
+ )
59
+
60
+ GPU_SKIP_LIST = (
61
+ "TSNE", # too slow for using in testing on common data size
62
+ "RandomForestRegressor", # too slow for using in testing on common data size
63
+ "KMeans", # does not support GPU offloading
64
+ "config_context", # does not malloc
65
+ "get_config", # does not malloc
66
+ "set_config", # does not malloc
67
+ "Ridge", # does not support GPU offloading (fails silently)
68
+ "ElasticNet", # does not support GPU offloading (fails silently)
69
+ "Lasso", # does not support GPU offloading (fails silently)
70
+ "SVR", # does not support GPU offloading (fails silently)
71
+ "NuSVR", # does not support GPU offloading (fails silently)
72
+ "NuSVC", # does not support GPU offloading (fails silently)
73
+ "LogisticRegression", # default parameters not supported, see solver=newton-cg
74
+ "NuSVC(probability=True)", # does not support GPU offloading (fails silently)
75
+ "IncrementalLinearRegression", # issue with potrf with the specific dataset
76
+ "LinearRegression", # issue with potrf with the specific dataset
77
+ )
78
+
79
+
80
+ def gen_functions(functions):
81
+ func_dict = functions.copy()
82
+
83
+ roc_auc_score = func_dict.pop("roc_auc_score")
84
+ func_dict["roc_auc_score"] = lambda x, y: roc_auc_score(y, y)
85
+
86
+ pairwise_distances = func_dict.pop("pairwise_distances")
87
+ func_dict["pairwise_distances(metric='cosine')"] = lambda x, y: pairwise_distances(
88
+ x, metric="cosine"
89
+ )
90
+ func_dict["pairwise_distances(metric='correlation')"] = (
91
+ lambda x, y: pairwise_distances(x, metric="correlation")
92
+ )
93
+
94
+ _assert_all_finite = func_dict.pop("_assert_all_finite")
95
+ func_dict["_assert_all_finite"] = lambda x, y: [
96
+ _assert_all_finite(x),
97
+ _assert_all_finite(y),
98
+ ]
99
+ return func_dict
100
+
101
+
102
+ FUNCTIONS = gen_functions(PATCHED_FUNCTIONS)
103
+
104
+ CPU_ESTIMATORS = {
105
+ k: v
106
+ for k, v in {**PATCHED_MODELS, **SPECIAL_INSTANCES, **FUNCTIONS}.items()
107
+ if not k in CPU_SKIP_LIST
108
+ }
109
+
110
+ GPU_ESTIMATORS = {
111
+ k: v
112
+ for k, v in {**PATCHED_MODELS, **SPECIAL_INSTANCES}.items()
113
+ if not k in GPU_SKIP_LIST
114
+ }
115
+
116
+ data_shapes = [
117
+ pytest.param((1000, 100), id="(1000, 100)"),
118
+ pytest.param((2000, 50), id="(2000, 50)"),
119
+ ]
120
+
121
+ EXTRA_MEMORY_THRESHOLD = 0.15
122
+ N_SPLITS = 10
123
+ ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray}
124
+
125
+
126
+ def gen_clsf_data(n_samples, n_features):
127
+ data, label = make_classification(
128
+ n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777
129
+ )
130
+ return (
131
+ data,
132
+ label,
133
+ data.size * data.dtype.itemsize + label.size * label.dtype.itemsize,
134
+ )
135
+
136
+
137
+ def get_traced_memory(queue=None):
138
+ if _is_dpc_backend and queue and queue.sycl_device.is_gpu:
139
+ return _backend.get_used_memory(queue)
140
+ else:
141
+ return tracemalloc.get_traced_memory()[0]
142
+
143
+
144
+ def take(x, index, axis=0, queue=None):
145
+ xp, array_api = get_namespace(x)
146
+ if array_api:
147
+ return xp.take(x, xp.asarray(index, device=queue), axis=axis)
148
+ else:
149
+ return x.take(index, axis=axis)
150
+
151
+
152
+ def split_train_inference(kf, x, y, estimator, queue=None):
153
+ mem_tracks = []
154
+ for train_index, test_index in kf.split(x):
155
+ x_train = take(x, train_index, queue=queue)
156
+ y_train = take(y, train_index, queue=queue)
157
+ x_test = take(x, test_index, queue=queue)
158
+ y_test = take(y, test_index, queue=queue)
159
+
160
+ if isclass(estimator) and issubclass(estimator, BaseEstimator):
161
+ alg = estimator()
162
+ flag = True
163
+ elif isinstance(estimator, BaseEstimator):
164
+ alg = clone(estimator)
165
+ flag = True
166
+ else:
167
+ flag = False
168
+
169
+ if flag:
170
+ alg.fit(x_train, y_train)
171
+ if hasattr(alg, "predict"):
172
+ alg.predict(x_test)
173
+ elif hasattr(alg, "transform"):
174
+ alg.transform(x_test)
175
+ elif hasattr(alg, "kneighbors"):
176
+ alg.kneighbors(x_test)
177
+ del alg
178
+ else:
179
+ estimator(x_train, y_train)
180
+
181
+ del x_train, x_test, y_train, y_test, flag
182
+ mem_tracks.append(get_traced_memory(queue))
183
+ return mem_tracks
184
+
185
+
186
+ def _kfold_function_template(estimator, dataframe, data_shape, queue=None, func=None):
187
+ tracemalloc.start()
188
+
189
+ n_samples, n_features = data_shape
190
+ X, y, data_memory_size = gen_clsf_data(n_samples, n_features)
191
+ kf = KFold(n_splits=N_SPLITS)
192
+ if func:
193
+ X = func(X)
194
+
195
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
196
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
197
+
198
+ mem_before = get_traced_memory(queue)
199
+ mem_tracks = split_train_inference(kf, X, y, estimator, queue=queue)
200
+ mem_iter_diffs = np.array(mem_tracks[1:]) - np.array(mem_tracks[:-1])
201
+ mem_incr_mean, mem_incr_std = mem_iter_diffs.mean(), mem_iter_diffs.std()
202
+ mem_incr_mean, mem_incr_std = round(mem_incr_mean), round(mem_incr_std)
203
+ with warnings.catch_warnings():
204
+ # In the case that the memory usage is constant, this will raise
205
+ # a ConstantInputWarning error in pearsonr from scipy, this can
206
+ # be ignored.
207
+ warnings.filterwarnings(
208
+ "ignore",
209
+ message="An input array is constant; the correlation coefficient is not defined",
210
+ )
211
+ mem_iter_corr, _ = pearsonr(mem_tracks, list(range(len(mem_tracks))))
212
+
213
+ if mem_iter_corr > 0.95:
214
+ logging.warning(
215
+ "Memory usage is steadily increasing with iterations "
216
+ "(Pearson correlation coefficient between "
217
+ f"memory tracks and iterations is {mem_iter_corr})\n"
218
+ "Memory usage increase per iteration: "
219
+ f"{mem_incr_mean}±{mem_incr_std} bytes"
220
+ )
221
+ mem_before_gc = get_traced_memory(queue)
222
+ mem_diff = mem_before_gc - mem_before
223
+ if isinstance(estimator, BaseEstimator):
224
+ name = str(estimator)
225
+ else:
226
+ name = estimator.__name__
227
+
228
+ message = (
229
+ "Size of extra allocated memory {} using garbage collector "
230
+ f"is greater than {EXTRA_MEMORY_THRESHOLD * 100}% of input data"
231
+ f"\n\tAlgorithm: {name}"
232
+ f"\n\tInput data size: {data_memory_size} bytes"
233
+ "\n\tExtra allocated memory size: {} bytes"
234
+ " / {} %"
235
+ )
236
+ if mem_diff >= EXTRA_MEMORY_THRESHOLD * data_memory_size:
237
+ logging.warning(
238
+ message.format(
239
+ "before", mem_diff, round((mem_diff) / data_memory_size * 100, 2)
240
+ )
241
+ )
242
+ gc.collect()
243
+ mem_after = get_traced_memory(queue)
244
+ tracemalloc.stop()
245
+ mem_diff = mem_after - mem_before
246
+
247
+ # GPU offloading with SYCL contains a program/kernel cache which should
248
+ # be controllable via a KernelProgramCache object in the SYCL context.
249
+ # The programs and kernels are stored on the GPU, but cannot be cleared
250
+ # as this class is not available for access in all oneDAL DPC++ runtimes.
251
+ # Therefore, until this is implemented this test must be skipped for gpu
252
+ # as it looks like a memory leak (at least there is no way to discern a
253
+ # leak on the first run).
254
+ if queue is None or queue.sycl_device.is_cpu:
255
+ assert mem_diff < EXTRA_MEMORY_THRESHOLD * data_memory_size, message.format(
256
+ "after", mem_diff, round((mem_diff) / data_memory_size * 100, 2)
257
+ )
258
+
259
+
260
+ @pytest.mark.parametrize("order", ["F", "C"])
261
+ @pytest.mark.parametrize(
262
+ "dataframe,queue", get_dataframes_and_queues("numpy,pandas,dpctl", "cpu")
263
+ )
264
+ @pytest.mark.parametrize("estimator", CPU_ESTIMATORS.keys())
265
+ @pytest.mark.parametrize("data_shape", data_shapes)
266
+ def test_memory_leaks(estimator, dataframe, queue, order, data_shape):
267
+ func = ORDER_DICT[order]
268
+ if estimator == "_assert_all_finite" and queue is not None:
269
+ pytest.skip(f"{estimator} is not designed for device offloading")
270
+
271
+ _kfold_function_template(
272
+ CPU_ESTIMATORS[estimator], dataframe, data_shape, queue, func
273
+ )
274
+
275
+
276
+ @pytest.mark.skipif(
277
+ os.getenv("ZES_ENABLE_SYSMAN") is None or not is_dpctl_available("gpu"),
278
+ reason="SYCL device memory leak check requires the level zero sysman",
279
+ )
280
+ @pytest.mark.parametrize("queue", get_queues("gpu"))
281
+ @pytest.mark.parametrize("estimator", GPU_ESTIMATORS.keys())
282
+ @pytest.mark.parametrize("order", ["F", "C"])
283
+ @pytest.mark.parametrize("data_shape", data_shapes)
284
+ def test_gpu_memory_leaks(estimator, queue, order, data_shape):
285
+ func = ORDER_DICT[order]
286
+ if "ExtraTrees" in estimator and data_shape == (2000, 50):
287
+ pytest.skip("Avoid a segmentation fault in Extra Trees algorithms")
288
+
289
+ with config_context(target_offload=queue):
290
+ _kfold_function_template(GPU_ESTIMATORS[estimator], None, data_shape, queue, func)