scikit-learn-intelex 2024.2.0__py310-none-win_amd64.whl → 2024.3.0__py310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (107) hide show
  1. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  3. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  4. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +338 -0
  5. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  6. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +72 -41
  7. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +10 -14
  8. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  9. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +13 -2
  10. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -2
  11. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +39 -2
  12. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +7 -9
  13. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +6 -9
  14. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +5 -8
  15. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -5
  16. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  17. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  18. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +4 -0
  19. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +4 -0
  20. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +155 -0
  21. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +8 -3
  22. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  23. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  24. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +361 -0
  25. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  26. scikit_learn_intelex-2024.3.0.dist-info/RECORD +98 -0
  27. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  28. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  29. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  30. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
  31. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  32. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
  33. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -136
  34. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  35. scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
  36. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  37. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  38. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  39. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  40. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  41. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  42. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  43. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  44. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  45. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  46. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  47. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  48. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  49. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  51. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  52. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  53. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  54. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  55. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  56. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -0
  57. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  58. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  59. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  60. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  61. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  62. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  63. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  65. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  66. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  67. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  69. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  70. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  72. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  73. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  74. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +0 -0
  75. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  76. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +0 -0
  77. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  78. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  81. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  83. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  84. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  86. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  87. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  88. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  91. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  92. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  94. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  96. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -3
  97. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +0 -0
  98. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  99. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  100. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -0
  101. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  102. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  104. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  105. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  106. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  107. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -453,14 +453,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
453
453
 
454
454
  # The estimator is checked against the class attribute for conformance.
455
455
  # This should only trigger if the user uses this class directly.
456
- if (
457
- self.estimator.__class__ == DecisionTreeClassifier
458
- and self._onedal_factory != onedal_RandomForestClassifier
456
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
457
+ self._onedal_factory, onedal_RandomForestClassifier
459
458
  ):
460
459
  self._onedal_factory = onedal_RandomForestClassifier
461
- elif (
462
- self.estimator.__class__ == ExtraTreeClassifier
463
- and self._onedal_factory != onedal_ExtraTreesClassifier
460
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
461
+ self._onedal_factory, onedal_ExtraTreesClassifier
464
462
  ):
465
463
  self._onedal_factory = onedal_ExtraTreesClassifier
466
464
 
@@ -747,7 +745,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
747
745
  or self.estimator.__class__ == DecisionTreeClassifier,
748
746
  "ExtraTrees only supported starting from oneDAL version 2023.1",
749
747
  ),
750
- (sample_weight is not None, "sample_weight is not supported."),
748
+ (sample_weight is None, "sample_weight is not supported."),
751
749
  ]
752
750
  )
753
751
 
@@ -843,14 +841,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
843
841
 
844
842
  # The splitter is checked against the class attribute for conformance
845
843
  # This should only trigger if the user uses this class directly.
846
- if (
847
- self.estimator.__class__ == DecisionTreeRegressor
848
- and self._onedal_factory != onedal_RandomForestRegressor
844
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
845
+ self._onedal_factory, onedal_RandomForestRegressor
849
846
  ):
850
847
  self._onedal_factory = onedal_RandomForestRegressor
851
- elif (
852
- self.estimator.__class__ == ExtraTreeRegressor
853
- and self._onedal_factory != onedal_ExtraTreesRegressor
848
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
849
+ self._onedal_factory, onedal_ExtraTreesRegressor
854
850
  ):
855
851
  self._onedal_factory = onedal_ExtraTreesRegressor
856
852
 
@@ -1056,7 +1052,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1056
1052
  or self.estimator.__class__ == DecisionTreeClassifier,
1057
1053
  "ExtraTrees only supported starting from oneDAL version 2023.1",
1058
1054
  ),
1059
- (sample_weight is not None, "sample_weight is not supported."),
1055
+ (sample_weight is None, "sample_weight is not supported."),
1060
1056
  ]
1061
1057
  )
1062
1058
 
@@ -45,11 +45,7 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
45
45
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
46
46
 
47
47
 
48
- # TODO:
49
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
50
- @pytest.mark.parametrize(
51
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
52
- )
48
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
53
49
  def test_sklearnex_import_rf_regression(dataframe, queue):
