scikit-learn-intelex 2024.1.0__py311-none-win_amd64.whl → 2025.1.0__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (277) hide show
  1. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
  2. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
  3. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/_daal4py.cp311-win_amd64.pyd +0 -0
  4. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/doc/third-party-programs.txt +424 -0
  5. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +19 -0
  6. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mb/model_builders.py +377 -0
  7. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp311-win_amd64.pyd +0 -0
  8. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
  9. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +248 -0
  10. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
  11. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
  12. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
  13. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +597 -0
  14. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  15. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition}/__init__.py +3 -3
  16. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +524 -0
  17. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +196 -0
  18. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +337 -0
  19. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
  20. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1397 -0
  21. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  22. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn}/linear_model/__init__.py +29 -29
  23. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +848 -0
  24. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +272 -0
  25. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +325 -0
  26. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +2 -2
  27. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
  28. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  29. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +1026 -0
  30. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
  31. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +208 -0
  32. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  33. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold}/__init__.py +4 -2
  34. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +405 -0
  35. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
  36. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +236 -0
  37. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
  38. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
  39. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
  40. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  41. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
  42. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  43. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/_models_info.py +13 -22
  44. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  45. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/test_patching.py +10 -42
  46. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch}/tests/utils/_launch_algorithms.py +4 -5
  47. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
  48. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +503 -0
  49. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +139 -0
  50. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +74 -0
  51. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  52. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  53. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
  54. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +734 -0
  55. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
  56. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +75 -0
  57. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
  58. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +693 -0
  59. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/__init__.py +83 -0
  60. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_config.py +54 -0
  61. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_device_offload.py +222 -0
  62. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp311-win_amd64.pyd +0 -0
  63. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp311-win_amd64.pyd +0 -0
  64. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
  65. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +107 -0
  66. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +160 -0
  67. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +298 -0
  68. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +196 -0
  69. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
  70. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +110 -0
  71. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +564 -0
  72. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +115 -0
  73. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
  74. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
  75. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
  76. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_base.py +38 -0
  77. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
  78. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
  79. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_policy.py +59 -0
  80. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/_spmd_policy.py +30 -0
  81. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +125 -0
  82. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/common/tests/test_policy.py +76 -0
  83. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance}/__init__.py +3 -2
  84. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +125 -0
  85. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +146 -0
  86. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
  87. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +122 -0
  88. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +19 -0
  89. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +154 -0
  90. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +126 -0
  91. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +414 -0
  92. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition}/__init__.py +3 -2
  93. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +204 -0
  94. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +186 -0
  95. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +198 -0
  96. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
  97. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +727 -0
  98. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
  99. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
  100. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +258 -0
  101. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +329 -0
  102. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +249 -0
  103. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +168 -0
  104. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +107 -0
  105. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +250 -0
  106. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
  107. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
  108. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
  109. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +767 -0
  110. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
  111. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
  112. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +25 -0
  113. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +153 -0
  114. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
  115. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
  116. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/svm.py +556 -0
  117. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +351 -0
  118. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
  119. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
  120. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +176 -0
  121. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
  122. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/test_common.py +57 -0
  123. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +162 -0
  124. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +102 -0
  125. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/__init__.py +49 -0
  126. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +81 -0
  127. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/_dpep_helpers.py +56 -0
  128. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/onedal/utils/validation.py +440 -0
  129. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__init__.py +10 -7
  130. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_config.py +22 -16
  131. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +126 -0
  132. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/_utils.py +27 -4
  133. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  134. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +230 -0
  135. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +345 -0
  136. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +270 -0
  137. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +404 -0
  138. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +1 -1
  139. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +19 -10
  140. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +395 -0
  141. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +8 -6
  142. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +159 -0
  143. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
  144. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  145. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +398 -0
  146. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +237 -0
  147. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +425 -0
  148. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +25 -9
  149. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +241 -60
  150. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +250 -188
  151. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +39 -21
  152. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +16 -2
  153. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
  154. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +13 -0
  155. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +482 -0
  156. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +425 -0
  157. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +341 -0
  158. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +194 -133
  159. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +7 -0
  160. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +207 -0
  161. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +153 -0
  162. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
  163. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +134 -0
  164. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +4 -0
  165. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +5 -0
  166. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +3 -0
  167. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +5 -0
  168. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
  169. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +236 -0
  170. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +53 -6
  171. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +51 -155
  172. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +46 -149
  173. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +55 -100
  174. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +16 -18
  175. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +1 -3
  176. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +138 -0
  177. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  178. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  179. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +233 -0
  180. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +266 -0
  181. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model}/__init__.py +19 -19
  182. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/ridge.py +424 -0
  183. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_ridge.py +102 -0
  184. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
  185. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
  186. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  187. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +107 -0
  188. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +307 -0
  189. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +97 -0
  190. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +172 -0
  191. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
  192. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
  193. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +37 -0
  194. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +107 -0
  195. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +184 -0
  196. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition}/__init__.py +3 -2
  197. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +11 -12
  198. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +269 -0
  199. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +128 -0
  200. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  201. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +265 -0
  202. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -1
  203. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py → scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +14 -18
  204. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  205. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +329 -0
  206. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +145 -0
  207. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +162 -0
  208. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +288 -0
  209. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +339 -0
  210. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +172 -78
  211. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +74 -70
  212. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +170 -77
  213. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +66 -66
  214. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +12 -20
  215. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +390 -0
  216. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +123 -0
  217. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +379 -0
  218. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +276 -0
  219. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +108 -0
  220. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  221. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +385 -0
  222. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +321 -0
  223. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +44 -0
  224. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +371 -0
  225. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
  226. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +82 -0
  227. scikit_learn_intelex-2025.1.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_finite.py +89 -0
  228. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/METADATA +231 -230
  229. scikit_learn_intelex-2025.1.0.dist-info/RECORD +257 -0
  230. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/WHEEL +1 -1
  231. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +0 -223
  232. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -17
  233. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -30
  234. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  235. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  236. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -388
  237. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -17
  238. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -82
  239. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
  240. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
  241. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -84
  242. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -376
  243. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -98
  244. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -376
  245. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  246. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +0 -188
  247. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +0 -225
  248. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -227
  249. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -428
  250. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  251. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  252. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  253. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  254. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  255. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  256. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  257. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  258. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  259. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  260. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  261. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  262. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  263. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  264. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  265. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  266. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  267. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  268. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  269. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  270. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  271. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  272. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  273. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  274. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  275. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2025.1.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  276. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/LICENSE.txt +0 -0
  277. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2025.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
