scikit-learn-intelex 2024.1.0__py311-none-win_amd64.whl → 2024.3.0__py311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +6 -4
  3. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  4. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  5. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +130 -0
  6. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  7. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +338 -0
  8. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  9. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +91 -59
  10. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +15 -24
  11. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  12. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +1 -2
  13. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +3 -10
  14. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +32 -40
  15. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +91 -0
  16. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
  17. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +204 -0
  18. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +13 -18
  19. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +12 -17
  20. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +10 -15
  21. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +12 -16
  22. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  23. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +3 -8
  24. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +46 -12
  25. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
  26. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +19 -0
  27. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
  28. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  29. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
  30. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  31. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +9 -6
  32. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +6 -7
  33. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +9 -6
  34. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -4
  35. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +155 -0
  36. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +9 -7
  37. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  38. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +93 -0
  39. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  40. scikit_learn_intelex-2024.3.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +361 -0
  41. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/METADATA +2 -2
  42. scikit_learn_intelex-2024.3.0.dist-info/RECORD +98 -0
  43. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  44. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  45. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
  46. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
  47. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  48. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -376
  49. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -19
  50. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  51. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  52. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -227
  53. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -31
  54. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -122
  55. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  56. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  57. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  58. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  59. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  60. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  61. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  62. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  63. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  64. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  65. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  66. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  67. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  69. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  70. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  71. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  72. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  73. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  74. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  75. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  76. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  77. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  78. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  79. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  80. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  81. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  82. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  83. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  84. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  85. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  86. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  87. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  89. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  91. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  92. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  93. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  94. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  95. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  96. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  98. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  99. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  100. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  101. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  102. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  104. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  105. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  106. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  107. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  108. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  109. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.3.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  110. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/LICENSE.txt +0 -0
  111. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/WHEEL +0 -0
  112. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.3.0.dist-info}/top_level.txt +0 -0
@@ -42,11 +42,10 @@ from sklearn.utils.validation import (
42
42
  check_X_y,
43
43
  )
44
44
 
45
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
45
46
  from daal4py.sklearn._utils import (
46
47
  check_tree_nodes,
47
- control_n_jobs,
48
48
  daal_check_version,
49
- run_with_n_jobs,
50
49
  sklearn_check_version,
51
50
  )
52
51
  from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
@@ -78,7 +77,6 @@ if sklearn_check_version("1.4"):
78
77
  class BaseForest(ABC):
79
78
  _onedal_factory = None
80
79
 
81
- @run_with_n_jobs
82
80
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
83
81
  if sklearn_check_version("0.24"):
84
82
  X, y = self._validate_data(
@@ -455,14 +453,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
455
453
 
456
454
  # The estimator is checked against the class attribute for conformance.
457
455
  # This should only trigger if the user uses this class directly.
458
- if (
459
- self.estimator.__class__ == DecisionTreeClassifier
460
- and self._onedal_factory != onedal_RandomForestClassifier
456
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
457
+ self._onedal_factory, onedal_RandomForestClassifier
461
458
  ):
462
459
  self._onedal_factory = onedal_RandomForestClassifier
463
- elif (
464
- self.estimator.__class__ == ExtraTreeClassifier
465
- and self._onedal_factory != onedal_ExtraTreesClassifier
460
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
461
+ self._onedal_factory, onedal_ExtraTreesClassifier
466
462
  ):
467
463
  self._onedal_factory = onedal_ExtraTreesClassifier
468
464
 
@@ -749,7 +745,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
749
745
  or self.estimator.__class__ == DecisionTreeClassifier,
750
746
  "ExtraTrees only supported starting from oneDAL version 2023.1",
751
747
  ),
752
- (sample_weight is not None, "sample_weight is not supported."),
748
+ (sample_weight is None, "sample_weight is not supported."),
753
749
  ]
754
750
  )
755
751
 
@@ -787,7 +783,6 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
787
783
 
788
784
  return patching_status
789
785
 
790
- @run_with_n_jobs
791
786
  def _onedal_predict(self, X, queue=None):