54
50
  from sklearnex.ensemble import RandomForestRegressor
55
51
 
@@ -59,17 +55,17 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
59
55
  rf = RandomForestRegressor(max_depth=2, random_state=0).fit(X, y)
60
56
  assert "sklearnex" in rf.__module__
61
57
  pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))
62
- if daal_check_version((2024, "P", 0)):
63
- assert_allclose([-6.971], pred, atol=1e-2)
58
+
59
+ if queue is not None and queue.sycl_device.is_gpu:
60
+ assert_allclose([-0.011208], pred, atol=1e-2)
64
61
  else:
65
- assert_allclose([-6.839], pred, atol=1e-2)
62
+ if daal_check_version((2024, "P", 0)):
63
+ assert_allclose([-6.971], pred, atol=1e-2)
64
+ else:
65
+ assert_allclose([-6.839], pred, atol=1e-2)
66
66
 
67
67
 
68
- # TODO:
69
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
70
- @pytest.mark.parametrize(
71
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
72
- )
68
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
73
69
  def test_sklearnex_import_et_classifier(dataframe, queue):
74
70
  from sklearnex.ensemble import ExtraTreesClassifier
75
71
 
@@ -90,11 +86,7 @@ def test_sklearnex_import_et_classifier(dataframe, queue):
90
86
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
91
87
 
92
88
 
93
- # TODO:
94
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
95
- @pytest.mark.parametrize(
96
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
97
- )
89
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
98
90
  def test_sklearnex_import_et_regression(dataframe, queue):
99
91
  from sklearnex.ensemble import ExtraTreesRegressor
100
92
 
@@ -114,4 +106,8 @@ def test_sklearnex_import_et_regression(dataframe, queue):
114
106
  ]
115
107
  )
116
108
  )
