scikit-learn-intelex 2025.10.0__py313-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (267) hide show
  1. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__init__.py +73 -0
  2. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/__main__.py +58 -0
  3. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/_daal4py.cp313-win_amd64.pyd +0 -0
  4. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/__init__.py +94 -0
  5. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/gbt_convertors.py +1199 -0
  6. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/logistic_regression_builders.py +211 -0
  7. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mb/tree_based_builders.py +425 -0
  8. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/mpi_transceiver.cp313-win_amd64.pyd +0 -0
  9. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/__init__.py +40 -0
  10. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_n_jobs_support.py +252 -0
  11. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/_utils.py +245 -0
  12. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/__init__.py +20 -0
  13. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/dbscan.py +165 -0
  14. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/k_means.py +565 -0
  15. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/cluster/tests/test_dbscan.py +109 -0
  16. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/__init__.py +19 -0
  17. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/decomposition/_pca.py +528 -0
  18. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/AdaBoostClassifier.py +192 -0
  19. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/GBTDAAL.py +333 -0
  20. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/__init__.py +27 -0
  21. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/_forest.py +1285 -0
  22. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/ensemble/tests/test_decision_forest.py +206 -0
  23. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/__init__.py +29 -0
  24. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_coordinate_descent.py +826 -0
  25. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_linear.py +269 -0
  26. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/_ridge.py +290 -0
  27. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/coordinate_descent.py +17 -0
  28. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/linear.py +17 -0
  29. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_loss.py +195 -0
  30. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/logistic_path.py +561 -0
  31. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/ridge.py +17 -0
  32. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_enet.py +157 -0
  33. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_linear.py +198 -0
  34. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/linear_model/tests/test_ridge.py +69 -0
  35. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/__init__.py +19 -0
  36. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/manifold/_t_sne.py +432 -0
  37. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/__init__.py +20 -0
  38. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_pairwise.py +259 -0
  39. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/metrics/_ranking.py +210 -0
  40. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/__init__.py +19 -0
  41. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/_split.py +309 -0
  42. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/model_selection/tests/test_split.py +56 -0
  43. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/__init__.py +0 -0
  44. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/dispatcher.py +232 -0
  45. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/_models_info.py +164 -0
  46. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_monkeypatch.py +71 -0
  47. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/test_patching.py +90 -0
  48. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/monkeypatch/tests/utils/_launch_algorithms.py +117 -0
  49. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/__init__.py +21 -0
  50. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_base.py +493 -0
  51. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_classification.py +136 -0
  52. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_regression.py +72 -0
  53. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/_unsupervised.py +55 -0
  54. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/neighbors/tests/test_kneighbors.py +113 -0
  55. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/__init__.py +19 -0
  56. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/svm/svm.py +736 -0
  57. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/__init__.py +21 -0
  58. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/base.py +80 -0
  59. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/tests/test_utils.py +51 -0
  60. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/daal4py/sklearn/utils/validation.py +772 -0
  61. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/__init__.py +151 -0
  62. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_config.py +80 -0
  63. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_device_offload.py +199 -0
  64. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_dpc.cp313-win_amd64.pyd +0 -0
  65. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/_onedal_py_host.cp313-win_amd64.pyd +0 -0
  66. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/__init__.py +20 -0
  67. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/basic_statistics.py +151 -0
  68. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/incremental_basic_statistics.py +165 -0
  69. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_basic_statistics.py +241 -0
  70. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/test_incremental_basic_statistics.py +279 -0
  71. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/basic_statistics/tests/utils.py +50 -0
  72. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/__init__.py +27 -0
  73. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/dbscan.py +80 -0
  74. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans.py +582 -0
  75. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/kmeans_init.py +145 -0
  76. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_dbscan.py +125 -0
  77. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans.py +88 -0
  78. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/cluster/tests/test_kmeans_init.py +93 -0
  79. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_backend.py +258 -0
  80. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_estimator_checks.py +47 -0
  81. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/_mixin.py +62 -0
  82. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/hyperparameters.py +148 -0
  83. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/common/tests/test_sycl.py +148 -0
  84. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/__init__.py +20 -0
  85. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/covariance.py +121 -0
  86. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/incremental_covariance.py +151 -0
  87. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_covariance.py +50 -0
  88. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/covariance/tests/test_incremental_covariance.py +190 -0
  89. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/__init__.py +29 -0
  90. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_data_conversion.py +158 -0
  91. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_dlpack.py +64 -0
  92. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/_sycl_usm.py +63 -0
  93. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/common.py +131 -0
  94. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/datatypes/tests/test_data.py +686 -0
  95. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/__init__.py +20 -0
  96. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/incremental_pca.py +218 -0
  97. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/pca.py +124 -0
  98. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/decomposition/tests/test_incremental_pca.py +291 -0
  99. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/__init__.py +19 -0
  100. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/dummy/dummy.py +137 -0
  101. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/__init__.py +29 -0
  102. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/forest.py +781 -0
  103. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/ensemble/tests/test_random_forest.py +97 -0
  104. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/__init__.py +27 -0
  105. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/incremental_linear_model.py +201 -0
  106. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/linear_model.py +230 -0
  107. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/logistic_regression.py +293 -0
  108. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_linear_regression.py +213 -0
  109. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_incremental_ridge_regression.py +171 -0
  110. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_linear_regression.py +252 -0
  111. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_logistic_regression.py +95 -0
  112. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/linear_model/tests/test_ridge.py +95 -0
  113. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/__init__.py +19 -0
  114. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/neighbors.py +690 -0
  115. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/neighbors/tests/test_knn_classification.py +49 -0
  116. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/__init__.py +27 -0
  117. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/get_tree.py +27 -0
  118. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/kernel_functions.py +202 -0
  119. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/primitives/tests/test_kernel_functions.py +159 -0
  120. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/__init__.py +19 -0
  121. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/svm.py +592 -0
  122. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_csr_svm.py +352 -0
  123. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvc.py +204 -0
  124. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_nusvr.py +210 -0
  125. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svc.py +168 -0
  126. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/svm/tests/test_svr.py +243 -0
  127. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/test_common.py +71 -0
  128. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_dataframes_support.py +179 -0
  129. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/tests/utils/_device_selection.py +94 -0
  130. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_array_api.py +98 -0
  131. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_sycl_queue_manager.py +213 -0
  132. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/_third_party.py +220 -0
  133. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/tests/test_validation.py +142 -0
  134. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/onedal/utils/validation.py +503 -0
  135. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__init__.py +69 -0
  136. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/__main__.py +58 -0
  137. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_config.py +163 -0
  138. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_device_offload.py +205 -0
  139. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/_utils.py +219 -0
  140. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/base.py +109 -0
  141. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +20 -0
  142. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +241 -0
  143. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +338 -0
  144. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_basic_statistics.py +405 -0
  145. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +455 -0
  146. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/__init__.py +20 -0
  147. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +199 -0
  148. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/k_means.py +399 -0
  149. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +38 -0
  150. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +157 -0
  151. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/conftest.py +82 -0
  152. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  153. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +440 -0
  154. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +307 -0
  155. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/__init__.py +19 -0
  156. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +558 -0
  157. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +164 -0
  158. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dispatcher.py +572 -0
  159. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +629 -0
  160. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/__init__.py +19 -0
  161. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/_dummy.py +615 -0
  162. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/dummy/tests/test_dummy.py +62 -0
  163. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +29 -0
  164. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1799 -0
  165. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +196 -0
  166. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/__main__.py +72 -0
  167. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/glob/dispatcher.py +101 -0
  168. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/__init__.py +32 -0
  169. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +44 -0
  170. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_linear.py +427 -0
  171. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/incremental_ridge.py +407 -0
  172. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +363 -0
  173. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +466 -0
  174. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/ridge.py +407 -0
  175. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_linear.py +267 -0
  176. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_incremental_ridge.py +214 -0
  177. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +167 -0
  178. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +565 -0
  179. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_ridge.py +256 -0
  180. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/__init__.py +19 -0
  181. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/t_sne.py +28 -0
  182. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +256 -0
  183. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/__init__.py +23 -0
  184. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/pairwise.py +20 -0
  185. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/ranking.py +20 -0
  186. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +39 -0
  187. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/__init__.py +21 -0
  188. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/split.py +20 -0
  189. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +34 -0
  190. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/__init__.py +27 -0
  191. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +189 -0
  192. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/common.py +313 -0
  193. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +189 -0
  194. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +167 -0
  195. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +170 -0
  196. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +82 -0
  197. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/__init__.py +17 -0
  198. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +19 -0
  199. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +261 -0
  200. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +112 -0
  201. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +19 -0
  202. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/incremental_pca.py +406 -0
  203. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_incremental_pca.py +390 -0
  204. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +25 -0
  205. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +20 -0
  206. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +24 -0
  207. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/incremental_basic_statistics.py +30 -0
  208. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_basic_statistics_spmd.py +117 -0
  209. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/basic_statistics/tests/test_incremental_basic_statistics_spmd.py +314 -0
  210. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +30 -0
  211. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +26 -0
  212. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +21 -0
  213. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_dbscan_spmd.py +108 -0
  214. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/cluster/tests/test_kmeans_spmd.py +180 -0
  215. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +20 -0
  216. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +24 -0
  217. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/incremental_covariance.py +29 -0
  218. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_covariance_spmd.py +120 -0
  219. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/tests/test_incremental_covariance_spmd.py +200 -0
  220. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +20 -0
  221. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/incremental_pca.py +23 -0
  222. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +23 -0
  223. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_incremental_pca_spmd.py +276 -0
  224. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/decomposition/tests/test_pca_spmd.py +146 -0
  225. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +19 -0
  226. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +299 -0
  227. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/ensemble/tests/test_forest_spmd.py +299 -0
  228. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +21 -0
  229. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/incremental_linear_model.py +28 -0
  230. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +24 -0
  231. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  232. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_incremental_linear_spmd.py +345 -0
  233. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_linear_regression_spmd.py +162 -0
  234. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/tests/test_logistic_regression_spmd.py +169 -0
  235. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +23 -0
  236. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/spmd/neighbors/tests/test_neighbors_spmd.py +433 -0
  237. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/__init__.py +29 -0
  238. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/_common.py +403 -0
  239. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvc.py +278 -0
  240. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/nusvr.py +158 -0
  241. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svc.py +306 -0
  242. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/svr.py +155 -0
  243. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +124 -0
  244. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_common.py +607 -0
  245. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_config.py +256 -0
  246. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_hyperparameters.py +111 -0
  247. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +335 -0
  248. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +269 -0
  249. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +111 -0
  250. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +48 -0
  251. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +418 -0
  252. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability.py +335 -0
  253. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/__init__.py +48 -0
  254. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/base.py +420 -0
  255. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/tests/utils/spmd.py +198 -0
  256. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/__init__.py +19 -0
  257. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/_array_api.py +217 -0
  258. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/class_weight.py +100 -0
  259. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/parallel.py +97 -0
  260. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_class_weight.py +69 -0
  261. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/tests/test_validation.py +238 -0
  262. scikit_learn_intelex-2025.10.0.data/data/Lib/site-packages/sklearnex/utils/validation.py +212 -0
  263. scikit_learn_intelex-2025.10.0.dist-info/LICENSE.txt +202 -0
  264. scikit_learn_intelex-2025.10.0.dist-info/METADATA +182 -0
  265. scikit_learn_intelex-2025.10.0.dist-info/RECORD +267 -0
  266. scikit_learn_intelex-2025.10.0.dist-info/WHEEL +5 -0
  267. scikit_learn_intelex-2025.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,165 @@