792
787
  X = check_array(
793
788
  X,
@@ -802,7 +797,6 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
802
797
  res = self._onedal_estimator.predict(X, queue=queue)
803
798
  return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
804
799
 
805
- @run_with_n_jobs
806
800
  def _onedal_predict_proba(self, X, queue=None):
807
801
  X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
808
802
  check_is_fitted(self, "_onedal_estimator")
@@ -847,14 +841,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
847
841
 
848
842
  # The splitter is checked against the class attribute for conformance
849
843
  # This should only trigger if the user uses this class directly.
850
- if (
851
- self.estimator.__class__ == DecisionTreeRegressor
852
- and self._onedal_factory != onedal_RandomForestRegressor
844
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
845
+ self._onedal_factory, onedal_RandomForestRegressor
853
846
  ):
854
847
  self._onedal_factory = onedal_RandomForestRegressor
855
- elif (
856
- self.estimator.__class__ == ExtraTreeRegressor
857
- and self._onedal_factory != onedal_ExtraTreesRegressor
848
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
849
+ self._onedal_factory, onedal_ExtraTreesRegressor
858
850
  ):
859
851
  self._onedal_factory = onedal_ExtraTreesRegressor
860
852
 
@@ -1060,7 +1052,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1060
1052
  or self.estimator.__class__ == DecisionTreeClassifier,
1061
1053
  "ExtraTrees only supported starting from oneDAL version 2023.1",
1062
1054
  ),
1063
- (sample_weight is not None, "sample_weight is not supported."),
1055
+ (sample_weight is None, "sample_weight is not supported."),
1064
1056
  ]
1065
1057
  )
1066
1058
 
@@ -1096,7 +1088,6 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1096
1088
 
1097
1089
  return patching_status
1098
1090
 
1099
- @run_with_n_jobs
1100
1091
  def _onedal_predict(self, X, queue=None):
1101
1092
  X = check_array(
1102
1093
  X, dtype=[np.float64, np.float32], force_all_finite=False
@@ -1138,7 +1129,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1138
1129
  predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1139
1130
 
1140
1131
 
1141
- @control_n_jobs
1132
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1142
1133
  class RandomForestClassifier(ForestClassifier):
1143
1134
  __doc__ = sklearn_RandomForestClassifier.__doc__
1144
1135
  _onedal_factory = onedal_RandomForestClassifier
@@ -1348,7 +1339,7 @@ class RandomForestClassifier(ForestClassifier):
1348
1339
  self.min_bin_size = min_bin_size
1349
1340
 
1350
1341
 
1351
- @control_n_jobs
1342
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1352
1343
  class RandomForestRegressor(ForestRegressor):
1353
1344
  __doc__ = sklearn_RandomForestRegressor.__doc__
1354
1345
  _onedal_factory = onedal_RandomForestRegressor
@@ -1549,7 +1540,7 @@ class RandomForestRegressor(ForestRegressor):
1549
1540
  self.min_bin_size = min_bin_size
1550
1541
 
1551
1542
 
1552
- @control_n_jobs
1543
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1553
1544
  class ExtraTreesClassifier(ForestClassifier):
1554
1545
  __doc__ = sklearn_ExtraTreesClassifier.__doc__
1555
1546
  _onedal_factory = onedal_ExtraTreesClassifier
@@ -1759,7 +1750,7 @@ class ExtraTreesClassifier(ForestClassifier):
1759
1750
  self.min_bin_size = min_bin_size
1760
1751
 
1761
1752
 
1762
- @control_n_jobs
1753
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1763
1754
  class ExtraTreesRegressor(ForestRegressor):
1764
1755
  __doc__ = sklearn_ExtraTreesRegressor.__doc__
1765
1756
  _onedal_factory = onedal_ExtraTreesRegressor
@@ -45,11 +45,7 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
45
45
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
46
46
 
47
47
 
48
- # TODO:
49
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
50
- @pytest.mark.parametrize(
51
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
52
- )
48
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
53
49
  def test_sklearnex_import_rf_regression(dataframe, queue):
54
50
  from sklearnex.ensemble import RandomForestRegressor
55
51
 
@@ -59,17 +55,17 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
59
55
  rf = RandomForestRegressor(max_depth=2, random_state=0).fit(X, y)
60
56
  assert "sklearnex" in rf.__module__
61
57
  pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))