20
+ from sklearn import datasets
21
+ from sklearn.metrics.pairwise import rbf_kernel
22
+ from sklearn.svm import NuSVR as SklearnNuSVR
23
+
24
+ from onedal.svm import NuSVR
25
+ from onedal.tests.utils._device_selection import (
26
+ get_queues,
27
+ pass_if_not_implemented_for_gpu,
28
+ )
29
+
30
+ synth_params = {"n_samples": 500, "n_features": 100, "random_state": 42}
31
+
32
+
33
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
34
+ @pytest.mark.parametrize("queue", get_queues())
35
+ def test_diabetes_simple(queue):
36
+ diabetes = datasets.load_diabetes()
37
+ clf = NuSVR(kernel="linear", C=10.0)
38
+ clf.fit(diabetes.data, diabetes.target, queue=queue)
39
+ assert clf.score(diabetes.data, diabetes.target, queue=queue) > 0.02
40
+
41
+
42
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
43
+ @pytest.mark.parametrize("queue", get_queues())
44
+ def test_input_format_for_diabetes(queue):
45
+ diabetes = datasets.load_diabetes()
46
+
47
+ c_contiguous_numpy = np.asanyarray(diabetes.data, dtype="float", order="C")
48
+ assert c_contiguous_numpy.flags.c_contiguous
49
+ assert not c_contiguous_numpy.flags.f_contiguous
50
+ assert not c_contiguous_numpy.flags.fnc
51
+
52
+ clf = NuSVR(kernel="linear", C=10.0)
53
+ clf.fit(c_contiguous_numpy, diabetes.target, queue=queue)
54
+ dual_c_contiguous_numpy = clf.dual_coef_
55
+ res_c_contiguous_numpy = clf.predict(c_contiguous_numpy, queue=queue)
56
+
57
+ f_contiguous_numpy = np.asanyarray(diabetes.data, dtype="float", order="F")
58
+ assert not f_contiguous_numpy.flags.c_contiguous
59
+ assert f_contiguous_numpy.flags.f_contiguous
60
+ assert f_contiguous_numpy.flags.fnc
61
+
62
+ clf = NuSVR(kernel="linear", C=10.0)
63
+ clf.fit(f_contiguous_numpy, diabetes.target, queue=queue)
64
+ dual_f_contiguous_numpy = clf.dual_coef_
65
+ res_f_contiguous_numpy = clf.predict(f_contiguous_numpy, queue=queue)
66
+ assert_allclose(dual_c_contiguous_numpy, dual_f_contiguous_numpy)
67
+ assert_allclose(res_c_contiguous_numpy, res_f_contiguous_numpy)
68
+
69
+
70
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
71
+ @pytest.mark.parametrize("queue", get_queues())
72
+ def test_predict(queue):
73
+ iris = datasets.load_iris()
74
+ X = iris.data
75
+ y = iris.target
76
+
77
+ reg = NuSVR(kernel="linear", C=0.1).fit(X, y, queue=queue)
78
+
79
+ linear = np.dot(X, reg.support_vectors_.T)
80
+ dec = np.dot(linear, reg.dual_coef_.T) + reg.intercept_
81
+ assert_array_almost_equal(dec.ravel(), reg.predict(X, queue=queue).ravel())
82
+
83
+ reg = NuSVR(kernel="rbf", gamma=1).fit(X, y, queue=queue)
84
+
85
+ rbfs = rbf_kernel(X, reg.support_vectors_, gamma=reg.gamma)
86
+ dec = np.dot(rbfs, reg.dual_coef_.T) + reg.intercept_
87
+ assert_array_almost_equal(dec.ravel(), reg.predict(X, queue=queue).ravel())
88
+
89
+
90
+ def _test_diabetes_compare_with_sklearn(queue, kernel):
91
+ diabetes = datasets.load_diabetes()
92
+ clf_onedal = NuSVR(kernel=kernel, nu=0.25, C=10.0)
93
+ clf_onedal.fit(diabetes.data, diabetes.target, queue=queue)
94
+ result = clf_onedal.score(diabetes.data, diabetes.target, queue=queue)
95
+
96
+ clf_sklearn = SklearnNuSVR(kernel=kernel, nu=0.25, C=10.0)
97
+ clf_sklearn.fit(diabetes.data, diabetes.target)
98
+ expected = clf_sklearn.score(diabetes.data, diabetes.target)
99
+
100
+ assert result > expected - 1e-5
101
+ assert_allclose(clf_sklearn.intercept_, clf_onedal.intercept_, atol=1e-3)
102
+ assert_allclose(
103
+ clf_sklearn.support_vectors_.shape, clf_sklearn.support_vectors_.shape
104
+ )
105
+ assert_allclose(clf_sklearn.dual_coef_, clf_onedal.dual_coef_, atol=1e-2)
106
+
107
+
108
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
109
+ @pytest.mark.parametrize("queue", get_queues())
110
+ @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
111
+ def test_diabetes_compare_with_sklearn(queue, kernel):
112
+ if kernel == "sigmoid":
113
+ pytest.skip("Sparse sigmoid kernel function is buggy.")
114
+ _test_diabetes_compare_with_sklearn(queue, kernel)
115
+
116
+
117
+ def _test_synth_rbf_compare_with_sklearn(queue, C, nu, gamma):
118
+ x, y = datasets.make_regression(**synth_params)
119
+
120
+ clf = NuSVR(kernel="rbf", gamma=gamma, C=C, nu=nu)
121
+ clf.fit(x, y, queue=queue)
122
+ result = clf.score(x, y, queue=queue)
123
+
124
+ clf = SklearnNuSVR(kernel="rbf", gamma=gamma, C=C, nu=nu)
125
+ clf.fit(x, y)
126
+ expected = clf.score(x, y)
127
+
128
+ assert result > 0.4
129
+ assert abs(result - expected) < 1e-3
130
+
131
+
132
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
133
+ @pytest.mark.parametrize("queue", get_queues())
134
+ @pytest.mark.parametrize("gamma", ["scale", "auto"])
135
+ @pytest.mark.parametrize("C", [100.0, 1000.0])
136
+ @pytest.mark.parametrize("nu", [0.25, 0.75])
137
+ def test_synth_rbf_compare_with_sklearn(queue, C, nu, gamma):
138
+ _test_synth_rbf_compare_with_sklearn(queue, C, nu, gamma)
139
+
140
+
141
+ def _test_synth_linear_compare_with_sklearn(queue, C, nu):
142
+ x, y = datasets.make_regression(**synth_params)
143
+
144
+ clf = NuSVR(kernel="linear", C=C, nu=nu)
145
+ clf.fit(x, y, queue=queue)
146
+ result = clf.score(x, y, queue=queue)
147
+
148
+ clf = SklearnNuSVR(kernel="linear", C=C, nu=nu)
149
+ clf.fit(x, y)
150
+ expected = clf.score(x, y)
151
+
152
+ # Linear kernel doesn't work well for synthetic regression
153
+ # resulting in low R2 score
154
+ # assert result > 0.5
155
+ assert abs(result - expected) < 1e-3
156
+
157
+
158
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
159
+ @pytest.mark.parametrize("queue", get_queues())
160
+ @pytest.mark.parametrize("C", [0.001, 0.1])
161
+ @pytest.mark.parametrize("nu", [0.25, 0.75])
162
+ def test_synth_linear_compare_with_sklearn(queue, C, nu):
163
+ _test_synth_linear_compare_with_sklearn(queue, C, nu)
164
+
165
+
166
+ def _test_synth_poly_compare_with_sklearn(queue, params):
167
+ x, y = datasets.make_regression(**synth_params)
168
+
169
+ clf = NuSVR(kernel="poly", **params)
170
+ clf.fit(x, y, queue=queue)
171
+ result = clf.score(x, y, queue=queue)
172
+
173
+ clf = SklearnNuSVR(kernel="poly", **params)
174
+ clf.fit(x, y)
175
+ expected = clf.score(x, y)
176
+
177
+ assert result > 0.5
178
+ assert abs(result - expected) < 1e-3
179
+
180
+
181
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
182
+ @pytest.mark.parametrize("queue", get_queues())
183
+ @pytest.mark.parametrize(
184
+ "params",
185
+ [
186
+ {"degree": 2, "coef0": 0.1, "gamma": "scale", "C": 100, "nu": 0.25},
187
+ {"degree": 3, "coef0": 0.0, "gamma": "scale", "C": 1000, "nu": 0.75},
188
+ ],
189
+ )
190
+ def test_synth_poly_compare_with_sklearn(queue, params):
191
+ _test_synth_poly_compare_with_sklearn(queue, params)
192
+
193
+
194
+ @pass_if_not_implemented_for_gpu(reason="nusvr is not implemented")
195
+ @pytest.mark.parametrize("queue", get_queues())
196
+ def test_pickle(queue):
197
+ diabetes = datasets.load_diabetes()
198
+
199
+ clf = NuSVR(kernel="rbf", C=10.0)
200
+ clf.fit(diabetes.data, diabetes.target, queue=queue)
201
+ expected = clf.predict(diabetes.data, queue=queue)
202
+
203
+ import pickle
204
+
205
+ dump = pickle.dumps(clf)
206
+ clf2 = pickle.loads(dump)
207
+
208
+ assert type(clf2) == clf.__class__
209
+ result = clf2.predict(diabetes.data, queue=queue)
210
+ assert_array_equal(expected, result)
@@ -0,0 +1,176 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from os import environ
18
+
19
+ # sklearn requires manual enabling of Scipy array API support
20
+ # if `array-api-compat` package is present in environment
21
+ # TODO: create generic approach to handle this for all tests
22
+ environ["SCIPY_ARRAY_API"] = "1"
23
+
24
+
25
+ import numpy as np
26
+ import pytest
27
+ import sklearn.utils.estimator_checks
28
+ from numpy.testing import assert_array_almost_equal, assert_array_equal
29
+ from sklearn import datasets
30
+ from sklearn.datasets import make_blobs
31
+ from sklearn.metrics.pairwise import rbf_kernel
32
+ from sklearn.model_selection import train_test_split
33
+
34
+ from onedal.svm import SVC
35
+ from onedal.tests.utils._device_selection import (
36
+ get_queues,
37
+ pass_if_not_implemented_for_gpu,
38
+ )
39
+
40
+
41
+ def _test_libsvm_parameters(queue, array_constr, dtype):
42
+ X = array_constr([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=dtype)
43
+ y = array_constr([1, 1, 1, 2, 2, 2], dtype=dtype)
44
+
45
+ clf = SVC(kernel="linear").fit(X, y, queue=queue)
46
+ assert_array_equal(clf.dual_coef_, [[-0.25, 0.25]])
47
+ assert_array_equal(clf.support_, [1, 3])
48
+ assert_array_equal(clf.support_vectors_, (X[1], X[3]))
49
+ assert_array_equal(clf.intercept_, [0.0])
50
+ assert_array_equal(clf.predict(X), y)
51
+
52
+
53
+ @pytest.mark.parametrize("queue", get_queues())
54
+ @pytest.mark.parametrize("array_constr", [np.array])
55
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
56
+ def test_libsvm_parameters(queue, array_constr, dtype):
57
+ if queue and queue.sycl_device.is_gpu:
58
+ pytest.skip("Sporadic failures on GPU sycl_queue.")
59
+ _test_libsvm_parameters(queue, array_constr, dtype)
60
+
61
+
62
+ @pass_if_not_implemented_for_gpu(reason="class weights are not implemented")
63
+ @pytest.mark.parametrize(
64
+ "queue",
65
+ get_queues("cpu")
66
+ + [
67
+ pytest.param(
68
+ get_queues("gpu"),
69
+ marks=pytest.mark.xfail(
70
+ reason="class weights are not implemented " "but the error is not raised"
71
+ ),
72
+ )
73
+ ],
74
+ )
75
+ def test_class_weight(queue):
76
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]])
77
+ y = np.array([1, 1, 1, 2, 2, 2])
78
+
79
+ clf = SVC(class_weight={1: 0.1})
80
+ clf.fit(X, y, queue=queue)
81
+ assert_array_almost_equal(clf.predict(X, queue=queue), [2] * 6)
82
+
83
+
84
+ @pytest.mark.parametrize("queue", get_queues())
85
+ def test_sample_weight(queue):
86
+ if queue and queue.sycl_device.is_gpu:
87
+ pytest.skip("Sporadic failures on GPU sycl_queue.")
88
+ X = np.array([[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 2]])
89
+ y = np.array([1, 1, 1, 2, 2, 2])
90
+
91
+ clf = SVC(kernel="linear")
92
+ clf.fit(X, y, sample_weight=[1] * 6, queue=queue)
93
+ assert_array_almost_equal(clf.intercept_, [0.0])
94
+
95
+
96
+ @pytest.mark.parametrize("queue", get_queues())
97
+ def test_decision_function(queue):
98
+ X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]], dtype=np.float32)
99
+ Y = np.array([1, 1, 1, 2, 2, 2], dtype=np.float32)
100
+
101
+ clf = SVC(kernel="rbf", gamma=1, decision_function_shape="ovo")
102
+ clf.fit(X, Y, queue=queue)
103
+
104
+ rbfs = rbf_kernel(X, clf.support_vectors_, gamma=clf.gamma)
105
+ dec = np.dot(rbfs, clf.dual_coef_.T) + clf.intercept_
106
+ assert_array_almost_equal(dec.ravel(), clf.decision_function(X, queue=queue))
107
+
108
+
109
+ @pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
110
+ @pytest.mark.parametrize("queue", get_queues())
111
+ def test_iris(queue):
112
+ iris = datasets.load_iris()
113
+ clf = SVC(kernel="linear").fit(iris.data, iris.target, queue=queue)
114
+ assert clf.score(iris.data, iris.target, queue=queue) > 0.9
115
+ assert_array_equal(clf.classes_, np.sort(clf.classes_))
116
+
117
+
118
+ @pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
119
+ @pytest.mark.parametrize("queue", get_queues())
120
+ def test_decision_function_shape(queue):
121
+ X, y = make_blobs(n_samples=80, centers=5, random_state=0)
122
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
123
+
124
+ # check shape of ovo_decition_function=True
125
+ clf = SVC(kernel="linear", decision_function_shape="ovo").fit(
126
+ X_train, y_train, queue=queue
127
+ )
128
+ dec = clf.decision_function(X_train, queue=queue)
129
+ assert dec.shape == (len(X_train), 10)
130
+
131
+ with pytest.raises(ValueError, match="must be either 'ovr' or 'ovo'"):
132
+ SVC(decision_function_shape="bad").fit(X_train, y_train, queue=queue)
133
+
134
+
135
+ @pass_if_not_implemented_for_gpu(reason="multiclass svm is not implemented")
136
+ @pytest.mark.parametrize("queue", get_queues())
137
+ def test_pickle(queue):
138
+ iris = datasets.load_iris()
139
+ clf = SVC(kernel="linear").fit(iris.data, iris.target, queue=queue)
140
+ expected = clf.decision_function(iris.data, queue=queue)
141
+
142
+ import pickle
143
+
144
+ dump = pickle.dumps(clf)
145
+ clf2 = pickle.loads(dump)
146
+
147
+ assert type(clf2) == clf.__class__
148
+ result = clf2.decision_function(iris.data, queue=queue)
149
+ assert_array_equal(expected, result)
150
+
151
+
152
+ @pass_if_not_implemented_for_gpu(reason="sigmoid kernel is not implemented")
153
+ @pytest.mark.parametrize(
154
+ "queue",
155
+ get_queues("cpu")
156
+ + [
157
+ pytest.param(
158
+ get_queues("gpu"),
159
+ marks=pytest.mark.xfail(
160
+ reason="raises Unimplemented error " "with inconsistent error message"
161
+ ),
162
+ )
163
+ ],
164
+ )
165
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
166
+ def test_svc_sigmoid(queue, dtype):
167
+ X_train = np.array(
168
+ [[-1, 2], [0, 0], [2, -1], [+1, +1], [+1, +2], [+2, +1]], dtype=dtype
169
+ )
170
+ X_test = np.array([[0, 2], [0.5, 0.5], [0.3, 0.1], [2, 0], [-1, -1]], dtype=dtype)
171
+ y_train = np.array([1, 1, 1, 2, 2, 2], dtype=dtype)
172
+ svc = SVC(kernel="sigmoid").fit(X_train, y_train, queue=queue)
173
+
174
+ assert_array_equal(svc.dual_coef_, [[-1, -1, -1, 1, 1, 1]])
175
+ assert_array_equal(svc.support_, [0, 1, 2, 3, 4, 5])
176
+ assert_array_equal(svc.predict(X_test, queue=queue), [2, 2, 1, 2, 1])
@@ -0,0 +1,243 @@
1
+ # ==============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
20
+ from sklearn import datasets
21
+ from sklearn.metrics.pairwise import rbf_kernel
22
+ from sklearn.svm import SVR as SklearnSVR
23
+
24
+ from onedal.svm import SVR
25
+ from onedal.tests.utils._device_selection import (
26
+ get_queues,
27
+ pass_if_not_implemented_for_gpu,
28
+ )
29
+
30
+ synth_params = {"n_samples": 500, "n_features": 100, "random_state": 42}
31
+
32
+
33
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
34
+ @pytest.mark.parametrize("queue", get_queues())
35
+ def test_run_to_run_fit(queue):
36
+ diabetes = datasets.load_diabetes()
37
+ clf_first = SVR(kernel="linear", C=10.0)
38
+ clf_first.fit(diabetes.data, diabetes.target, queue=queue)
39
+
40
+ for _ in range(10):
41
+ clf = SVR(kernel="linear", C=10.0)
42
+ clf.fit(diabetes.data, diabetes.target, queue=queue)
43
+ assert_allclose(clf_first.intercept_, clf.intercept_)
44
+ assert_allclose(clf_first.support_vectors_, clf.support_vectors_)
45
+ assert_allclose(clf_first.dual_coef_, clf.dual_coef_)
46
+
47
+
48
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
49
+ @pytest.mark.parametrize("queue", get_queues())
50
+ def test_diabetes_simple(queue):
51
+ diabetes = datasets.load_diabetes()
52
+ clf = SVR(kernel="linear", C=10.0)
53
+ clf.fit(diabetes.data, diabetes.target, queue=queue)
54
+ assert clf.score(diabetes.data, diabetes.target, queue=queue) > 0.02
55
+
56
+
57
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
58
+ @pytest.mark.parametrize("queue", get_queues())
59
+ def test_input_format_for_diabetes(queue):
60
+ diabetes = datasets.load_diabetes()
61
+
62
+ c_contiguous_numpy = np.asanyarray(diabetes.data, dtype="float", order="C")
63
+ assert c_contiguous_numpy.flags.c_contiguous
64
+ assert not c_contiguous_numpy.flags.f_contiguous
65
+ assert not c_contiguous_numpy.flags.fnc
66
+
67
+ clf = SVR(kernel="linear", C=10.0)
68
+ clf.fit(c_contiguous_numpy, diabetes.target, queue=queue)
69
+ dual_c_contiguous_numpy = clf.dual_coef_
70
+ res_c_contiguous_numpy = clf.predict(c_contiguous_numpy, queue=queue)
71
+
72
+ f_contiguous_numpy = np.asanyarray(diabetes.data, dtype="float", order="F")
73
+ assert not f_contiguous_numpy.flags.c_contiguous
74
+ assert f_contiguous_numpy.flags.f_contiguous
75
+ assert f_contiguous_numpy.flags.fnc
76
+
77
+ clf = SVR(kernel="linear", C=10.0)
78
+ clf.fit(f_contiguous_numpy, diabetes.target, queue=queue)
79
+ dual_f_contiguous_numpy = clf.dual_coef_
80
+ res_f_contiguous_numpy = clf.predict(f_contiguous_numpy, queue=queue)
81
+ assert_allclose(dual_c_contiguous_numpy, dual_f_contiguous_numpy)
82
+ assert_allclose(res_c_contiguous_numpy, res_f_contiguous_numpy)
83
+
84
+
85
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
86
+ @pytest.mark.parametrize("queue", get_queues())
87
+ def test_predict(queue):
88
+ iris = datasets.load_iris()
89
+ X = iris.data
90
+ y = iris.target
91
+
92
+ reg = SVR(kernel="linear", C=0.1).fit(X, y, queue=queue)
93
+
94
+ linear = np.dot(X, reg.support_vectors_.T)
95
+ dec = np.dot(linear, reg.dual_coef_.T) + reg.intercept_
96
+ assert_array_almost_equal(dec.ravel(), reg.predict(X, queue=queue).ravel())
97
+
98
+ reg = SVR(kernel="rbf", gamma=1).fit(X, y, queue=queue)
99
+
100
+ rbfs = rbf_kernel(X, reg.support_vectors_, gamma=reg.gamma)
101
+ dec = np.dot(rbfs, reg.dual_coef_.T) + reg.intercept_
102
+ assert_array_almost_equal(dec.ravel(), reg.predict(X, queue=queue).ravel())
103
+
104
+
105
+ def _test_diabetes_compare_with_sklearn(queue, kernel):
106
+ diabetes = datasets.load_diabetes()
107
+ clf_onedal = SVR(kernel=kernel, C=10.0, gamma=2)
108
+ clf_onedal.fit(diabetes.data, diabetes.target, queue=queue)
109
+ result = clf_onedal.score(diabetes.data, diabetes.target, queue=queue)
110
+
111
+ clf_sklearn = SklearnSVR(kernel=kernel, C=10.0, gamma=2)
112
+ clf_sklearn.fit(diabetes.data, diabetes.target)
113
+ expected = clf_sklearn.score(diabetes.data, diabetes.target)
114
+
115
+ assert result > expected - 1e-5
116
+ assert_allclose(clf_sklearn.intercept_, clf_onedal.intercept_, atol=1e-3)
117
+ assert_allclose(
118
+ clf_sklearn.support_vectors_.shape, clf_sklearn.support_vectors_.shape
119
+ )
120
+ assert_allclose(clf_sklearn.dual_coef_, clf_onedal.dual_coef_, atol=1e-1)
121
+
122
+
123
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
124
+ @pytest.mark.parametrize("queue", get_queues())
125
+ @pytest.mark.parametrize("kernel", ["linear", "rbf", "poly", "sigmoid"])
126
+ def test_diabetes_compare_with_sklearn(queue, kernel):
127
+ if kernel == "sigmoid":
128
+ pytest.skip("Sparse sigmoid kernel function is buggy.")
129
+ _test_diabetes_compare_with_sklearn(queue, kernel)
130
+
131
+
132
+ def _test_synth_rbf_compare_with_sklearn(queue, C, gamma):
133
+ x, y = datasets.make_regression(**synth_params)
134
+ clf = SVR(kernel="rbf", gamma=gamma, C=C)
135
+ clf.fit(x, y, queue=queue)
136
+ result = clf.score(x, y, queue=queue)
137
+
138
+ clf = SklearnSVR(kernel="rbf", gamma=gamma, C=C)
139
+ clf.fit(x, y)
140
+ expected = clf.score(x, y)
141
+
142
+ assert result > 0.4
143
+ assert result > expected - 1e-5
144
+
145
+
146
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
147
+ @pytest.mark.parametrize("queue", get_queues())
148
+ @pytest.mark.parametrize("gamma", ["scale", "auto"])
149
+ @pytest.mark.parametrize("C", [100.0, 1000.0])
150
+ def test_synth_rbf_compare_with_sklearn(queue, C, gamma):
151
+ _test_synth_rbf_compare_with_sklearn(queue, C, gamma)
152
+
153
+
154
+ def _test_synth_linear_compare_with_sklearn(queue, C):
155
+ x, y = datasets.make_regression(**synth_params)
156
+ clf = SVR(kernel="linear", C=C)
157
+ clf.fit(x, y, queue=queue)
158
+ result = clf.score(x, y, queue=queue)
159
+
160
+ clf = SklearnSVR(kernel="linear", C=C)
161
+ clf.fit(x, y)
162
+ expected = clf.score(x, y)
163
+
164
+ # Linear kernel doesn't work well for synthetic regression
165
+ # resulting in low R2 score
166
+ # assert result > 0.5
167
+ assert result > expected - 1e-3
168
+
169
+
170
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
171
+ @pytest.mark.parametrize("queue", get_queues())
172
+ @pytest.mark.parametrize("C", [0.001, 0.1])
173
+ def test_synth_linear_compare_with_sklearn(queue, C):
174
+ _test_synth_linear_compare_with_sklearn(queue, C)
175
+
176
+
177
+ def _test_synth_poly_compare_with_sklearn(queue, params):
178
+ x, y = datasets.make_regression(**synth_params)
179
+ clf = SVR(kernel="poly", **params)
180
+ clf.fit(x, y, queue=queue)
181
+ result = clf.score(x, y, queue=queue)
182
+
183
+ clf = SklearnSVR(kernel="poly", **params)
184
+ clf.fit(x, y)
185
+ expected = clf.score(x, y)
186
+
187
+ assert result > 0.5
188
+ assert result > expected - 1e-5
189
+
190
+
191
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
192
+ @pytest.mark.parametrize("queue", get_queues())
193
+ @pytest.mark.parametrize(
194
+ "params",
195
+ [
196
+ {"degree": 2, "coef0": 0.1, "gamma": "scale", "C": 100},
197
+ {"degree": 3, "coef0": 0.0, "gamma": "scale", "C": 1000},
198
+ ],
199
+ )
200
+ def test_synth_poly_compare_with_sklearn(queue, params):
201
+ _test_synth_poly_compare_with_sklearn(queue, params)
202
+
203
+
204
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
205
+ @pytest.mark.parametrize("queue", get_queues())
206
+ def test_sided_sample_weight(queue):
207
+ clf = SVR(C=1e-2, kernel="linear")
208
+
209
+ X = [[-2, 0], [-1, -1], [0, -2], [0, 2], [1, 1], [2, 0]]
210
+ Y = [1, 1, 1, 2, 2, 2]
211
+
212
+ sample_weight = [10.0, 0.1, 0.1, 0.1, 0.1, 10]
213
+ clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
214
+ y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
215
+ assert y_pred < 1.5
216
+
217
+ sample_weight = [1.0, 0.1, 10.0, 10.0, 0.1, 0.1]
218
+ clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
219
+ y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
220
+ assert y_pred > 1.5
221
+
222
+ sample_weight = [1] * 6
223
+ clf.fit(X, Y, sample_weight=sample_weight, queue=queue)
224
+ y_pred = clf.predict([[-1.0, 1.0]], queue=queue)
225
+ assert y_pred == pytest.approx(1.5)
226
+
227
+
228
+ @pass_if_not_implemented_for_gpu(reason="svr is not implemented")
229
+ @pytest.mark.parametrize("queue", get_queues())
230
+ def test_pickle(queue):
231
+ diabetes = datasets.load_diabetes()
232
+ clf = SVR(kernel="rbf", C=10.0)
233
+ clf.fit(diabetes.data, diabetes.target, queue=queue)
234
+ expected = clf.predict(diabetes.data, queue=queue)
235
+
236
+ import pickle
237
+
238
+ dump = pickle.dumps(clf)
239
+ clf2 = pickle.loads(dump)
240
+
241
+ assert type(clf2) == clf.__class__
242
+ result = clf2.predict(diabetes.data, queue=queue)
243
+ assert_array_equal(expected, result)
@@ -0,0 +1,57 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import importlib
18
+ import os
19
+ from glob import glob
20
+
21
+
22
+ def _check_primitive_usage_ban(primitive_name, package, allowed_locations=None):
23
+ """This test blocks the usage of the primitive in
24
+ in certain files.
25
+ """
26
+
27
+ loc = importlib.util.find_spec(package).origin
28
+
29
+ path = loc.replace("__init__.py", "")
30
+ files = [y for x in os.walk(path) for y in glob(os.path.join(x[0], "*.py"))]
31
+
32
+ output = []
33
+
34
+ for f in files:
35
+ if open(f, "r").read().find(primitive_name) != -1:
36
+ output += [f.replace(path, package + os.sep)]
37
+
38
+ # remove this file from the list
39
+ if allowed_locations:
40
+ for allowed in allowed_locations:
41
+ output = [i for i in output if allowed not in i]
42
+
43
+ return output
44
+
45
+
46
+ def test_sklearn_check_version_ban():
47
+ """This test blocks the use of sklearn_check_version
48
+ in onedal files. The versioning should occur in the
49
+ sklearnex package for clarity and maintainability.
50
+ """
51
+ output = _check_primitive_usage_ban(
52
+ primitive_name="sklearn_check_version", package="onedal"
53
+ )
54
+
55
+ # remove this file from the list
56
+ output = "\n".join([i for i in output if "test_common.py" not in i])
57
+ assert output == "", f"sklearn versioning is occuring in: \n{output}"