1
+ # Copyright 2024 Intel Corporation
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .._device_offload import supports_queue
17
+ from ..common._backend import bind_default_backend
18
+ from ..datatypes import from_table, return_type_constructor, to_table
19
+ from ..utils import _sycl_queue_manager as QM
20
+ from .basic_statistics import BasicStatistics
21
+
22
+
23
+ class IncrementalBasicStatistics(BasicStatistics):
24
+ """Incremental oneDAL low order moments estimator.
25
+
26
+ Calculate basic statistics for data split into batches.
27
+
28
+ Parameters
29
+ ----------
30
+ result_options : str or list, default=str('all')
31
+ List of statistics to compute.
32
+
33
+ algorithm : str, default=str('by_default')
34
+ Method for statistics computation.
35
+
36
+ Attributes
37
+ ----------
38
+ min : ndarray of shape (n_features,)
39
+ Minimum of each feature over all samples.
40
+
41
+ max : ndarray of shape (n_features,)
42
+ Maximum of each feature over all samples.
43
+
44
+ sum : ndarray of shape (n_features,)
45
+ Sum of each feature over all samples.
46
+
47
+ mean : ndarray of shape (n_features,)
48
+ Mean of each feature over all samples.
49
+
50
+ variance : ndarray of shape (n_features,)
51
+ Variance of each feature over all samples.
52
+
53
+ variation : ndarray of shape (n_features,)
54
+ Variation of each feature over all samples.
55
+
56
+ sum_squares : ndarray of shape (n_features,)
57
+ Sum of squares for each feature over all samples.
58
+
59
+ standard_deviation : ndarray of shape (n_features,)
60
+ Standard deviation of each feature over all samples.
61
+
62
+ sum_squares_centered : ndarray of shape (n_features,)
63
+ Centered sum of squares for each feature over all samples.
64
+
65
+ second_order_raw_moment : ndarray of shape (n_features,)
66
+ Second order moment of each feature over all samples.
67
+
68
+ Notes
69
+ -----
70
+ Attributes are populated only for corresponding result options.
71
+ """
72
+
73
+ def __init__(self, result_options="all", algorithm="by_default"):
74
+ super().__init__(result_options, algorithm)
75
+ self._reset()
76
+ self._queue = None
77
+
78
+ @bind_default_backend("basic_statistics")
79
+ def partial_compute_result(self): ...
80
+
81
+ @bind_default_backend("basic_statistics")
82
+ def partial_compute(self, *args, **kwargs): ...
83
+
84
+ @bind_default_backend("basic_statistics")
85
+ def finalize_compute(self, *args, **kwargs): ...
86
+
87
+ def _reset(self):
88
+ self._need_to_finalize = False
89
+ self._outtype = None
90
+ self._queue = None
91
+ # get the _partial_result pointer from backend
92
+ self._partial_result = self.partial_compute_result()
93
+
94
+ def __getstate__(self):
95
+ # Since finalize_fit can't be dispatched without directly provided queue
96
+ # and the dispatching policy can't be serialized, the computation is finalized
97
+ # here and the policy is not saved in serialized data.
98
+ self.finalize_fit()
99
+ data = self.__dict__.copy()
100
+ data.pop("_queue", None)
101
+
102
+ return data
103
+
104
+ @supports_queue
105
+ def partial_fit(self, X, sample_weight=None, queue=None):
106
+ """Generate partial statistics from batch data in `_partial_result`.
107
+
108
+ Parameters
109
+ ----------
110
+ X : array-like of shape (n_samples, n_features)
111
+ Training data batch, where `n_samples` is the number of samples
112
+ in the batch, and `n_features` is the number of features.
113
+
114
+ sample_weight : array-like of shape (n_samples,), default=None
115
+ Individual weights for each sample.
116
+
117
+ queue : SyclQueue or None, default=None
118
+ SYCL Queue object for device code execution. Default
119
+ value None causes computation on host.
120
+
121
+ Returns
122
+ -------
123
+ self : object
124
+ Returns the instance itself.
125
+ """
126
+
127
+ self._queue = queue
128
+ if not self._outtype:
129
+ self._outtype = return_type_constructor(X)
130
+
131
+ X_table, sample_weight_table = to_table(X, sample_weight, queue=queue)
132
+
133
+ if not hasattr(self, "_onedal_params"):
134
+ self._onedal_params = self._get_onedal_params(False, dtype=X.dtype)
135
+
136
+ self._partial_result = self.partial_compute(
137
+ self._onedal_params, self._partial_result, X_table, sample_weight_table
138
+ )
139
+
140
+ self._need_to_finalize = True
141
+ self._queue = queue
142
+
143
+ def finalize_fit(self):
144
+ """Finalize statistics from the current `_partial_result`.
145
+
146
+ Returns
147
+ -------
148
+ self : object
149
+ Returns the instance itself.
150
+ """
151
+ if self._need_to_finalize:
152
+ with QM.manage_global_queue(self._queue):
153
+ result = self.finalize_compute(self._onedal_params, self._partial_result)
154
+
155
+ for opt in self.options:
156
+ setattr(
157
+ self,
158
+ opt + "_",
159
+ from_table(getattr(result, opt), like=self._outtype)[0, :],
160
+ )
161
+
162
+ self._outtype = None
163
+ self._need_to_finalize = False
164
+
165
+ return self
@@ -0,0 +1,241 @@
1
+ # ==============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+ from scipy import sparse as sp
21
+
22
+ from onedal.basic_statistics import BasicStatistics
23
+ from onedal.basic_statistics.tests.utils import options_and_tests
24
+ from onedal.tests.utils._device_selection import get_queues
25
+
26
+ options_and_tests_csr = [
27
+ ("sum", "sum", (5e-6, 1e-9)),
28
+ ("min", "min", (0, 0)),
29
+ ("max", "max", (0, 0)),
30
+ ("mean", "mean", (5e-6, 1e-9)),
31
+ ]
32
+
33
+
34
+ @pytest.mark.parametrize("queue", get_queues())
35
+ @pytest.mark.parametrize("result_option", options_and_tests.keys())
36
+ @pytest.mark.parametrize("row_count", [100, 1000])
37
+ @pytest.mark.parametrize("column_count", [10, 100])
38
+ @pytest.mark.parametrize("weighted", [True, False])
39
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
40
+ def test_single_option_on_random_data(
41
+ queue, result_option, row_count, column_count, weighted, dtype
42
+ ):
43
+ function, tols = options_and_tests[result_option]
44
+ fp32tol, fp64tol = tols
45
+ seed = 77
46
+ gen = np.random.default_rng(seed)
47
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
48
+ data = data.astype(dtype=dtype)
49
+ if weighted:
50
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
51
+ weights = weights.astype(dtype=dtype)
52
+ else:
53
+ weights = None
54
+
55
+ basicstat = BasicStatistics(result_options=result_option)
56
+
57
+ result = basicstat.fit(data, sample_weight=weights, queue=queue)
58
+
59
+ res = getattr(result, result_option + "_")
60
+ if weighted:
61
+ weighted_data = np.diag(weights) @ data
62
+ gtr = function(weighted_data)
63
+ else:
64
+ gtr = function(data)
65
+
66
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
67
+ assert_allclose(gtr, res, atol=tol)
68
+
69
+
70
+ @pytest.mark.parametrize("queue", get_queues())
71
+ @pytest.mark.parametrize("row_count", [100, 1000])
72
+ @pytest.mark.parametrize("column_count", [10, 100])
73
+ @pytest.mark.parametrize("weighted", [True, False])
74
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
75
+ def test_multiple_options_on_random_data(queue, row_count, column_count, weighted, dtype):
76
+ seed = 42
77
+ gen = np.random.default_rng(seed)
78
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
79
+ data = data.astype(dtype=dtype)
80
+
81
+ if weighted:
82
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
83
+ weights = weights.astype(dtype=dtype)
84
+ else:
85
+ weights = None
86
+
87
+ basicstat = BasicStatistics(result_options=["mean", "max", "sum"])
88
+
89
+ result = basicstat.fit(data, sample_weight=weights, queue=queue)
90
+
91
+ res_mean, res_max, res_sum = result.mean_, result.max_, result.sum_
92
+ if weighted:
93
+ weighted_data = np.diag(weights) @ data
94
+ gtr_mean, gtr_max, gtr_sum = (
95
+ options_and_tests["mean"][0](weighted_data),
96
+ options_and_tests["max"][0](weighted_data),
97
+ options_and_tests["sum"][0](weighted_data),
98
+ )
99
+ else:
100
+ gtr_mean, gtr_max, gtr_sum = (
101
+ options_and_tests["mean"][0](data),
102
+ options_and_tests["max"][0](data),
103
+ options_and_tests["sum"][0](data),
104
+ )
105
+
106
+ tol = 5e-4 if res_mean.dtype == np.float32 else 1e-7
107
+ assert_allclose(gtr_mean, res_mean, atol=tol)
108
+ assert_allclose(gtr_max, res_max, atol=tol)
109
+ assert_allclose(gtr_sum, res_sum, atol=tol)
110
+
111
+
112
+ @pytest.mark.parametrize("queue", get_queues())
113
+ @pytest.mark.parametrize("row_count", [100, 1000])
114
+ @pytest.mark.parametrize("column_count", [10, 100])
115
+ @pytest.mark.parametrize("weighted", [True, False])
116
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
117
+ def test_all_option_on_random_data(queue, row_count, column_count, weighted, dtype):
118
+ seed = 77
119
+ gen = np.random.default_rng(seed)
120
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
121
+ data = data.astype(dtype=dtype)
122
+ if weighted:
123
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
124
+ weights = weights.astype(dtype=dtype)
125
+ else:
126
+ weights = None
127
+
128
+ basicstat = BasicStatistics(result_options="all")
129
+
130
+ result = basicstat.fit(data, sample_weight=weights, queue=queue)
131
+
132
+ if weighted:
133
+ weighted_data = np.diag(weights) @ data
134
+
135
+ for result_option in options_and_tests:
136
+ function, tols = options_and_tests[result_option]
137
+ fp32tol, fp64tol = tols
138
+ res = getattr(result, result_option + "_")
139
+ if weighted:
140
+ gtr = function(weighted_data)
141
+ else:
142
+ gtr = function(data)
143
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
144
+ assert_allclose(gtr, res, atol=tol)
145
+
146
+
147
+ @pytest.mark.parametrize("queue", get_queues())
148
+ @pytest.mark.parametrize("result_option", options_and_tests.keys())
149
+ @pytest.mark.parametrize("data_size", [100, 1000])
150
+ @pytest.mark.parametrize("weighted", [True, False])
151
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
152
+ def test_1d_input_on_random_data(queue, result_option, data_size, weighted, dtype):
153
+
154
+ function, tols = options_and_tests[result_option]
155
+ fp32tol, fp64tol = tols
156
+ seed = 77
157
+ gen = np.random.default_rng(seed)
158
+ data = gen.uniform(low=-0.3, high=+0.7, size=data_size)
159
+ data = data.astype(dtype=dtype)
160
+ if weighted:
161
+ weights = gen.uniform(low=-0.5, high=+1.0, size=data_size)
162
+ weights = weights.astype(dtype=dtype)
163
+ else:
164
+ weights = None
165
+
166
+ basicstat = BasicStatistics(result_options=result_option)
167
+
168
+ result = basicstat.fit(data, sample_weight=weights, queue=queue)
169
+
170
+ res = getattr(result, result_option + "_")
171
+ if weighted:
172
+ weighted_data = weights * data
173
+ gtr = function(weighted_data)
174
+ else:
175
+ gtr = function(data)
176
+
177
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
178
+ assert_allclose(gtr, res, atol=tol)
179
+
180
+
181
+ @pytest.mark.skipif(not hasattr(sp, "random_array"), reason="requires scipy>=1.12.0")
182
+ @pytest.mark.parametrize("queue", get_queues())
183
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
184
+ def test_basic_csr(queue, dtype):
185
+ seed = 42
186
+ row_count, column_count = 5000, 3008
187
+
188
+ gen = np.random.default_rng(seed)
189
+
190
+ data = sp.random_array(
191
+ shape=(row_count, column_count),
192
+ density=0.01,
193
+ format="csr",
194
+ dtype=dtype,
195
+ random_state=gen,
196
+ )
197
+
198
+ basicstat = BasicStatistics(result_options="mean")
199
+ result = basicstat.fit(data, queue=queue)
200
+
201
+ res_mean = result.mean_
202
+ gtr_mean = data.mean(axis=0)
203
+ tol = 5e-6 if res_mean.dtype == np.float32 else 1e-9
204
+ assert_allclose(gtr_mean, res_mean, rtol=tol)
205
+
206
+
207
+ @pytest.mark.skipif(not hasattr(sp, "random_array"), reason="requires scipy>=1.12.0")
208
+ @pytest.mark.parametrize("queue", get_queues())
209
+ @pytest.mark.parametrize("option", options_and_tests_csr)
210
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
211
+ def test_options_csr(queue, option, dtype):
212
+ result_option, function, tols = option
213
+ fp32tol, fp64tol = tols
214
+
215
+ if result_option == "max":
216
+ pytest.skip("There is a bug in oneDAL's max computations on GPU")
217
+
218
+ seed = 42
219
+ row_count, column_count = 20046, 4007
220
+
221
+ gen = np.random.default_rng(seed)
222
+
223
+ data = sp.random_array(
224
+ shape=(row_count, column_count),
225
+ density=0.002,
226
+ format="csr",
227
+ dtype=dtype,
228
+ random_state=gen,
229
+ )
230
+
231
+ basicstat = BasicStatistics(result_options=result_option)
232
+ result = basicstat.fit(data, queue=queue)
233
+
234
+ res = getattr(result, result_option + "_")
235
+ func = getattr(data, function)
236
+ gtr = func(axis=0)
237
+ if type(gtr).__name__ != "ndarray":
238
+ gtr = gtr.toarray().flatten()
239
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
240
+
241
+ assert_allclose(gtr, res, rtol=tol)
@@ -0,0 +1,279 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from onedal.basic_statistics import IncrementalBasicStatistics
22
+ from onedal.basic_statistics.tests.utils import options_and_tests
23
+ from onedal.datatypes import from_table
24
+ from onedal.tests.utils._device_selection import get_queues
25
+
26
+
27
+ @pytest.mark.parametrize("queue", get_queues())
28
+ @pytest.mark.parametrize("weighted", [True, False])
29
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
30
+ def test_multiple_options_on_gold_data(queue, weighted, dtype):
31
+ X = np.array([[0, 0], [1, 1]])
32
+ X = X.astype(dtype=dtype)
33
+ X_split = np.array_split(X, 2)
34
+ if weighted:
35
+ weights = np.array([1, 0.5])
36
+ weights = weights.astype(dtype=dtype)
37
+ weights_split = np.array_split(weights, 2)
38
+
39
+ incbs = IncrementalBasicStatistics()
40
+ for i in range(2):
41
+ if weighted:
42
+ incbs.partial_fit(X_split[i], weights_split[i], queue=queue)
43
+ else:
44
+ incbs.partial_fit(X_split[i], queue=queue)
45
+
46
+ result = incbs.finalize_fit()
47
+
48
+ if weighted:
49
+ expected_weighted_mean = np.array([0.25, 0.25])
50
+ expected_weighted_min = np.array([0, 0])
51
+ expected_weighted_max = np.array([0.5, 0.5])
52
+ assert_allclose(expected_weighted_mean, result.mean_)
53
+ assert_allclose(expected_weighted_max, result.max_)
54
+ assert_allclose(expected_weighted_min, result.min_)
55
+ else:
56
+ expected_mean = np.array([0.5, 0.5])
57
+ expected_min = np.array([0, 0])
58
+ expected_max = np.array([1, 1])
59
+ assert_allclose(expected_mean, result.mean_)
60
+ assert_allclose(expected_max, result.max_)
61
+ assert_allclose(expected_min, result.min_)
62
+
63
+
64
+ @pytest.mark.parametrize("queue", get_queues())
65
+ @pytest.mark.parametrize("num_batches", [2, 10])
66
+ @pytest.mark.parametrize("result_option", options_and_tests.keys())
67
+ @pytest.mark.parametrize("row_count", [100, 1000])
68
+ @pytest.mark.parametrize("column_count", [10, 100])
69
+ @pytest.mark.parametrize("weighted", [True, False])
70
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
71
+ def test_single_option_on_random_data(
72
+ queue, num_batches, result_option, row_count, column_count, weighted, dtype
73
+ ):
74
+ function, tols = options_and_tests[result_option]
75
+ fp32tol, fp64tol = tols
76
+ seed = 77
77
+ gen = np.random.default_rng(seed)
78
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
79
+ data = data.astype(dtype=dtype)
80
+ data_split = np.array_split(data, num_batches)
81
+ if weighted:
82
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
83
+ weights = weights.astype(dtype=dtype)
84
+ weights_split = np.array_split(weights, num_batches)
85
+ incbs = IncrementalBasicStatistics(result_options=result_option)
86
+
87
+ for i in range(num_batches):
88
+ if weighted:
89
+ incbs.partial_fit(data_split[i], weights_split[i], queue=queue)
90
+ else:
91
+ incbs.partial_fit(data_split[i], queue=queue)
92
+ result = incbs.finalize_fit()
93
+
94
+ res = getattr(result, result_option + "_")
95
+ if weighted:
96
+ weighted_data = np.diag(weights) @ data
97
+ gtr = function(weighted_data)
98
+ else:
99
+ gtr = function(data)
100
+
101
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
102
+ assert_allclose(gtr, res, atol=tol)
103
+
104
+
105
+ @pytest.mark.parametrize("queue", get_queues())
106
+ @pytest.mark.parametrize("num_batches", [2, 10])
107
+ @pytest.mark.parametrize("row_count", [100, 1000])
108
+ @pytest.mark.parametrize("column_count", [10, 100])
109
+ @pytest.mark.parametrize("weighted", [True, False])
110
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
111
+ def test_multiple_options_on_random_data(
112
+ queue, num_batches, row_count, column_count, weighted, dtype
113
+ ):
114
+ seed = 42
115
+ gen = np.random.default_rng(seed)
116
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
117
+ data = data.astype(dtype=dtype)
118
+ data_split = np.array_split(data, num_batches)
119
+ if weighted:
120
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
121
+ weights = weights.astype(dtype=dtype)
122
+ weights_split = np.array_split(weights, num_batches)
123
+ incbs = IncrementalBasicStatistics(result_options=["mean", "max", "sum"])
124
+
125
+ for i in range(num_batches):
126
+ if weighted:
127
+ incbs.partial_fit(data_split[i], weights_split[i], queue=queue)
128
+ else:
129
+ incbs.partial_fit(data_split[i], queue=queue)
130
+ result = incbs.finalize_fit()
131
+
132
+ res_mean, res_max, res_sum = result.mean_, result.max_, result.sum_
133
+ if weighted:
134
+ weighted_data = np.diag(weights) @ data
135
+ gtr_mean, gtr_max, gtr_sum = (
136
+ options_and_tests["mean"][0](weighted_data),
137
+ options_and_tests["max"][0](weighted_data),
138
+ options_and_tests["sum"][0](weighted_data),
139
+ )
140
+ else:
141
+ gtr_mean, gtr_max, gtr_sum = (
142
+ options_and_tests["mean"][0](data),
143
+ options_and_tests["max"][0](data),
144
+ options_and_tests["sum"][0](data),
145
+ )
146
+
147
+ tol = 3e-4 if res_mean.dtype == np.float32 else 1e-7
148
+ assert_allclose(gtr_mean, res_mean, atol=tol)
149
+ assert_allclose(gtr_max, res_max, atol=tol)
150
+ assert_allclose(gtr_sum, res_sum, atol=tol)
151
+
152
+
153
+ @pytest.mark.parametrize("queue", get_queues())
154
+ @pytest.mark.parametrize("num_batches", [2, 10])
155
+ @pytest.mark.parametrize("row_count", [100, 1000])
156
+ @pytest.mark.parametrize("column_count", [10, 100])
157
+ @pytest.mark.parametrize("weighted", [True, False])
158
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
159
+ def test_all_option_on_random_data(
160
+ queue, num_batches, row_count, column_count, weighted, dtype
161
+ ):
162
+ seed = 77
163
+ gen = np.random.default_rng(seed)
164
+ data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
165
+ data = data.astype(dtype=dtype)
166
+ data_split = np.array_split(data, num_batches)
167
+ if weighted:
168
+ weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
169
+ weights = weights.astype(dtype=dtype)
170
+ weights_split = np.array_split(weights, num_batches)
171
+ incbs = IncrementalBasicStatistics(result_options="all")
172
+
173
+ for i in range(num_batches):
174
+ if weighted:
175
+ incbs.partial_fit(data_split[i], weights_split[i], queue=queue)
176
+ else:
177
+ incbs.partial_fit(data_split[i], queue=queue)
178
+ result = incbs.finalize_fit()
179
+
180
+ if weighted:
181
+ weighted_data = np.diag(weights) @ data
182
+
183
+ for result_option in options_and_tests:
184
+ function, tols = options_and_tests[result_option]
185
+ fp32tol, fp64tol = tols
186
+ res = getattr(result, result_option + "_")
187
+ if weighted:
188
+ gtr = function(weighted_data)
189
+ else:
190
+ gtr = function(data)
191
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
192
+ assert_allclose(gtr, res, atol=tol)
193
+
194
+
195
+ @pytest.mark.parametrize("queue", get_queues())
196
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
197
+ def test_incremental_estimator_pickle(queue, dtype):
198
+ import pickle
199
+
200
+ from onedal.basic_statistics import IncrementalBasicStatistics
201
+
202
+ incbs = IncrementalBasicStatistics()
203
+
204
+ # Check that estimator can be serialized without any data.
205
+ dump = pickle.dumps(incbs)
206
+ incbs_loaded = pickle.loads(dump)
207
+ seed = 77
208
+ gen = np.random.default_rng(seed)
209
+ X = gen.uniform(low=-0.3, high=+0.7, size=(10, 10))
210
+ X = X.astype(dtype)
211
+ X_split = np.array_split(X, 2)
212
+ incbs.partial_fit(X_split[0], queue=queue)
213
+ incbs_loaded.partial_fit(X_split[0], queue=queue)
214
+
215
+ assert incbs._need_to_finalize == True
216
+ assert incbs_loaded._need_to_finalize == True
217
+
218
+ # Check that estimator can be serialized after partial_fit call.
219
+ dump = pickle.dumps(incbs)
220
+ incbs_loaded = pickle.loads(dump)
221
+ assert incbs._need_to_finalize == False
222
+ # Finalize is called during serialization to make sure partial results are finalized correctly.
223
+ assert incbs_loaded._need_to_finalize == False
224
+
225
+ partial_n_rows = from_table(incbs._partial_result.partial_n_rows)
226
+ partial_n_rows_loaded = from_table(incbs_loaded._partial_result.partial_n_rows)
227
+ assert_allclose(partial_n_rows, partial_n_rows_loaded)
228
+
229
+ partial_min = from_table(incbs._partial_result.partial_min)
230
+ partial_min_loaded = from_table(incbs_loaded._partial_result.partial_min)
231
+ assert_allclose(partial_min, partial_min_loaded)
232
+
233
+ partial_max = from_table(incbs._partial_result.partial_max)
234
+ partial_max_loaded = from_table(incbs_loaded._partial_result.partial_max)
235
+ assert_allclose(partial_max, partial_max_loaded)
236
+
237
+ partial_sum = from_table(incbs._partial_result.partial_sum)
238
+ partial_sum_loaded = from_table(incbs_loaded._partial_result.partial_sum)
239
+ assert_allclose(partial_sum, partial_sum_loaded)
240
+
241
+ partial_sum_squares = from_table(incbs._partial_result.partial_sum_squares)
242
+ partial_sum_squares_loaded = from_table(
243
+ incbs_loaded._partial_result.partial_sum_squares
244
+ )
245
+ assert_allclose(partial_sum_squares, partial_sum_squares_loaded)
246
+
247
+ partial_sum_squares_centered = from_table(
248
+ incbs._partial_result.partial_sum_squares_centered
249
+ )
250
+ partial_sum_squares_centered_loaded = from_table(
251
+ incbs_loaded._partial_result.partial_sum_squares_centered
252
+ )
253
+ assert_allclose(partial_sum_squares_centered, partial_sum_squares_centered_loaded)
254
+
255
+ incbs.partial_fit(X_split[1], queue=queue)
256
+ incbs_loaded.partial_fit(X_split[1], queue=queue)
257
+ assert incbs._need_to_finalize == True
258
+ assert incbs_loaded._need_to_finalize == True
259
+
260
+ dump = pickle.dumps(incbs_loaded)
261
+ incbs_loaded = pickle.loads(dump)
262
+
263
+ assert incbs._need_to_finalize == True
264
+ assert incbs_loaded._need_to_finalize == False
265
+
266
+ incbs.finalize_fit()
267
+ incbs_loaded.finalize_fit()
268
+
269
+ # Check that finalized estimator can be serialized.
270
+ dump = pickle.dumps(incbs_loaded)
271
+ incbs_loaded = pickle.loads(dump)
272
+
273
+ for result_option in options_and_tests:
274
+ _, tols = options_and_tests[result_option]
275
+ fp32tol, fp64tol = tols
276
+ res = getattr(incbs, result_option + "_")
277
+ res_loaded = getattr(incbs_loaded, result_option + "_")
278
+ tol = fp32tol if res.dtype == np.float32 else fp64tol
279
+ assert_allclose(res, res_loaded, atol=tol)