62
- if daal_check_version((2024, "P", 0)):
63
- assert_allclose([-6.971], pred, atol=1e-2)
58
+
59
+ if queue is not None and queue.sycl_device.is_gpu:
60
+ assert_allclose([-0.011208], pred, atol=1e-2)
64
61
  else:
65
- assert_allclose([-6.839], pred, atol=1e-2)
62
+ if daal_check_version((2024, "P", 0)):
63
+ assert_allclose([-6.971], pred, atol=1e-2)
64
+ else:
65
+ assert_allclose([-6.839], pred, atol=1e-2)
66
66
 
67
67
 
68
- # TODO:
69
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
70
- @pytest.mark.parametrize(
71
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
72
- )
68
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
73
69
  def test_sklearnex_import_et_classifier(dataframe, queue):
74
70
  from sklearnex.ensemble import ExtraTreesClassifier
75
71
 
@@ -90,11 +86,7 @@ def test_sklearnex_import_et_classifier(dataframe, queue):
90
86
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
91
87
 
92
88
 
93
- # TODO:
94
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
95
- @pytest.mark.parametrize(
96
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
97
- )
89
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
98
90
  def test_sklearnex_import_et_regression(dataframe, queue):
99
91
  from sklearnex.ensemble import ExtraTreesRegressor
100
92
 
@@ -114,4 +106,8 @@ def test_sklearnex_import_et_regression(dataframe, queue):
114
106
  ]
115
107
  )
116
108
  )
117
- assert_allclose([0.445], pred, atol=1e-2)
109
+
110
+ if queue is not None and queue.sycl_device.is_gpu:
111
+ assert_allclose([1.909769], pred, atol=1e-2)
112
+ else:
113
+ assert_allclose([0.445], pred, atol=1e-2)
@@ -16,14 +16,13 @@
16
16
 
17
17
  from .coordinate_descent import ElasticNet, Lasso
18
18
  from .linear import LinearRegression
19
- from .logistic_path import LogisticRegression, logistic_regression_path
19
+ from .logistic_regression import LogisticRegression
20
20
  from .ridge import Ridge
21
21
 
22
22
  __all__ = [
23
23
  "Ridge",
24
24
  "LinearRegression",
25
25
  "LogisticRegression",
26
- "logistic_regression_path",
27
26
  "ElasticNet",
28
27
  "Lasso",
29
28
  ]
@@ -65,13 +65,8 @@ if daal_check_version((2023, "P", 100)):
65
65
  import numpy as np
66
66
  from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
67
67
 
68
- from daal4py.sklearn._utils import (
69
- control_n_jobs,
70
- get_dtype,
71
- make2d,
72
- run_with_n_jobs,
73
- sklearn_check_version,
74
- )
68
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
69
+ from daal4py.sklearn._utils import get_dtype, make2d, sklearn_check_version
75
70
 
76
71
  from .._device_offload import dispatch, wrap_output_data