117
- assert_allclose([0.445], pred, atol=1e-2)
109
+
110
+ if queue is not None and queue.sycl_device.is_gpu:
111
+ assert_allclose([1.909769], pred, atol=1e-2)
112
+ else:
113
+ assert_allclose([0.445], pred, atol=1e-2)
@@ -185,7 +185,10 @@ if daal_check_version((2024, "P", 1)):
185
185
  [
186
186
  (self.penalty == "l2", "Only l2 penalty is supported."),
187
187
  (self.dual == False, "dual=True is not supported."),
188
- (self.intercept_scaling == 1, "Intercept scaling is not supported."),
188
+ (
189
+ self.intercept_scaling == 1,
190
+ "Intercept scaling is not supported.",
191
+ ),
189
192
  (self.class_weight is None, "Class weight is not supported"),
190
193
  (self.solver == "newton-cg", "Only newton-cg solver is supported."),
191
194
  (
@@ -230,7 +233,10 @@ if daal_check_version((2024, "P", 1)):
230
233
  (n_samples > 0, "Number of samples is less than 1."),
231
234
  (not issparse(*data), "Sparse input is not supported."),
232
235
  (not model_is_sparse, "Sparse coefficients are not supported."),
233
- (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
236
+ (
237
+ hasattr(self, "_onedal_estimator"),
238
+ "oneDAL model was not trained.",
239
+ ),
234
240
  ]
235
241
  )
236
242
  if not dal_ready:
@@ -324,6 +330,11 @@ if daal_check_version((2024, "P", 1)):
324
330
  assert hasattr(self, "_onedal_estimator")
325
331
  return self._onedal_estimator.predict_log_proba(X, queue=queue)
326
332
 
333
+ fit.__doc__ = sklearn_LogisticRegression.fit.__doc__
334
+ predict.__doc__ = sklearn_LogisticRegression.predict.__doc__
335
+ predict_proba.__doc__ = sklearn_LogisticRegression.predict_proba.__doc__
336
+ predict_log_proba.__doc__ = sklearn_LogisticRegression.predict_log_proba.__doc__
337
+
327
338
  else:
328
339
  LogisticRegression = LogisticRegression_daal4py
329
340
 
@@ -14,9 +14,7 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
- import numpy as np
18
17
  import pytest
19
- from numpy.testing import assert_allclose
20
18
  from sklearn.datasets import load_breast_cancer, load_iris
21
19
  from sklearn.metrics import accuracy_score
22
20
  from sklearn.model_selection import train_test_split
@@ -137,11 +137,50 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
137
137
  @available_if(sklearn_LocalOutlierFactor._check_novelty_fit_predict)
138
138
  @wrap_output_data
139
139
  def fit_predict(self, X, y=None):
140
+ """Fit the model to the training set X and return the labels.
141
+
142
+ **Not available for novelty detection (when novelty is set to True).**
143
+ Label is 1 for an inlier and -1 for an outlier according to the LOF
144
+ score and the contamination parameter.
145
+
146
+ Parameters
147
+ ----------
148
+ X : {array-like, sparse matrix} of shape (n_samples, n_features), default=None
149
+ The query sample or samples to compute the Local Outlier Factor
150
+ w.r.t. the training samples.
151
+
152
+ y : Ignored
153
+ Not used, present for API consistency by convention.
154
+
155
+ Returns
156
+ -------
157
+ is_inlier : ndarray of shape (n_samples,)
158
+ Returns -1 for anomalies/outliers and 1 for inliers.
159
+ """
140
160
  return self.fit(X)._predict()
141
161
 
142
162
  @available_if(sklearn_LocalOutlierFactor._check_novelty_predict)
143
163
  @wrap_output_data
144
164
  def predict(self, X=None):
165
+ """Predict the labels (1 inlier, -1 outlier) of X according to LOF.
166
+
167
+ **Only available for novelty detection (when novelty is set to True).**
168
+ This method allows to generalize prediction to *new observations* (not
169
+ in the training set). Note that the result of ``clf.fit(X)`` then
170
+ ``clf.predict(X)`` with ``novelty=True`` may differ from the result
171
+ obtained by ``clf.fit_predict(X)`` with ``novelty=False``.
172
+
173
+ Parameters
174
+ ----------
175
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
176
+ The query sample or samples to compute the Local Outlier Factor
177
+ w.r.t. the training samples.
178
+
179
+ Returns
180
+ -------
181
+ is_inlier : ndarray of shape (n_samples,)
182
+ Returns -1 for anomalies/outliers and +1 for inliers.
183
+ """
145
184
  return self._predict(X)
146
185
 
147
186
  @wrap_output_data
@@ -162,6 +201,4 @@ class LocalOutlierFactor(KNeighborsDispatchingBase, sklearn_LocalOutlierFactor):
162
201
  )
163
202
 
164
203
  fit.__doc__ = sklearn_LocalOutlierFactor.fit.__doc__
165
- fit_predict.__doc__ = sklearn_LocalOutlierFactor.fit_predict.__doc__
166
- predict.__doc__ = sklearn_LocalOutlierFactor.predict.__doc__
167
204
  kneighbors.__doc__ = sklearn_LocalOutlierFactor.kneighbors.__doc__
@@ -14,20 +14,12 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
- import warnings
18
-
19
- from sklearn.neighbors._ball_tree import BallTree
20
- from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
21
- from sklearn.neighbors._kd_tree import KDTree
22
-
23
17
  from daal4py.sklearn._n_jobs_support import control_n_jobs
24
18
  from daal4py.sklearn._utils import sklearn_check_version
25
19
 
26
20
  if not sklearn_check_version("1.2"):
27
21
  from sklearn.neighbors._base import _check_weights
28
22
 
29
- import numpy as np
30
- from sklearn.neighbors._base import VALID_METRICS
31
23
  from sklearn.neighbors._classification import (
32
24
  KNeighborsClassifier as sklearn_KNeighborsClassifier,
33
25
  )
@@ -35,7 +27,6 @@ from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestN
35
27
  from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
36
28
 
37
29
  from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier
38
- from onedal.utils import _check_array, _num_features, _num_samples
39
30
 
40
31
  from .._device_offload import dispatch, wrap_output_data
41
32
  from .common import KNeighborsDispatchingBase
@@ -143,6 +134,7 @@ else:
143
134
 
144
135
  @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "kneighbors"])
145
136
  class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
137
+ __doc__ = sklearn_KNeighborsClassifier.__doc__
146
138
  if sklearn_check_version("1.2"):
