scikit-learn-intelex 2024.2.0__py312-none-win_amd64.whl → 2024.4.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +31 -4
  3. {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex}/basic_statistics/__init__.py +2 -1
  4. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  5. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  6. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  7. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  8. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +335 -0
  9. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  10. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +74 -43
  11. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +78 -89
  12. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  13. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +316 -0
  14. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +63 -11
  15. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +40 -5
  16. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -2
  17. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +74 -20
  18. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +4 -1
  19. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +44 -131
  20. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +198 -221
  21. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +146 -0
  22. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -5
  23. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  24. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +5 -73
  25. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +6 -5
  26. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  27. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  28. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +4 -7
  29. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +70 -50
  30. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +6 -52
  31. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +70 -51
  32. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -49
  33. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +164 -0
  34. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +8 -3
  35. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  36. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +8 -2
  37. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  38. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +371 -0
  39. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +2 -1
  40. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +97 -0
  41. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  42. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  43. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  44. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  45. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
  46. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -308
  47. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  48. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
  49. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  50. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
  51. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -136
  52. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  53. scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
  54. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  55. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  56. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  57. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  58. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  59. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  60. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  61. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  62. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  63. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  64. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  65. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  66. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  67. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  69. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  70. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  72. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  73. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  74. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  75. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  76. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  77. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  78. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  79. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  80. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  81. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  83. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  84. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  86. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  87. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd}/basic_statistics/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  91. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  92. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  93. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  94. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  96. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  98. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  99. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  100. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  101. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  102. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  104. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  105. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  106. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  107. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  108. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  109. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  110. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  111. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  112. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import sklearnex