77
72
  from .._utils import (
@@ -93,7 +88,7 @@ if daal_check_version((2023, "P", 100)):
93
88
  from onedal.utils import _num_features, _num_samples
94
89
 
95
90
  @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
96
- @control_n_jobs
91
+ @control_n_jobs(decorated_methods=["fit", "predict"])
97
92
  class LinearRegression(sklearn_LinearRegression, BaseLinearRegression):
98
93
  __doc__ = sklearn_LinearRegression.__doc__
99
94
  intercept_, coef_ = None, None
@@ -330,7 +325,6 @@ if daal_check_version((2023, "P", 100)):
330
325
  onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
331
326
  self._onedal_estimator = onedal_LinearRegression(**onedal_params)
332
327
 
333
- @run_with_n_jobs
334
328
  def _onedal_fit(self, X, y, sample_weight, queue=None):
335
329
  assert sample_weight is None
336
330
 
@@ -369,7 +363,6 @@ if daal_check_version((2023, "P", 100)):
369
363
  del self._onedal_estimator
370
364
  super().fit(X, y)
371
365
 
372
- @run_with_n_jobs
373
366
  def _onedal_predict(self, X, queue=None):
374
367
  X = self._validate_data(X, accept_sparse=False, reset=False)
375
368
  if not hasattr(self, "_onedal_estimator"):
@@ -1,5 +1,5 @@
1
1
  # ===============================================================================
2
- # Copyright 2023 Intel Corporation
2
+ # Copyright 2024 Intel Corporation
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,14 +17,11 @@
17
17
  import logging
18
18
  from abc import ABC
19
19
 
20
- import sklearn.linear_model._logistic as logistic_module
21
-
22
20
  from daal4py.sklearn._utils import daal_check_version
23
21
  from daal4py.sklearn.linear_model.logistic_path import (
24
- LogisticRegression,
25
- daal4py_predict,
26
- logistic_regression_path,
22
+ LogisticRegression as LogisticRegression_daal4py,
27
23
  )
24
+ from daal4py.sklearn.linear_model.logistic_path import daal4py_fit, daal4py_predict
28
25
 
29
26
 
30
27
  class BaseLogisticRegression(ABC):
@@ -43,14 +40,18 @@ if daal_check_version((2024, "P", 1)):
43
40
  from sklearn.linear_model import LogisticRegression as sklearn_LogisticRegression
44
41
  from sklearn.utils.validation import check_X_y
45
42
 
43
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
46
44
  from daal4py.sklearn._utils import sklearn_check_version
47
45
  from onedal.linear_model import LogisticRegression as onedal_LogisticRegression
48
46
  from onedal.utils import _num_features, _num_samples
49
47
 
50
- from ..._device_offload import dispatch, wrap_output_data
51
- from ..._utils import PatchingConditionsChain, get_patch_message
52
- from ...utils.validation import _assert_all_finite
48
+ from .._device_offload import dispatch, wrap_output_data
49
+ from .._utils import PatchingConditionsChain, get_patch_message
50
+ from ..utils.validation import _assert_all_finite
53
51
 
52
+ @control_n_jobs(
53
+ decorated_methods=["fit", "predict", "predict_proba", "predict_log_proba"]
54
+ )
54
55
  class LogisticRegression(sklearn_LogisticRegression, BaseLogisticRegression):
55
56
  __doc__ = sklearn_LogisticRegression.__doc__
56
57
  intercept_, coef_, n_iter_ = None, None, None
@@ -97,6 +98,8 @@ if daal_check_version((2024, "P", 1)):
97
98
  l1_ratio=l1_ratio,
98
99
  )
99
100
 
101
+ _onedal_cpu_fit = daal4py_fit
102
+
100
103
  def fit(self, X, y, sample_weight=None):
101
104
  if sklearn_check_version("1.0"):
102
105
  self._check_feature_names(X, reset=True)
@@ -160,10 +163,8 @@ if daal_check_version((2024, "P", 1)):
160
163
  def _test_type_and_finiteness(self, X_in):
161
164
  X = np.asarray(X_in)
162
165
 
163
- dtype = X.dtype
164
- if "complex" in str(type(dtype)):
166
+ if np.iscomplexobj(X):
165
167
  return False
166
-
167
168
  try:
168
169
  _assert_all_finite(X)
169
170
  except BaseException:
@@ -184,7 +185,10 @@ if daal_check_version((2024, "P", 1)):
184
185
  [
185
186
  (self.penalty == "l2", "Only l2 penalty is supported."),
186
187
  (self.dual == False, "dual=True is not supported."),
187
- (self.intercept_scaling == 1, "Intercept scaling is not supported."),
188
+ (
189
+ self.intercept_scaling == 1,
190
+ "Intercept scaling is not supported.",
191
+ ),
188
192
  (self.class_weight is None, "Class weight is not supported"),
189
193
  (self.solver == "newton-cg", "Only newton-cg solver is supported."),
190
194
  (
@@ -229,7 +233,10 @@ if daal_check_version((2024, "P", 1)):
229
233
  (n_samples > 0, "Number of samples is less than 1."),
230
234
  (not issparse(*data), "Sparse input is not supported."),
231
235
  (not model_is_sparse, "Sparse coefficients are not supported."),
232
- (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
236
+ (
237
+ hasattr(self, "_onedal_estimator"),
238
+ "oneDAL model was not trained.",
239
+ ),
233
240
  ]
234
241
  )
235
242
  if not dal_ready:
@@ -268,15 +275,6 @@ if daal_check_version((2024, "P", 1)):
268
275
  }