147
139
  _parameter_constraints: dict = {**KNeighborsClassifier_._parameter_constraints}
148
140
 
@@ -330,3 +322,9 @@ class KNeighborsClassifier(KNeighborsClassifier_, KNeighborsDispatchingBase):
330
322
  self._fit_method = self._onedal_estimator._fit_method
331
323
  self.outputs_2d_ = self._onedal_estimator.outputs_2d_
332
324
  self._tree = self._onedal_estimator._tree
325
+
326
+ fit.__doc__ = sklearn_KNeighborsClassifier.fit.__doc__
327
+ predict.__doc__ = sklearn_KNeighborsClassifier.predict.__doc__
328
+ predict_proba.__doc__ = sklearn_KNeighborsClassifier.predict_proba.__doc__
329
+ kneighbors.__doc__ = sklearn_KNeighborsClassifier.kneighbors.__doc__
330
+ radius_neighbors.__doc__ = sklearn_NearestNeighbors.radius_neighbors.__doc__
@@ -14,20 +14,12 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
- import warnings
18
-
19
- from sklearn.neighbors._ball_tree import BallTree
20
- from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
21
- from sklearn.neighbors._kd_tree import KDTree
22
-
23
17
  from daal4py.sklearn._n_jobs_support import control_n_jobs
24
18
  from daal4py.sklearn._utils import sklearn_check_version
25
19
 
26
20
  if not sklearn_check_version("1.2"):
27
21
  from sklearn.neighbors._base import _check_weights
28
22
 
29
- import numpy as np
30
- from sklearn.neighbors._base import VALID_METRICS
31
23
  from sklearn.neighbors._regression import (
32
24
  KNeighborsRegressor as sklearn_KNeighborsRegressor,
33
25
  )
@@ -35,7 +27,6 @@ from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestN
35
27
  from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
36
28
 
37
29
  from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor
38
- from onedal.utils import _check_array, _num_features, _num_samples
39
30
 
40
31
  from .._device_offload import dispatch, wrap_output_data
41
32
  from .common import KNeighborsDispatchingBase
@@ -139,6 +130,7 @@ else:
139
130
 
140
131
  @control_n_jobs(decorated_methods=["fit", "predict", "kneighbors"])
141
132
  class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
133
+ __doc__ = sklearn_KNeighborsRegressor.__doc__
142
134
  if sklearn_check_version("1.2"):
143
135
  _parameter_constraints: dict = {**KNeighborsRegressor_._parameter_constraints}
144
136
 
@@ -306,3 +298,8 @@ class KNeighborsRegressor(KNeighborsRegressor_, KNeighborsDispatchingBase):
306
298
  self._y = self._onedal_estimator._y
307
299
  self._fit_method = self._onedal_estimator._fit_method
308
300
  self._tree = self._onedal_estimator._tree
301
+
302
+ fit.__doc__ = sklearn_KNeighborsRegressor.__doc__
303
+ predict.__doc__ = sklearn_KNeighborsRegressor.predict.__doc__
304
+ kneighbors.__doc__ = sklearn_KNeighborsRegressor.kneighbors.__doc__
305
+ radius_neighbors.__doc__ = sklearn_NearestNeighbors.radius_neighbors.__doc__
@@ -19,21 +19,13 @@ try:
19
19
  except ImportError:
20
20
  from distutils.version import LooseVersion as Version
21
21
 
22
- import warnings
23
-
24
- import numpy as np
25
22
  from sklearn import __version__ as sklearn_version
26
- from sklearn.neighbors._ball_tree import BallTree
27
- from sklearn.neighbors._base import VALID_METRICS
28
- from sklearn.neighbors._base import NeighborsBase as sklearn_NeighborsBase
29
- from sklearn.neighbors._kd_tree import KDTree
30
23
  from sklearn.neighbors._unsupervised import NearestNeighbors as sklearn_NearestNeighbors
31
24
  from sklearn.utils.validation import _deprecate_positional_args, check_is_fitted
32
25
 
33
26
  from daal4py.sklearn._n_jobs_support import control_n_jobs
34
27
  from daal4py.sklearn._utils import sklearn_check_version
35
28
  from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors
