scikit-learn-intelex 2024.5.0__py39-none-win_amd64.whl → 2024.7.0__py39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (129) hide show
  1. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/_config.py +3 -15
  2. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +98 -0
  3. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +143 -0
  4. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +251 -0
  5. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +1 -1
  6. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  7. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +8 -0
  8. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
  9. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +15 -3
  10. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/conftest.py +11 -1
  11. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +64 -13
  12. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +35 -0
  13. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +25 -1
  14. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +4 -2
  15. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +109 -1
  16. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +121 -57
  17. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +7 -0
  18. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
  19. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
  20. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +102 -25
  21. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +25 -39
  22. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +92 -74
  23. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
  24. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +10 -10
  25. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +30 -5
  26. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +45 -3
  27. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +21 -0
  28. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
  29. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
  30. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +3 -0
  31. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +9 -0
  32. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +45 -1
  33. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +1 -20
  34. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +25 -20
  35. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +31 -7
  36. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  37. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  38. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +228 -0
  39. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  40. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py → scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +19 -17
  41. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +419 -0
  42. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  43. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  44. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  45. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  46. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  47. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  48. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  49. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  50. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +163 -0
  51. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  52. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +328 -0
  53. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +40 -4
  54. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +31 -2
  55. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +40 -4
  56. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +31 -2
  57. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
  58. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +328 -0
  59. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/_utils_spmd.py +185 -0
  60. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +54 -0
  61. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +4 -0
  62. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +290 -0
  63. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +12 -4
  64. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +21 -25
  65. scikit_learn_intelex-2024.7.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +295 -0
  66. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/_namespace.py +1 -1
  67. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/METADATA +5 -2
  68. scikit_learn_intelex-2024.7.0.dist-info/RECORD +122 -0
  69. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/WHEEL +1 -1
  70. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -257
  71. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -17
  72. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -185
  73. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +0 -173
  74. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -231
  75. scikit_learn_intelex-2024.5.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
  76. scikit_learn_intelex-2024.5.0.dist-info/RECORD +0 -104
  77. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  78. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  79. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  80. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  81. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +0 -0
  82. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  83. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  84. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  86. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  87. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  88. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  90. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  91. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  92. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  94. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  96. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -0
  98. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  99. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  100. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  101. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  102. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  103. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  104. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  105. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  106. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  107. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  108. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  109. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  110. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  111. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  112. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  113. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  114. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  115. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
  116. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  117. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  118. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  119. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  120. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  121. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  122. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
  123. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  124. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  125. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  126. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +0 -0
  127. {scikit_learn_intelex-2024.5.0.data → scikit_learn_intelex-2024.7.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  128. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/LICENSE.txt +0 -0
  129. {scikit_learn_intelex-2024.5.0.dist-info → scikit_learn_intelex-2024.7.0.dist-info}/top_level.txt +0 -0
@@ -14,24 +14,12 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
- import threading
18
17
  from contextlib import contextmanager
19
18
 
20
19
  from sklearn import get_config as skl_get_config
21
20
  from sklearn import set_config as skl_set_config
22
21
 
23
- _default_global_config = {
24
- "target_offload": "auto",
25
- "allow_fallback_to_host": False,
26
- }
27
-
28
- _threadlocal = threading.local()
29
-
30
-
31
- def _get_sklearnex_threadlocal_config():
32
- if not hasattr(_threadlocal, "global_config"):
33
- _threadlocal.global_config = _default_global_config.copy()
34
- return _threadlocal.global_config
22
+ from onedal._config import _get_config as onedal_get_config
35
23
 
36
24
 
37
25
  def get_config():
@@ -46,7 +34,7 @@ def get_config():
46
34
  set_config : Set global configuration.
47
35
  """
48
36
  sklearn = skl_get_config()
49
- sklearnex = _get_sklearnex_threadlocal_config().copy()
37
+ sklearnex = onedal_get_config()
50
38
  return {**sklearn, **sklearnex}
51
39
 
52
40
 
@@ -70,7 +58,7 @@ def set_config(target_offload=None, allow_fallback_to_host=None, **sklearn_confi
70
58
  """
71
59
  skl_set_config(**sklearn_configs)
72
60
 
73
- local_config = _get_sklearnex_threadlocal_config()
61
+ local_config = onedal_get_config(copy=False)
74
62
 
75
63
  if target_offload is not None:
76
64
  local_config["target_offload"] = target_offload
@@ -0,0 +1,98 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from functools import wraps
18
+
19
+ from onedal._device_offload import (
20
+ _copy_to_usm,
21
+ _get_global_queue,
22
+ _transfer_to_host,
23
+ dpnp_available,
24
+ )
25
+
26
+ if dpnp_available:
27
+ import dpnp
28
+ from onedal._device_offload import _convert_to_dpnp
29
+
30
+
31
+ from ._config import get_config
32
+
33
+
34
+ def _get_backend(obj, queue, method_name, *data):
35
+ cpu_device = queue is None or queue.sycl_device.is_cpu
36
+ gpu_device = queue is not None and queue.sycl_device.is_gpu
37
+
38
+ if cpu_device:
39
+ patching_status = obj._onedal_cpu_supported(method_name, *data)
40
+ if patching_status.get_status():
41
+ return "onedal", queue, patching_status
42
+ else:
43
+ return "sklearn", None, patching_status
44
+
45
+ allow_fallback_to_host = get_config()["allow_fallback_to_host"]
46
+
47
+ if gpu_device:
48
+ patching_status = obj._onedal_gpu_supported(method_name, *data)
49
+ if patching_status.get_status():
50
+ return "onedal", queue, patching_status
51
+ else:
52
+ if allow_fallback_to_host:
53
+ patching_status = obj._onedal_cpu_supported(method_name, *data)
54
+ if patching_status.get_status():
55
+ return "onedal", None, patching_status
56
+ else:
57
+ return "sklearn", None, patching_status
58
+ else:
59
+ return "sklearn", None, patching_status
60
+
61
+ raise RuntimeError("Device support is not implemented")
62
+
63
+
64
+ def dispatch(obj, method_name, branches, *args, **kwargs):
65
+ q = _get_global_queue()
66
+ q, hostargs = _transfer_to_host(q, *args)
67
+ q, hostvalues = _transfer_to_host(q, *kwargs.values())
68
+ hostkwargs = dict(zip(kwargs.keys(), hostvalues))
69
+
70
+ backend, q, patching_status = _get_backend(obj, q, method_name, *hostargs)
71
+
72
+ if backend == "onedal":
73
+ patching_status.write_log(queue=q)
74
+ return branches[backend](obj, *hostargs, **hostkwargs, queue=q)
75
+ if backend == "sklearn":
76
+ patching_status.write_log()
77
+ return branches[backend](obj, *hostargs, **hostkwargs)
78
+ raise RuntimeError(
79
+ f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}"
80
+ )
81
+
82
+
83
+ def wrap_output_data(func):
84
+ @wraps(func)
85
+ def wrapper(self, *args, **kwargs):
86
+ data = (*args, *kwargs.values())
87
+ if len(data) == 0:
88
+ usm_iface = None
89
+ else:
90
+ usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
91
+ result = func(self, *args, **kwargs)
92
+ if usm_iface is not None:
93
+ result = _copy_to_usm(usm_iface["syclobj"], result)
94
+ if dpnp_available and isinstance(data[0], dpnp.ndarray):
95
+ result = _convert_to_dpnp(result)
96
+ return result
97
+
98
+ return wrapper
@@ -0,0 +1,143 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ from sklearn.base import BaseEstimator
19
+ from sklearn.utils import check_array
20
+ from sklearn.utils.validation import _check_sample_weight
21
+
22
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
23
+ from daal4py.sklearn._utils import sklearn_check_version
24
+ from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
25
+
26
+ from .._device_offload import dispatch
27
+ from .._utils import PatchingConditionsChain
28
+
29
+
30
+ @control_n_jobs(decorated_methods=["fit"])
31
+ class BasicStatistics(BaseEstimator):
32
+ """
33
+ Estimator for basic statistics.
34
+ Allows to compute basic statistics for provided data.
35
+ Parameters
36
+ ----------
37
+ result_options: string or list, default='all'
38
+ List of statistics to compute
39
+
40
+ Attributes (are existing only if corresponding result option exists)
41
+ ----------
42
+ min : ndarray of shape (n_features,)
43
+ Minimum of each feature over all samples.
44
+ max : ndarray of shape (n_features,)
45
+ Maximum of each feature over all samples.
46
+ sum : ndarray of shape (n_features,)
47
+ Sum of each feature over all samples.
48
+ mean : ndarray of shape (n_features,)
49
+ Mean of each feature over all samples.
50
+ variance : ndarray of shape (n_features,)
51
+ Variance of each feature over all samples.
52
+ variation : ndarray of shape (n_features,)
53
+ Variation of each feature over all samples.
54
+ sum_squares : ndarray of shape (n_features,)
55
+ Sum of squares for each feature over all samples.
56
+ standard_deviation : ndarray of shape (n_features,)
57
+ Standard deviation of each feature over all samples.
58
+ sum_squares_centered : ndarray of shape (n_features,)
59
+ Centered sum of squares for each feature over all samples.
60
+ second_order_raw_moment : ndarray of shape (n_features,)
61
+ Second order moment of each feature over all samples.
62
+ """
63
+
64
+ def __init__(self, result_options="all"):
65
+ self.options = result_options
66
+
67
+ _onedal_basic_statistics = staticmethod(onedal_BasicStatistics)
68
+
69
+ def _save_attributes(self):
70
+ assert hasattr(self, "_onedal_estimator")
71
+
72
+ if self.options == "all":
73
+ result_options = onedal_BasicStatistics.get_all_result_options()
74
+ else:
75
+ result_options = self.options
76
+
77
+ if isinstance(result_options, str):
78
+ setattr(self, result_options, getattr(self._onedal_estimator, result_options))
79
+ elif isinstance(result_options, list):
80
+ for option in result_options:
81
+ setattr(self, option, getattr(self._onedal_estimator, option))
82
+
83
+ def _onedal_supported(self, method_name, *data):
84
+ patching_status = PatchingConditionsChain(
85
+ f"sklearnex.basic_statistics.{self.__class__.__name__}.{method_name}"
86
+ )
87
+ return patching_status
88
+
89
+ _onedal_cpu_supported = _onedal_supported
90
+ _onedal_gpu_supported = _onedal_supported
91
+
92
+ def _onedal_fit(self, X, sample_weight=None, queue=None):
93
+ if sklearn_check_version("1.0"):
94
+ X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_2d=False)
95
+ else:
96
+ X = check_array(X, dtype=[np.float64, np.float32])
97
+
98
+ if sample_weight is not None:
99
+ sample_weight = _check_sample_weight(sample_weight, X)
100
+
101
+ onedal_params = {
102
+ "result_options": self.options,
103
+ }
104
+
105
+ if not hasattr(self, "_onedal_estimator"):
106
+ self._onedal_estimator = self._onedal_basic_statistics(**onedal_params)
107
+ self._onedal_estimator.fit(X, sample_weight, queue)
108
+ self._save_attributes()
109
+
110
+ def compute(self, data, weights=None, queue=None):
111
+ return self._onedal_estimator.compute(data, weights, queue)
112
+
113
+ def fit(self, X, y=None, *, sample_weight=None):
114
+ """Compute statistics with X, using minibatches of size batch_size.
115
+
116
+ Parameters
117
+ ----------
118
+ X : array-like of shape (n_samples, n_features)
119
+ Data for compute, where `n_samples` is the number of samples and
120
+ `n_features` is the number of features.
121
+
122
+ y : Ignored
123
+ Not used, present for API consistency by convention.
124
+
125
+ sample_weight : array-like of shape (n_samples,), default=None
126
+ Weights for compute weighted statistics, where `n_samples` is the number of samples.
127
+
128
+ Returns
129
+ -------
130
+ self : object
131
+ Returns the instance itself.
132
+ """
133
+ dispatch(
134
+ self,
135
+ "fit",
136
+ {
137
+ "onedal": self.__class__._onedal_fit,
138
+ "sklearn": None,
139
+ },
140
+ X,
141
+ sample_weight,
142
+ )
143
+ return self
@@ -0,0 +1,251 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from onedal.basic_statistics.tests.test_basic_statistics import (
22
+ expected_max,
23
+ expected_mean,
24
+ expected_sum,
25
+ options_and_tests,
26
+ )
27
+ from onedal.tests.utils._dataframes_support import (
28
+ _convert_to_dataframe,
29
+ get_dataframes_and_queues,
30
+ )
31
+ from sklearnex.basic_statistics import BasicStatistics
32
+
33
+
34
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
35
+ def test_sklearnex_import_basic_statistics(dataframe, queue):
36
+ X = np.array([[0, 0], [1, 1]])
37
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
38
+
39
+ weights = np.array([1, 0.5])
40
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
41
+
42
+ result = BasicStatistics().fit(X_df)
43
+
44
+ expected_mean = np.array([0.5, 0.5])
45
+ expected_min = np.array([0, 0])
46
+ expected_max = np.array([1, 1])
47
+
48
+ assert_allclose(expected_mean, result.mean)
49
+ assert_allclose(expected_max, result.max)
50
+ assert_allclose(expected_min, result.min)
51
+
52
+ result = BasicStatistics().fit(X_df, sample_weight=weights_df)
53
+
54
+ expected_weighted_mean = np.array([0.25, 0.25])
55
+ expected_weighted_min = np.array([0, 0])
56
+ expected_weighted_max = np.array([0.5, 0.5])
57
+
58
+ assert_allclose(expected_weighted_mean, result.mean)
59
+ assert_allclose(expected_weighted_min, result.min)
60
+ assert_allclose(expected_weighted_max, result.max)
61
+
62
+
63
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
64
+ @pytest.mark.parametrize("weighted", [True, False])
65
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
66
+ def test_multiple_options_on_gold_data(dataframe, queue, weighted, dtype):
67
+ X = np.array([[0, 0], [1, 1]])
68
+ X = X.astype(dtype=dtype)
69
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
70
+ if weighted:
71
+ weights = np.array([1, 0.5])
72
+ weights = weights.astype(dtype=dtype)
73
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
74
+ basicstat = BasicStatistics()
75
+
76
+ if weighted:
77
+ result = basicstat.fit(X_df, sample_weight=weights_df)
78
+ else:
79
+ result = basicstat.fit(X_df)
80
+
81
+ if weighted:
82
+ expected_weighted_mean = np.array([0.25, 0.25])
83
+ expected_weighted_min = np.array([0, 0])
84
+ expected_weighted_max = np.array([0.5, 0.5])
85
+ assert_allclose(expected_weighted_mean, result.mean)
86
+ assert_allclose(expected_weighted_max, result.max)
87
+ assert_allclose(expected_weighted_min, result.min)
88
+ else:
89
+ expected_mean = np.array([0.5, 0.5])
90
+ expected_min = np.array([0, 0])
91
+ expected_max = np.array([1, 1])
92
+ assert_allclose(expected_mean, result.mean)
93
+ assert_allclose(expected_max, result.max)
94
+ assert_allclose(expected_min, result.min)
95
+
96
+
97
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
98
+ @pytest.mark.parametrize("option", options_and_tests)
99
+ @pytest.mark.parametrize("row_count", [100, 1000])
100
+ @pytest.mark.parametrize("column_count", [10, 100])
101
+ @pytest.mark.parametrize("weighted", [True, False])
102
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
103
+ def test_single_option_on_random_data(
104
+ dataframe, queue, option, row_count, column_count, weighted, dtype
105
+ ):
106
+ result_option, function, tols = option
107
+ fp32tol, fp64tol = tols
108
+ seed = 77
109
+ gen = np.random.default_rng(seed)
110
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
111
+ X = X.astype(dtype=dtype)
112
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
113
+ if weighted:
114
+ weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
115
+ weights = weights.astype(dtype=dtype)
116
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
117
+ basicstat = BasicStatistics(result_options=result_option)
118
+
119
+ if weighted:
120
+ result = basicstat.fit(X_df, sample_weight=weights_df)
121
+ else:
122
+ result = basicstat.fit(X_df)
123
+
124
+ res = getattr(result, result_option)
125
+ if weighted:
126
+ weighted_data = np.diag(weights) @ X
127
+ gtr = function(weighted_data)
128
+ else:
129
+ gtr = function(X)
130
+
131
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
132
+ assert_allclose(gtr, res, atol=tol)
133
+
134
+
135
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
136
+ @pytest.mark.parametrize("row_count", [100, 1000])
137
+ @pytest.mark.parametrize("column_count", [10, 100])
138
+ @pytest.mark.parametrize("weighted", [True, False])
139
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
140
+ def test_multiple_options_on_random_data(
141
+ dataframe, queue, row_count, column_count, weighted, dtype
142
+ ):
143
+ seed = 77
144
+ gen = np.random.default_rng(seed)
145
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
146
+ X = X.astype(dtype=dtype)
147
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
148
+ if weighted:
149
+ weights = gen.uniform(low=-0.5, high=1.0, size=row_count)
150
+ weights = weights.astype(dtype=dtype)
151
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
152
+ basicstat = BasicStatistics(result_options=["mean", "max", "sum"])
153
+
154
+ if weighted:
155
+ result = basicstat.fit(X_df, sample_weight=weights_df)
156
+ else:
157
+ result = basicstat.fit(X_df)
158
+
159
+ res_mean, res_max, res_sum = result.mean, result.max, result.sum
160
+ if weighted:
161
+ weighted_data = np.diag(weights) @ X
162
+ gtr_mean, gtr_max, gtr_sum = (
163
+ expected_mean(weighted_data),
164
+ expected_max(weighted_data),
165
+ expected_sum(weighted_data),
166
+ )
167
+ else:
168
+ gtr_mean, gtr_max, gtr_sum = (
169
+ expected_mean(X),
170
+ expected_max(X),
171
+ expected_sum(X),
172
+ )
173
+
174
+ tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7
175
+ assert_allclose(gtr_mean, res_mean, atol=tol)
176
+ assert_allclose(gtr_max, res_max, atol=tol)
177
+ assert_allclose(gtr_sum, res_sum, atol=tol)
178
+
179
+
180
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
181
+ @pytest.mark.parametrize("row_count", [100, 1000])
182
+ @pytest.mark.parametrize("column_count", [10, 100])
183
+ @pytest.mark.parametrize("weighted", [True, False])
184
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
185
+ def test_all_option_on_random_data(
186
+ dataframe, queue, row_count, column_count, weighted, dtype
187
+ ):
188
+ seed = 77
189
+ gen = np.random.default_rng(seed)
190
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
191
+ X = X.astype(dtype=dtype)
192
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
193
+ if weighted:
194
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
195
+ weights = weights.astype(dtype=dtype)
196
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
197
+ basicstat = BasicStatistics(result_options="all")
198
+
199
+ if weighted:
200
+ result = basicstat.fit(X_df, sample_weight=weights_df)
201
+ else:
202
+ result = basicstat.fit(X_df)
203
+
204
+ if weighted:
205
+ weighted_data = np.diag(weights) @ X
206
+
207
+ for option in options_and_tests:
208
+ result_option, function, tols = option
209
+ fp32tol, fp64tol = tols
210
+ res = getattr(result, result_option)
211
+ if weighted:
212
+ gtr = function(weighted_data)
213
+ else:
214
+ gtr = function(X)
215
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
216
+ assert_allclose(gtr, res, atol=tol)
217
+
218
+
219
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
220
+ @pytest.mark.parametrize("option", options_and_tests)
221
+ @pytest.mark.parametrize("data_size", [100, 1000])
222
+ @pytest.mark.parametrize("weighted", [True, False])
223
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
224
+ def test_1d_input_on_random_data(dataframe, queue, option, data_size, weighted, dtype):
225
+ result_option, function, tols = option
226
+ fp32tol, fp64tol = tols
227
+ seed = 77
228
+ gen = np.random.default_rng(seed)
229
+ X = gen.uniform(low=-0.3, high=+0.7, size=data_size)
230
+ X = X.astype(dtype=dtype)
231
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
232
+ if weighted:
233
+ weights = gen.uniform(low=-0.5, high=1.0, size=data_size)
234
+ weights = weights.astype(dtype=dtype)
235
+ weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
236
+ basicstat = BasicStatistics(result_options=result_option)
237
+
238
+ if weighted:
239
+ result = basicstat.fit(X_df, sample_weight=weights_df)
240
+ else:
241
+ result = basicstat.fit(X_df)
242
+
243
+ res = getattr(result, result_option)
244
+ if weighted:
245
+ weighted_data = weights * X
246
+ gtr = function(weighted_data)
247
+ else:
248
+ gtr = function(X)
249
+
250
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
251
+ assert_allclose(gtr, res, atol=tol)
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import pytest
19
19
  from numpy.testing import assert_allclose