18
+ from daal4py.sklearn._utils import daal_check_version
19
+
20
+ # General use of patch_sklearn and unpatch_sklearn in pytest is not recommended.
21
+ # It changes global state and can impact the operation of other tests. This file
22
+ # specifically tests patch_sklearn and unpatch_sklearn and is exempt from this.
23
+ # If sklearnex patching is necessary in testing, use the 'with_sklearnex' pytest
24
+ # fixture.
25
+
26
+
27
+ def test_monkey_patching():
28
+ _tokens = sklearnex.get_patch_names()
29
+ _values = sklearnex.get_patch_map().values()
30
+ _classes = list()
31
+
32
+ for v in _values:
33
+ for c in v:
34
+ _classes.append(c[0])
35
+
36
+ try:
37
+ sklearnex.patch_sklearn()
38
+
39
+ for i, _ in enumerate(_tokens):
40
+ t = _tokens[i]
41
+ p = _classes[i][0]
42
+ n = _classes[i][1]
43
+
44
+ class_module = getattr(p, n).__module__
45
+ assert class_module.startswith("daal4py") or class_module.startswith(
46
+ "sklearnex"
47
+ ), "Patching has completed with error."
48
+
49
+ for i, _ in enumerate(_tokens):
50
+ t = _tokens[i]
51
+ p = _classes[i][0]
52
+ n = _classes[i][1]
53
+
54
+ sklearnex.unpatch_sklearn(t)
55
+ sklearn_class = getattr(p, n, None)
56
+ if sklearn_class is not None:
57
+ sklearn_class = sklearn_class.__module__
58
+ assert sklearn_class is None or sklearn_class.startswith(
59
+ "sklearn"
60
+ ), "Unpatching has completed with error."
61
+
62
+ finally:
63
+ sklearnex.unpatch_sklearn()
64
+
65
+ try:
66
+ for i, _ in enumerate(_tokens):
67
+ t = _tokens[i]
68
+ p = _classes[i][0]
69
+ n = _classes[i][1]
70
+
71
+ sklearn_class = getattr(p, n, None)
72
+ if sklearn_class is not None:
73
+ sklearn_class = sklearn_class.__module__
74
+ assert sklearn_class is None or sklearn_class.startswith(
75
+ "sklearn"
76
+ ), "Unpatching has completed with error."
77
+
78
+ finally:
79
+ sklearnex.unpatch_sklearn()
80
+
81
+ try:
82
+ for i, _ in enumerate(_tokens):
83
+ t = _tokens[i]
84
+ p = _classes[i][0]
85
+ n = _classes[i][1]
86
+
87
+ sklearnex.patch_sklearn(t)
88
+
89
+ class_module = getattr(p, n).__module__
90
+ assert class_module.startswith("daal4py") or class_module.startswith(
91
+ "sklearnex"
92
+ ), "Patching has completed with error."
93
+ finally:
94
+ sklearnex.unpatch_sklearn()
95
+
96
+
97
+ def test_patch_by_list_simple():
98
+ try:
99
+ sklearnex.patch_sklearn(["LogisticRegression"])
100
+
101
+ from sklearn.ensemble import RandomForestRegressor
102
+ from sklearn.linear_model import LogisticRegression
103
+ from sklearn.neighbors import KNeighborsRegressor
104
+ from sklearn.svm import SVC
105
+
106
+ assert RandomForestRegressor.__module__.startswith("sklearn")
107
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
108
+ if daal_check_version((2024, "P", 1)):
109
+ assert LogisticRegression.__module__.startswith("sklearnex")
110
+ else:
111
+ assert LogisticRegression.__module__.startswith("daal4py")
112
+ assert SVC.__module__.startswith("sklearn")
113
+ finally:
114
+ sklearnex.unpatch_sklearn()
115
+
116
+
117
+ def test_patch_by_list_many_estimators():
118
+ try:
119
+ sklearnex.patch_sklearn(["LogisticRegression", "SVC"])
120
+
121
+ from sklearn.ensemble import RandomForestRegressor
122
+ from sklearn.linear_model import LogisticRegression
123
+ from sklearn.neighbors import KNeighborsRegressor
124
+ from sklearn.svm import SVC
125
+
126
+ assert RandomForestRegressor.__module__.startswith("sklearn")
127
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
128
+ if daal_check_version((2024, "P", 1)):
129
+ assert LogisticRegression.__module__.startswith("sklearnex")
130
+ else:
131
+ assert LogisticRegression.__module__.startswith("daal4py")
132
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
133
+ "sklearnex"
134
+ )
135
+
136
+ finally:
137
+ sklearnex.unpatch_sklearn()
138
+
139
+
140
+ def test_unpatch_by_list_many_estimators():
141
+ try:
142
+ sklearnex.patch_sklearn()
143
+
144
+ from sklearn.ensemble import RandomForestRegressor
145
+ from sklearn.linear_model import LogisticRegression
146
+ from sklearn.neighbors import KNeighborsRegressor
147
+ from sklearn.svm import SVC
148
+
149
+ assert RandomForestRegressor.__module__.startswith("sklearnex")
150
+ assert KNeighborsRegressor.__module__.startswith(
151
+ "daal4py"
152
+ ) or KNeighborsRegressor.__module__.startswith("sklearnex")
153
+ if daal_check_version((2024, "P", 1)):
154
+ assert LogisticRegression.__module__.startswith("sklearnex")
155
+ else:
156
+ assert LogisticRegression.__module__.startswith("daal4py")
157
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
158
+ "sklearnex"
159
+ )
160
+
161
+ sklearnex.unpatch_sklearn(["KNeighborsRegressor", "RandomForestRegressor"])
162
+
163
+ from sklearn.ensemble import RandomForestRegressor
164
+ from sklearn.linear_model import LogisticRegression
165
+ from sklearn.neighbors import KNeighborsRegressor
166
+ from sklearn.svm import SVC
167
+
168
+ assert RandomForestRegressor.__module__.startswith("sklearn")
169
+ assert KNeighborsRegressor.__module__.startswith("sklearn")
170
+ if daal_check_version((2024, "P", 1)):
171
+ assert LogisticRegression.__module__.startswith("sklearnex")
172
+ else:
173
+ assert LogisticRegression.__module__.startswith("daal4py")
174
+
175
+ assert SVC.__module__.startswith("daal4py") or SVC.__module__.startswith(
176
+ "sklearnex"
177
+ )
178
+ finally:
179
+ sklearnex.unpatch_sklearn()
180
+
181
+
182
+ def test_patching_checker():
183
+ for name in [None, "SVC", "PCA"]:
184
+ try:
185
+ sklearnex.patch_sklearn(name=name)
186
+ assert sklearnex.sklearn_is_patched(name=name)
187
+
188
+ finally:
189
+ sklearnex.unpatch_sklearn(name=name)
190
+ assert not sklearnex.sklearn_is_patched(name=name)
191
+ try:
192
+ sklearnex.patch_sklearn()
193
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
194
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
195
+ for status in patching_status_map.values():
196
+ assert status
197
+ finally:
198
+ sklearnex.unpatch_sklearn()
199
+
200
+ patching_status_map = sklearnex.sklearn_is_patched(return_map=True)
201
+ assert len(patching_status_map) == len(sklearnex.get_patch_names())
202
+ for status in patching_status_map.values():
203
+ assert not status
204
+
205
+
206
+ def test_preview_namespace():
207
+ def get_estimators():
208
+ from sklearn.cluster import DBSCAN
209
+ from sklearn.decomposition import PCA
210
+ from sklearn.ensemble import RandomForestClassifier
211
+ from sklearn.linear_model import LinearRegression
212
+ from sklearn.svm import SVC
213
+
214
+ return (
215
+ LinearRegression(),
216
+ PCA(),
217
+ DBSCAN(),
218
+ SVC(),
219
+ RandomForestClassifier(),
220
+ )
221
+
222
+ from sklearnex.dispatcher import _is_preview_enabled
223
+
224
+ try:
225
+ sklearnex.patch_sklearn(preview=True)
226
+
227
+ assert _is_preview_enabled()
228
+
229
+ lr, pca, dbscan, svc, rfc = get_estimators()
230
+ assert "sklearnex" in rfc.__module__
231
+
232
+ if daal_check_version((2023, "P", 100)):
233
+ assert "sklearnex" in lr.__module__
234
+ else:
235
+ assert "daal4py" in lr.__module__
236
+
237
+ assert "sklearnex" in pca.__module__
238
+ assert "sklearnex" in dbscan.__module__
239
+ assert "sklearnex" in svc.__module__
240
+
241
+ finally:
242
+ sklearnex.unpatch_sklearn()
243
+
244
+ # no patching behavior
245
+ lr, pca, dbscan, svc, rfc = get_estimators()
246
+ assert "sklearn." in lr.__module__ and "daal4py" not in lr.__module__
247
+ assert "sklearn." in pca.__module__ and "daal4py" not in pca.__module__
248
+ assert "sklearn." in dbscan.__module__ and "daal4py" not in dbscan.__module__
249
+ assert "sklearn." in svc.__module__ and "daal4py" not in svc.__module__
250
+ assert "sklearn." in rfc.__module__ and "daal4py" not in rfc.__module__
251
+
252
+ # default patching behavior
253
+ try:
254
+ sklearnex.patch_sklearn()
255
+ assert not _is_preview_enabled()
256
+
257
+ lr, pca, dbscan, svc, rfc = get_estimators()
258
+ if daal_check_version((2023, "P", 100)):
259
+ assert "sklearnex" in lr.__module__
260
+ else:
261
+ assert "daal4py" in lr.__module__
262
+
263
+ assert "sklearnex" in pca.__module__
264
+ assert "sklearnex" in rfc.__module__
265
+ assert "sklearnex" in dbscan.__module__
266
+ assert "sklearnex" in svc.__module__
267
+ finally:
268
+ sklearnex.unpatch_sklearn()
@@ -84,10 +84,16 @@ def test_n_jobs_support(caplog, estimator_class, n_jobs):
84
84
  if method_name == "fit":