36
- from onedal.utils import _check_array, _num_features, _num_samples
37
29
 
38
30
  from .._device_offload import dispatch, wrap_output_data
39
31
  from .common import KNeighborsDispatchingBase
@@ -98,6 +90,7 @@ else:
98
90
 
99
91
  @control_n_jobs(decorated_methods=["fit", "kneighbors"])
100
92
  class NearestNeighbors(NearestNeighbors_, KNeighborsDispatchingBase):
93
+ __doc__ = sklearn_NearestNeighbors.__doc__
101
94
  if sklearn_check_version("1.2"):
102
95
  _parameter_constraints: dict = {**NearestNeighbors_._parameter_constraints}
103
96
 
@@ -219,3 +212,7 @@ class NearestNeighbors(NearestNeighbors_, KNeighborsDispatchingBase):
219
212
  self._fit_X = self._onedal_estimator._fit_X
220
213
  self._fit_method = self._onedal_estimator._fit_method
221
214
  self._tree = self._onedal_estimator._tree
215
+
216
+ fit.__doc__ = sklearn_NearestNeighbors.__doc__
217
+ kneighbors.__doc__ = sklearn_NearestNeighbors.kneighbors.__doc__
218
+ radius_neighbors.__doc__ = sklearn_NearestNeighbors.radius_neighbors.__doc__
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
- import numpy as np
18
17
  import pytest
19
18
  from numpy.testing import assert_allclose
20
19
 
@@ -33,7 +32,6 @@ from sklearnex.neighbors import (
33
32
 
34
33
  @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
35
34
  def test_sklearnex_import_knn_classifier(dataframe, queue):
36
-
37
35
  X = _convert_to_dataframe([[0], [1], [2], [3]], sycl_queue=queue, target_df=dataframe)
38
36
  y = _convert_to_dataframe([0, 0, 1, 1], sycl_queue=queue, target_df=dataframe)
39
37
  neigh = KNeighborsClassifier(n_neighbors=3).fit(X, y)
@@ -45,7 +43,6 @@ def test_sklearnex_import_knn_classifier(dataframe, queue):
45
43
 
46
44
  @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
47
45
  def test_sklearnex_import_knn_regression(dataframe, queue):
48
-
49
46
  X = _convert_to_dataframe([[0], [1], [2], [3]], sycl_queue=queue, target_df=dataframe)
50
47
  y = _convert_to_dataframe([0, 0, 1, 1], sycl_queue=queue, target_df=dataframe)
51
48
  neigh = KNeighborsRegressor(n_neighbors=2).fit(X, y)
@@ -61,7 +58,6 @@ def test_sklearnex_import_knn_regression(dataframe, queue):
61
58
  [LocalOutlierFactor, NearestNeighbors],
62
59
  )
63
60
  def test_sklearnex_kneighbors(estimator, dataframe, queue):
64
-
65
61
  X = [[0, 0, 2], [1, 0, 0], [0, 0, 1]]
66
62
  X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
67
63
  test = _convert_to_dataframe([[0, 0, 1.3]], sycl_queue=queue, target_df=dataframe)
@@ -74,7 +70,6 @@ def test_sklearnex_kneighbors(estimator, dataframe, queue):
74
70
 
75
71
  @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
76
72
  def test_sklearnex_import_lof(dataframe, queue):
77
-
78
73
  X = [[7, 7, 7], [1, 0, 0], [0, 0, 1], [0, 0, 1]]
79
74
  X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
80
75
  lof = LocalOutlierFactor(n_neighbors=2)
@@ -14,4 +14,4 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
- __all__ = ["cluster", "covariance", "decomposition"]
17
+ __all__ = ["cluster", "covariance"]
@@ -14,8 +14,6 @@
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
16
 
17
- from abc import ABC
18
-
19
17
  from onedal.spmd.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
20
18
  from onedal.spmd.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
21
19
 
@@ -23,16 +21,9 @@ from ...ensemble import RandomForestClassifier as RandomForestClassifier_Batch
23
21
  from ...ensemble import RandomForestRegressor as RandomForestRegressor_Batch
24
22
 
25
23
 