20
20
 
21
- from onedal.basic_statistics.tests.test_incremental_basic_statistics import (
21
+ from onedal.basic_statistics.tests.test_basic_statistics import (
22
22
  expected_max,
23
23
  expected_mean,
24
24
  expected_sum,
@@ -17,7 +17,6 @@
17
17
  import numbers
18
18
  from abc import ABC
19
19
 
20
- import numpy as np
21
20
  from scipy import sparse as sp
22
21
  from sklearn.cluster import DBSCAN as sklearn_DBSCAN
23
22
  from sklearn.utils.validation import _check_sample_weight
@@ -85,6 +84,9 @@ class DBSCAN(sklearn_DBSCAN, BaseDBSCAN):
85
84
  self.n_jobs = n_jobs
86
85
 
87
86
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
87
+ if sklearn_check_version("1.0"):
88
+ X = self._validate_data(X, force_all_finite=False)
89
+
88
90
  onedal_params = {
89
91
  "eps": self.eps,
90
92
  "min_samples": self.min_samples,
@@ -15,3 +15,11 @@
15
15
  # ===============================================================================
16
16
 
17
17
  from daal4py.sklearn.cluster import KMeans
18
+ from onedal._device_offload import support_usm_ndarray
19
+
20
+ # Note: `sklearnex.cluster.KMeans` only has functional
21
+ # sycl GPU support. No GPU device will be offloaded.
22
+ KMeans.fit = support_usm_ndarray(queue_param=False)(KMeans.fit)
23
+ KMeans.fit_predict = support_usm_ndarray(queue_param=False)(KMeans.fit_predict)
24
+ KMeans.predict = support_usm_ndarray(queue_param=False)(KMeans.predict)
25
+ KMeans.score = support_usm_ndarray(queue_param=False)(KMeans.score)
@@ -18,16 +18,18 @@ import numpy as np
18
18
  import pytest
19
19
  from numpy.testing import assert_allclose
20
20
 
21
+ from onedal.tests.utils._dataframes_support import (
22
+ _convert_to_dataframe,
23
+ get_dataframes_and_queues,
24
+ )
21
25
 
22
- # TODO:
23
- # adding this parameterized testing
24
- # somehow breaks other test with preview module patch:
25
- # sklearnex/tests/test_monkeypatch.py::test_preview_namespace.
26
- # @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
27
- def test_sklearnex_import_dbscan():
26
+
27
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
28
+ def test_sklearnex_import_dbscan(dataframe, queue):
28
29
  from sklearnex.cluster import DBSCAN
29
30
 
30
31
  X = np.array([[1, 2], [2, 2], [2, 3], [8, 7], [8, 8], [25, 80]])
32
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
31
33
  dbscan = DBSCAN(eps=3, min_samples=2).fit(X)
32
34
  assert "sklearnex" in dbscan.__module__
33
35
 
@@ -15,16 +15,28 @@
15
15
  # ===============================================================================
16
16
 
17
17
  import numpy as np
18
+ import pytest
18
19
  from numpy.testing import assert_allclose
19
20
 
21
+ from onedal.tests.utils._dataframes_support import (
22
+ _as_numpy,
23
+ _convert_to_dataframe,
24
+ get_dataframes_and_queues,
25
+ )
26
+
27
+
28
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
29
+ def test_sklearnex_import(dataframe, queue):
20
30
 
21
- def test_sklearnex_import():
22
31
  from sklearnex.cluster import KMeans
23
32
 
24
33
  X = np.array([[1, 2], [1, 4], [1, 0], [10, 2], [10, 4], [10, 0]])
34
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
25
35
  kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
26
36
  assert "daal4py" in kmeans.__module__
27
37
 
28
- result = kmeans.predict([[0, 0], [12, 3]])
38
+ X_test = [[0, 0], [12, 3]]
39
+ X_test = _convert_to_dataframe(X_test, sycl_queue=queue, target_df=dataframe)
40
+ result = kmeans.predict(X_test)
29
41
  expected = np.array([1, 0], dtype=np.int32)
30
- assert_allclose(expected, result)
42
+ assert_allclose(expected, _as_numpy(result))
@@ -19,7 +19,8 @@ import logging
19
19
 
20
20
  import pytest
21
21
 
22
- from sklearnex import patch_sklearn, unpatch_sklearn
22
+ from daal4py.sklearn._utils import sklearn_check_version
23
+ from sklearnex import config_context, patch_sklearn, unpatch_sklearn
23
24
 
24
25
 
25
26
  def pytest_configure(config):
@@ -61,3 +62,12 @@ def with_sklearnex():
61
62
  patch_sklearn()
62
63
  yield
63
64
  unpatch_sklearn()
65
+
66
+
67
+ @pytest.fixture
68
+ def with_array_api():
69
+ if sklearn_check_version("1.2"):
70
+ with config_context(array_api_dispatch=True):
71
+ yield
72
+ else:
73
+ yield