269
276
  self._onedal_estimator = onedal_LogisticRegression(**onedal_params)
270
277
 
271
- def _onedal_cpu_fit(self, X, y, sample_weight):
272
- which, what = logistic_module, "_logistic_regression_path"
273
- replacer = logistic_regression_path
274
- descriptor = getattr(which, what, None)
275
- setattr(which, what, replacer)
276
- clf = super().fit(X, y, sample_weight)
277
- setattr(which, what, descriptor)
278
- return clf
279
-
280
278
  def _onedal_fit(self, X, y, sample_weight, queue=None):
281
279
  if queue is None or queue.sycl_device.is_cpu:
282
280
  return self._onedal_cpu_fit(X, y, sample_weight)
@@ -313,38 +311,32 @@ if daal_check_version((2024, "P", 1)):
313
311
  return daal4py_predict(self, X, "computeClassLabels")
314
312
 
315
313
  X = self._validate_data(X, accept_sparse=False, reset=False)
316
- if not hasattr(self, "_onedal_estimator"):
317
- self._initialize_onedal_estimator()
318
- self._onedal_estimator.coef_ = self.coef_
319
- self._onedal_estimator.intercept_ = self.intercept_
320
- self._onedal_estimator.classes_ = self.classes_
321
-
314
+ assert hasattr(self, "_onedal_estimator")
322
315
  return self._onedal_estimator.predict(X, queue=queue)
323
316
 
324
317
  def _onedal_predict_proba(self, X, queue=None):
325
318
  if queue is None or queue.sycl_device.is_cpu:
326
319
  return daal4py_predict(self, X, "computeClassProbabilities")
327
- X = self._validate_data(X, accept_sparse=False, reset=False)
328
- if not hasattr(self, "_onedal_estimator"):
329
- self._initialize_onedal_estimator()
330
- self._onedal_estimator.coef_ = self.coef_
331
- self._onedal_estimator.intercept_ = self.intercept_
332
320
 
321
+ X = self._validate_data(X, accept_sparse=False, reset=False)
322
+ assert hasattr(self, "_onedal_estimator")
333
323
  return self._onedal_estimator.predict_proba(X, queue=queue)
334
324
 
335
325
  def _onedal_predict_log_proba(self, X, queue=None):
336
326
  if queue is None or queue.sycl_device.is_cpu:
337
327
  return daal4py_predict(self, X, "computeClassLogProbabilities")
338
- X = self._validate_data(X, accept_sparse=False, reset=False)
339
- if not hasattr(self, "_onedal_estimator"):
340
- self._initialize_onedal_estimator()
341
- self._onedal_estimator.coef_ = self.coef_
342
- self._onedal_estimator.intercept_ = self.intercept_
343
328
 
329
+ X = self._validate_data(X, accept_sparse=False, reset=False)
330
+ assert hasattr(self, "_onedal_estimator")
344
331
  return self._onedal_estimator.predict_log_proba(X, queue=queue)
345
332
 
333
+ fit.__doc__ = sklearn_LogisticRegression.fit.__doc__
334
+ predict.__doc__ = sklearn_LogisticRegression.predict.__doc__
335
+ predict_proba.__doc__ = sklearn_LogisticRegression.predict_proba.__doc__
336
+ predict_log_proba.__doc__ = sklearn_LogisticRegression.predict_log_proba.__doc__
337
+
346
338
  else:
347
- from daal4py.sklearn.linear_model import LogisticRegression
339
+ LogisticRegression = LogisticRegression_daal4py
348
340
 