26
- class BaseForestSPMD(ABC):
27
- def _onedal_classifier(self, **onedal_params):
28
- return onedal_RandomForestClassifier(**onedal_params)
29
-
30
- def _onedal_regressor(self, **onedal_params):
31
- return onedal_RandomForestRegressor(**onedal_params)
32
-
33
-
34
- class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
24
+ class RandomForestClassifier(RandomForestClassifier_Batch):
35
25
  __doc__ = RandomForestClassifier_Batch.__doc__
26
+ _onedal_factory = onedal_RandomForestClassifier
36
27
 
37
28
  def _onedal_cpu_supported(self, method_name, *data):
38
29
  # TODO:
@@ -55,8 +46,9 @@ class RandomForestClassifier(BaseForestSPMD, RandomForestClassifier_Batch):
55
46
  return ready
56
47
 
57
48
 
58
- class RandomForestRegressor(BaseForestSPMD, RandomForestRegressor_Batch):
49
+ class RandomForestRegressor(RandomForestRegressor_Batch):
59
50
  __doc__ = RandomForestRegressor_Batch.__doc__
51
+ _onedal_factory = onedal_RandomForestRegressor
60
52
 
61
53
  def _onedal_cpu_supported(self, method_name, *data):
62
54
  # TODO:
@@ -198,6 +198,8 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
198
198
  self._check_proba()
199
199
  return self._predict_proba
200
200
 
201
+ predict_proba.__doc__ = sklearn_NuSVC.predict_proba.__doc__
202
+
201
203
  @wrap_output_data
202
204
  def _predict_proba(self, X):
203
205
  if sklearn_check_version("1.0"):
@@ -232,6 +234,8 @@ class NuSVC(sklearn_NuSVC, BaseSVC):
232
234
  X,
233
235
  )
234
236
 