85
85
  continue
86
86
  method = getattr(estimator_instance, method_name)
87
- if len(inspect.signature(method).parameters) == 0:
87
+ argdict = inspect.signature(method).parameters
88
+ argnum = len(
89
+ [i for i in argdict if argdict[i].default == inspect.Parameter.empty]
90
+ )
91
+ if argnum == 0:
88
92
  check_method(method=method, caplog=caplog)
89
- else:
93
+ elif argnum == 1:
90
94
  check_method(X, method=method, caplog=caplog)
95
+ else:
96
+ check_method(X, Y, method=method, caplog=caplog)
91
97
  # check if correct methods were decorated
92
98
  check_methods_decoration(estimator_class)
93
99
  check_methods_decoration(estimator_instance)
@@ -15,13 +15,7 @@
15
15
  # ==============================================================================
16
16
  import pytest
17
17
 
18
- from sklearnex import config_context, patch_sklearn
19
-
20
- patch_sklearn()
21
-
22
- from sklearn.datasets import make_classification
23
- from sklearn.ensemble import BaggingClassifier
24
- from sklearn.svm import SVC
18
+ from sklearnex import config_context
25
19
 
26
20
  try:
27
21
  import dpctl
@@ -38,7 +32,11 @@ except (ImportError, ModuleNotFoundError):
38
32
  "to see raised 'SyclQueueCreationError'. "