349
341
  logging.warning(
350
342
  "Sklearnex LogisticRegression requires oneDAL version >= 2024.0.1 "
@@ -0,0 +1,91 @@
1
+ # ===============================================================================
2
+ # Copyright 2021 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import pytest
18
+ from sklearn.datasets import load_breast_cancer, load_iris
19
+ from sklearn.metrics import accuracy_score
20
+ from sklearn.model_selection import train_test_split
21
+
22
+ from daal4py.sklearn._utils import daal_check_version
23
+ from onedal.tests.utils._dataframes_support import (
24
+ _as_numpy,
25
+ _convert_to_dataframe,
26
+ get_dataframes_and_queues,
27
+ )
28
+
29
+
30
+ def prepare_input(X, y, dataframe, queue):
31
+ X_train, X_test, y_train, y_test = train_test_split(
32
+ X, y, train_size=0.8, random_state=42
33
+ )
34
+ X_train = _convert_to_dataframe(X_train, sycl_queue=queue, target_df=dataframe)
35
+ y_train = _convert_to_dataframe(y_train, sycl_queue=queue, target_df=dataframe)
36
+ X_test = _convert_to_dataframe(X_test, sycl_queue=queue, target_df=dataframe)
37
+ return X_train, X_test, y_train, y_test
38
+
39
+
40
+ @pytest.mark.parametrize(
41
+ "dataframe,queue",
42
+ get_dataframes_and_queues(device_filter_="cpu"),
43
+ )
44
+ def test_sklearnex_multiclass_classification(dataframe, queue):
45
+ from sklearnex.linear_model import LogisticRegression
46
+
47
+ X, y = load_iris(return_X_y=True)
48
+ X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue)
49
+
50
+ logreg = LogisticRegression(fit_intercept=True, solver="lbfgs", max_iter=200).fit(
51
+ X_train, y_train
52
+ )
53
+
54
+ if daal_check_version((2024, "P", 1)):
55
+ assert "sklearnex" in logreg.__module__
56
+ else:
57
+ assert "daal4py" in logreg.__module__
58
+
59
+ y_pred = _as_numpy(logreg.predict(X_test))
60
+ assert accuracy_score(y_test, y_pred) > 0.99
61
+
62
+
63
+ @pytest.mark.parametrize(
64
+ "dataframe,queue",
65
+ get_dataframes_and_queues(),
66
+ )
67
+ def test_sklearnex_binary_classification(dataframe, queue):
68
+ from sklearnex.linear_model import LogisticRegression
69
+
70
+ X, y = load_breast_cancer(return_X_y=True)
71
+ X_train, X_test, y_train, y_test = prepare_input(X, y, dataframe, queue)
72
+
73
+ logreg = LogisticRegression(fit_intercept=True, solver="newton-cg", max_iter=100).fit(
74
+ X_train, y_train
75
+ )
76
+
77
+ if daal_check_version((2024, "P", 1)):
78
+ assert "sklearnex" in logreg.__module__
79
+ else:
80
+ assert "daal4py" in logreg.__module__
81
+ if (
82
+ dataframe != "numpy"
83
+ and queue is not None
84
+ and queue.sycl_device.is_gpu
85
+ and daal_check_version((2024, "P", 1))
86
+ ):
87
+ # fit was done on gpu
88
+ assert hasattr(logreg, "_onedal_estimator")
89
+
90
+ y_pred = _as_numpy(logreg.predict(X_test))
91
+ assert accuracy_score(y_test, y_pred) > 0.95
@@ -14,10 +14,10 @@
14
14
  # limitations under the License.
15
15
  # ===============================================================================
16
16
 
17
+ from ._lof import LocalOutlierFactor
17
18
  from .knn_classification import KNeighborsClassifier
18
19
  from .knn_regression import KNeighborsRegressor
19
20
  from .knn_unsupervised import NearestNeighbors
20
- from .lof import LocalOutlierFactor
21
21
 
22
22
  __all__ = [
23
23
  "KNeighborsClassifier",