237
+ decision_function.__doc__ = sklearn_NuSVC.decision_function.__doc__
238
+
235
239
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
236
240
  onedal_params = {
237
241
  "nu": self.nu,
@@ -200,6 +200,8 @@ class SVC(sklearn_SVC, BaseSVC):
200
200
  self._check_proba()
201
201
  return self._predict_proba
202
202
 
203
+ predict_proba.__doc__ = sklearn_SVC.predict_proba.__doc__
204
+
203
205
  @wrap_output_data
204
206
  def _predict_proba(self, X):
205
207
  sklearn_pred_proba = (
@@ -232,6 +234,8 @@ class SVC(sklearn_SVC, BaseSVC):
232
234
  X,
233
235
  )
234
236
 
237
+ decision_function.__doc__ = sklearn_SVC.decision_function.__doc__
238
+
235
239
  def _onedal_gpu_supported(self, method_name, *data):
236
240
  class_name = self.__class__.__name__
237
241
  patching_status = PatchingConditionsChain(
@@ -0,0 +1,155 @@
1
+ # ==============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ from inspect import isclass
18
+
19
+ import numpy as np
20
+ from sklearn.base import (
21
+ BaseEstimator,
22
+ ClassifierMixin,
23
+ ClusterMixin,
24
+ OutlierMixin,
25
+ RegressorMixin,
26
+ TransformerMixin,
27
+ )
28
+ from sklearn.datasets import load_diabetes, load_iris
29
+ from sklearn.neighbors._base import KNeighborsMixin
30
+
31
+ from onedal.tests.utils._dataframes_support import _convert_to_dataframe
32
+ from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn
33
+ from sklearnex.neighbors import (
34
+ KNeighborsClassifier,
35
+ KNeighborsRegressor,
36
+ LocalOutlierFactor,
37
+ NearestNeighbors,
38
+ )
39
+ from sklearnex.svm import SVC, NuSVC
40
+
41
+
42
+ def _load_all_models(with_sklearnex=True, estimator=True):
43
+ # insure that patch state is correct as dictated by patch_sklearn boolean
44
+ # and return it to the previous state no matter what occurs.
45
+ already_patched_map = sklearn_is_patched(return_map=True)
46
+ already_patched = any(already_patched_map.values())
47
+ try:
48
+ if with_sklearnex:
49
+ patch_sklearn()
50
+ elif already_patched:
51
+ unpatch_sklearn()
52
+
53
+ models = {}
54
+ for patch_infos in get_patch_map().values():
55
+ candidate = getattr(patch_infos[0][0][0], patch_infos[0][0][1], None)
56
+ if candidate is not None and isclass(candidate) == estimator:
57
+ if not estimator or issubclass(candidate, BaseEstimator):
58
+ models[patch_infos[0][0][1]] = candidate
59
+ finally:
60
+ if with_sklearnex:
61
+ unpatch_sklearn()
62
+ # both branches are now in an unpatched state, repatch as necessary
63
+ if already_patched:
64
+ patch_sklearn(name=[i for i in already_patched_map if already_patched_map[i]])
65
+
66
+ return models
67
+
68
+
69
+ PATCHED_MODELS = _load_all_models(with_sklearnex=True)
70
+ UNPATCHED_MODELS = _load_all_models(with_sklearnex=False)
71
+
72
+ PATCHED_FUNCTIONS = _load_all_models(with_sklearnex=True, estimator=False)
73
+ UNPATCHED_FUNCTIONS = _load_all_models(with_sklearnex=False, estimator=False)
74
+
75
+ mixin_map = [
76
+ [
77
+ ClassifierMixin,
78
+ ["decision_function", "predict", "predict_proba", "predict_log_proba", "score"],
79
+ "classification",
80
+ ],
81
+ [RegressorMixin, ["predict", "score"], "regression"],
82
+ [ClusterMixin, ["fit_predict"], "classification"],
83
+ [TransformerMixin, ["fit_transform", "transform", "score"], "classification"],
84
+ [OutlierMixin, ["fit_predict", "predict"], "classification"],
85
+ [KNeighborsMixin, ["kneighbors"], None],
86
+ ]
87
+
88
+
89
+ SPECIAL_INSTANCES = {
90
+ str(i): i
91
+ for i in [
92
+ LocalOutlierFactor(novelty=True),
93
+ SVC(probability=True),
94
+ NuSVC(probability=True),
95
+ KNeighborsClassifier(algorithm="brute"),
96
+ KNeighborsRegressor(algorithm="brute"),
97
+ NearestNeighbors(algorithm="brute"),
98
+ ]
99
+ }
100
+
101
+
102
+ def gen_models_info(algorithms):
103
+ output = []
104
+ for i in algorithms:
105
+ # split handles SPECIAL_INSTANCES or custom inputs
106
+ # custom sklearn inputs must be a dict of estimators
107
+ # with keys set by the __str__ method
108
+ est = PATCHED_MODELS[i.split("(")[0]]
109
+
110
+ methods = set()
111
+ candidates = set(
112
+ [i for i in dir(est) if not i.startswith("_") and not i.endswith("_")]
113
+ )
114
+
115
+ for mixin, method, _ in mixin_map:
116
+ if issubclass(est, mixin):
117
+ methods |= candidates & set(method)
118
+
119
+ output += [[i, j] for j in methods]
120
+ return output
121
+
122
+
123
+ def gen_dataset(estimator, queue=None, target_df=None, dtype=np.float64):
124
+ dataset = None
125
+ name = estimator.__class__.__name__
126
+ est = PATCHED_MODELS[name]
127
+ for mixin, _, data in mixin_map:
128
+ if issubclass(est, mixin) and data is not None:
129
+ dataset = data
130
+ # load data
131
+ if dataset == "classification" or dataset is None:
132
+ X, y = load_iris(return_X_y=True)
133
+ elif dataset == "regression":
134
+ X, y = load_diabetes(return_X_y=True)
135
+ else:
136
+ raise ValueError("Unknown dataset type")
137
+
138
+ X = _convert_to_dataframe(X, sycl_queue=queue, target_df=target_df, dtype=dtype)
139
+ y = _convert_to_dataframe(y, sycl_queue=queue, target_df=target_df, dtype=dtype)
140
+ return X, y
141
+
142
+
143
+ DTYPES = [
144
+ np.int8,
145
+ np.int16,
146
+ np.int32,
147
+ np.int64,
148
+ np.float16,
149
+ np.float32,
150
+ np.float64,
151
+ np.uint8,
152
+ np.uint16,
153
+ np.uint32,
154
+ np.uint64,
155
+ ]