39
33
  "'dpctl' module is required for test.",
40
34
  )
41
- def test_config_context_in_parallel():
35
+ def test_config_context_in_parallel(with_sklearnex):
36
+ from sklearn.datasets import make_classification
37
+ from sklearn.ensemble import BaggingClassifier
38
+ from sklearn.svm import SVC
39
+
42
40
  x, y = make_classification(random_state=42)
43
41
  try:
44
42
  with config_context(target_offload="gpu", allow_fallback_to_host=False):
@@ -0,0 +1,371 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ import importlib
19
+ import inspect
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ from inspect import signature
25
+
26
+ import numpy as np
27
+ import numpy.random as nprnd
28
+ import pytest
29
+ from sklearn.base import BaseEstimator
30
+
31
+ from daal4py.sklearn._utils import sklearn_check_version
32
+ from onedal.tests.utils._dataframes_support import (
33
+ _convert_to_dataframe,
34
+ get_dataframes_and_queues,
35
+ )
36
+ from sklearnex import is_patched_instance
37
+ from sklearnex.dispatcher import _is_preview_enabled
38
+ from sklearnex.metrics import pairwise_distances, roc_auc_score
39
+ from sklearnex.tests._utils import (
40
+ DTYPES,
41
+ PATCHED_FUNCTIONS,
42
+ PATCHED_MODELS,
43
+ SPECIAL_INSTANCES,
44
+ UNPATCHED_FUNCTIONS,
45
+ UNPATCHED_MODELS,
46
+ gen_dataset,
47
+ gen_models_info,
48
+ )
49
+
50
+
51
+ @pytest.mark.parametrize("dtype", DTYPES)
52
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
53
+ @pytest.mark.parametrize("metric", ["cosine", "correlation"])
54
+ def test_pairwise_distances_patching(caplog, dataframe, queue, dtype, metric):
55
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
56
+ if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
57
+ pytest.skip("Hardware does not support fp16 SYCL testing")
58
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
59
+ pytest.skip("Hardware does not support fp64 SYCL testing")
60
+ elif queue and queue.sycl_device.is_gpu:
61
+ pytest.skip("pairwise_distances does not support GPU queues")
62
+
63
+ rng = nprnd.default_rng()
64
+ X = _convert_to_dataframe(
65
+ rng.random(size=1000).reshape(1, -1),
66
+ sycl_queue=queue,
67
+ target_df=dataframe,
68
+ dtype=dtype,
69
+ )
70
+
71
+ _ = pairwise_distances(X, metric=metric)
72
+ assert all(
73
+ [
74
+ "running accelerated version" in i.message
75
+ or "fallback to original Scikit-learn" in i.message
76
+ for i in caplog.records
77
+ ]
78
+ ), f"sklearnex patching issue in pairwise_distances with log: \n{caplog.text}"
79
+
80
+
81
+ @pytest.mark.parametrize(
82
+ "dtype", [i for i in DTYPES if "32" in i.__name__ or "64" in i.__name__]
83
+ )
84
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
85
+ def test_roc_auc_score_patching(caplog, dataframe, queue, dtype):
86
+ if dtype in [np.uint32, np.uint64] and sys.platform == "win32":
87
+ pytest.skip("Windows issue with unsigned ints")
88
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
89
+ pytest.skip("Hardware does not support fp64 SYCL testing")
90
+
91
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
92
+ rng = nprnd.default_rng()
93
+ X = _convert_to_dataframe(
94
+ rng.integers(2, size=1000),
95
+ sycl_queue=queue,
96
+ target_df=dataframe,
97
+ dtype=dtype,
98
+ )
99
+ y = _convert_to_dataframe(
100
+ rng.integers(2, size=1000),
101
+ sycl_queue=queue,
102
+ target_df=dataframe,
103
+ dtype=dtype,
104
+ )
105
+
106
+ _ = roc_auc_score(X, y)
107
+ assert all(
108
+ [
109
+ "running accelerated version" in i.message
110
+ or "fallback to original Scikit-learn" in i.message
111
+ for i in caplog.records
112
+ ]
113
+ ), f"sklearnex patching issue in roc_auc_score with log: \n{caplog.text}"
114
+
115
+
116
+ @pytest.mark.parametrize("dtype", DTYPES)
117
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
118
+ @pytest.mark.parametrize("estimator, method", gen_models_info(PATCHED_MODELS))
119
+ def test_standard_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
120
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
121
+ est = PATCHED_MODELS[estimator]()
122
+
123
+ if queue:
124
+ if dtype == np.float16 and not queue.sycl_device.has_aspect_fp16:
125
+ pytest.skip("Hardware does not support fp16 SYCL testing")
126
+ elif dtype == np.float64 and not queue.sycl_device.has_aspect_fp64:
127
+ pytest.skip("Hardware does not support fp64 SYCL testing")
128
+ elif queue.sycl_device.is_gpu and estimator in [
129
+ "KMeans",
130
+ "ElasticNet",
131
+ "Lasso",
132
+ "Ridge",
133
+ ]:
134
+ pytest.skip(f"{estimator} does not support GPU queues")
135
+
136
+ if estimator == "TSNE" and method == "fit_transform":
137
+ pytest.skip("TSNE.fit_transform is too slow for common testing")
138
+ elif (
139
+ estimator == "Ridge"
140
+ and method in ["predict", "score"]
141
+ and sys.platform == "win32"
142
+ and dtype in [np.uint32, np.uint64]
143
+ ):
144
+ pytest.skip("Windows segmentation fault for Ridge.predict for unsigned ints")
145
+ elif method and not hasattr(est, method):
146
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
147
+
148
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
149
+ est.fit(X, y)
150
+
151
+ if method:
152
+ if method != "score":
153
+ getattr(est, method)(X)
154
+ else:
155
+ est.score(X, y)
156
+ assert all(
157
+ [
158
+ "running accelerated version" in i.message
159
+ or "fallback to original Scikit-learn" in i.message
160
+ for i in caplog.records
161
+ ]
162
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
163
+
164
+
165
+ @pytest.mark.parametrize("dtype", DTYPES)
166
+ @pytest.mark.parametrize("dataframe, queue", get_dataframes_and_queues())
167
+ @pytest.mark.parametrize("estimator, method", gen_models_info(SPECIAL_INSTANCES))
168
+ def test_special_estimator_patching(caplog, dataframe, queue, dtype, estimator, method):
169
+ # prepare logging
170
+
171
+ with caplog.at_level(logging.WARNING, logger="sklearnex"):
172
+ est = SPECIAL_INSTANCES[estimator]
173
+
174
+ # Its not possible to get the dpnp/dpctl arrays to be in the proper dtype
175
+ if dtype == np.float16 and queue and not queue.sycl_device.has_aspect_fp16:
176
+ pytest.skip("Hardware does not support fp16 SYCL testing")
177
+ elif dtype == np.float64 and queue and not queue.sycl_device.has_aspect_fp64:
178
+ pytest.skip("Hardware does not support fp64 SYCL testing")
179
+
180
+ X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)
181
+ est.fit(X, y)
182
+
183
+ if method and not hasattr(est, method):
184
+ pytest.skip(f"sklearn available_if prevents testing {estimator}.{method}")
185
+
186
+ if method:
187
+ if method != "score":
188
+ getattr(est, method)(X)
189
+ else:
190
+ est.score(X, y)
191
+
192
+ assert all(
193
+ [
194
+ "running accelerated version" in i.message
195
+ or "fallback to original Scikit-learn" in i.message
196
+ for i in caplog.records
197
+ ]
198
+ ), f"sklearnex patching issue in {estimator}.{method} with log: \n{caplog.text}"
199
+
200
+
201
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
202
+ def test_standard_estimator_signatures(estimator):
203
+ est = PATCHED_MODELS[estimator]()
204
+ unpatched_est = UNPATCHED_MODELS[estimator]()
205
+
206
+ # all public sklearn methods should have signature matches in sklearnex
207
+
208
+ unpatched_est_methods = [
209
+ i
210
+ for i in dir(unpatched_est)
211
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched_est, i)
212
+ ]
213
+ for method in unpatched_est_methods:
214
+ est_method = getattr(est, method)
215
+ unpatched_est_method = getattr(unpatched_est, method)
216
+ if callable(unpatched_est_method):
217
+ regex = rf"(?:sklearn|daal4py)\S*{estimator}" # needed due to differences in module structure
218
+ patched_sig = re.sub(regex, estimator, str(signature(est_method)))
219
+ unpatched_sig = re.sub(regex, estimator, str(signature(unpatched_est_method)))
220
+ assert (
221
+ patched_sig == unpatched_sig
222
+ ), f"Signature of {estimator}.{method} does not match sklearn"
223
+
224
+
225
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
226
+ def test_standard_estimator_init_signatures(estimator):
227
+ # Several estimators have additional parameters that are user-accessible
228
+ # which are sklearnex-specific. They will fail and are removed from tests.
229
+ # remove n_jobs due to estimator patching for sklearnex (known deviation)
230
+ patched_sig = str(signature(PATCHED_MODELS[estimator].__init__))
231
+ unpatched_sig = str(signature(UNPATCHED_MODELS[estimator].__init__))
232
+
233
+ # Sklearnex allows for positional kwargs and n_jobs, when sklearn doesn't
234
+ for kwarg in ["n_jobs=None", "*"]:
235
+ patched_sig = patched_sig.replace(", " + kwarg, "")
236
+ unpatched_sig = unpatched_sig.replace(", " + kwarg, "")
237
+
238
+ # Special sklearnex-specific kwargs are removed from signatures here
239
+ if estimator in [
240
+ "RandomForestRegressor",
241
+ "RandomForestClassifier",
242
+ "ExtraTreesRegressor",
243
+ "ExtraTreesClassifier",
244
+ ]:
245
+ for kwarg in ["min_bin_size=1", "max_bins=256"]:
246
+ patched_sig = patched_sig.replace(", " + kwarg, "")
247
+
248
+ assert (
249
+ patched_sig == unpatched_sig
250
+ ), f"Signature of {estimator}.__init__ does not match sklearn"
251
+
252
+
253
+ @pytest.mark.parametrize(
254
+ "function",
255
+ [
256
+ i
257
+ for i in UNPATCHED_FUNCTIONS.keys()
258
+ if i not in ["train_test_split", "set_config", "config_context"]
259
+ ],
260
+ )
261
+ def test_patched_function_signatures(function):
262
+ # certain functions are dropped from the test
263
+ # as they add functionality to the underlying sklearn function
264
+ if not sklearn_check_version("1.1") and function == "_assert_all_finite":
265
+ pytest.skip("Sklearn versioning not added to _assert_all_finite")
266
+ func = PATCHED_FUNCTIONS[function]
267
+ unpatched_func = UNPATCHED_FUNCTIONS[function]
268
+
269
+ if callable(unpatched_func):
270
+ assert str(signature(func)) == str(
271
+ signature(unpatched_func)
272
+ ), f"Signature of {func} does not match sklearn"
273
+
274
+
275
+ def test_patch_map_match():
276
+ # This rule applies to functions and classes which are out of preview.
277
+ # Items listed in a matching submodule's __all__ attribute should be
278
+ # in get_patch_map. There should not be any missing or additional elements.
279
+
280
+ def list_all_attr(string):
281
+ try:
282
+ modules = set(importlib.import_module(string).__all__)
283
+ except ModuleNotFoundError:
284
+ modules = set([None])
285
+ return modules
286
+
287
+ if _is_preview_enabled():
288
+ pytest.skip("preview sklearnex has been activated")
289
+ patched = {**PATCHED_MODELS, **PATCHED_FUNCTIONS}
290
+
291
+ sklearnex__all__ = list_all_attr("sklearnex")
292
+ sklearn__all__ = list_all_attr("sklearn")
293
+
294
+ module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)}
295
+
296
+ # _assert_all_finite patches an internal sklearn function which isn't
297
+ # exposed via __all__ in sklearn. It is a special case where this rule
298
+ # is not applied (e.g. it is grandfathered in).
299
+ del patched["_assert_all_finite"]
300
+
301
+ # remove all scikit-learn-intelex-only estimators
302
+ for i in patched.copy():
303
+ if i not in UNPATCHED_MODELS and i not in UNPATCHED_FUNCTIONS:
304
+ del patched[i]
305
+
306
+ for module in module_map:
307
+ sklearn_module__all__ = list_all_attr("sklearn." + module_map[module])
308
+ sklearnex_module__all__ = list_all_attr("sklearnex." + module)
309
+ intersect = sklearnex_module__all__.intersection(sklearn_module__all__)
310
+
311
+ for i in intersect:
312
+ if i:
313
+ del patched[i]
314
+ else:
315
+ del patched[module]
316
+ assert patched == {}, f"{patched.keys()} were not properly patched"
317
+
318
+
319
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
320
+ def test_is_patched_instance(estimator):
321
+ patched = PATCHED_MODELS[estimator]
322
+ unpatched = UNPATCHED_MODELS[estimator]
323
+ assert is_patched_instance(patched), f"{patched} is a patched instance"
324
+ assert not is_patched_instance(unpatched), f"{unpatched} is an unpatched instance"
325
+
326
+
327
+ @pytest.mark.parametrize("estimator", PATCHED_MODELS.keys())
328
+ def test_if_estimator_inherits_sklearn(estimator):
329
+ est = PATCHED_MODELS[estimator]
330
+ if estimator in UNPATCHED_MODELS:
331
+ assert issubclass(
332
+ est, UNPATCHED_MODELS[estimator]
333
+ ), f"{estimator} does not inherit from the patched sklearn estimator"
334
+ else:
335
+ assert issubclass(est, BaseEstimator)
336
+
337
+
338
+ @pytest.mark.parametrize("estimator", UNPATCHED_MODELS.keys())
339
+ def test_docstring_patching_match(estimator):
340
+ patched = PATCHED_MODELS[estimator]
341
+ unpatched = UNPATCHED_MODELS[estimator]
342
+ patched_docstrings = {
343
+ i: getattr(patched, i).__doc__
344
+ for i in dir(patched)
345
+ if not i.startswith("_") and not i.endswith("_") and hasattr(patched, i)
346
+ }
347
+ unpatched_docstrings = {
348
+ i: getattr(unpatched, i).__doc__
349
+ for i in dir(unpatched)
350
+ if not i.startswith("_") and not i.endswith("_") and hasattr(unpatched, i)
351
+ }
352
+
353
+ # check class docstring match if a docstring is available
354
+
355
+ assert (patched.__doc__ is None) == (unpatched.__doc__ is None)
356
+
357
+ # check class attribute docstrings
358
+
359
+ for i in unpatched_docstrings:
360
+ assert (patched_docstrings[i] is None) == (unpatched_docstrings[i] is None)
361
+
362
+
363
+ @pytest.mark.parametrize("member", ["_onedal_cpu_supported", "_onedal_gpu_supported"])
364
+ @pytest.mark.parametrize(
365
+ "name",
366
+ [i for i in PATCHED_MODELS.keys() if "sklearnex" in PATCHED_MODELS[i].__module__],
367
+ )
368
+ def test_onedal_supported_member(name, member):
369
+ patched = PATCHED_MODELS[name]
370
+ sig = str(inspect.signature(getattr(patched, member)))
371
+ assert "(self, method_name, *data)" == sig
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
+ from ._namespace import get_namespace
17
18
  from .validation import _assert_all_finite
18
19
 
19
- __all__ = ["_assert_all_finite"]
20
+ __all__ = ["get_namespace", "_assert_all_finite"]