scikit-learn-intelex 2024.1.0__py312-none-win_amd64.whl → 2024.2.0__py312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (107) hide show
  1. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -3
  2. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/covariance/__init__.py +19 -0
  3. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +130 -0
  4. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +143 -0
  5. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +19 -18
  6. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +5 -10
  7. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +1 -2
  8. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/linear.py +3 -10
  9. {scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex}/linear_model/logistic_regression.py +19 -38
  10. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +93 -0
  11. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +1 -1
  12. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/neighbors/_lof.py +167 -0
  13. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +6 -9
  14. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +6 -8
  15. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +5 -7
  16. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +12 -11
  17. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  18. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +3 -8
  19. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +46 -12
  20. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +3 -5
  21. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +1 -0
  22. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +19 -0
  23. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +21 -0
  24. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +2 -1
  25. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +21 -0
  26. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +5 -6
  27. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +3 -4
  28. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +5 -6
  29. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -4
  30. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +1 -4
  31. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +33 -20
  32. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +93 -0
  33. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +19 -5
  34. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.2.0.dist-info}/METADATA +2 -2
  35. scikit_learn_intelex-2024.2.0.dist-info/RECORD +101 -0
  36. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -28
  37. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/neighbors/lof.py +0 -436
  38. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/__init__.py +0 -19
  39. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_logistic_regression.py +0 -59
  40. scikit_learn_intelex-2024.1.0.data/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +0 -31
  41. scikit_learn_intelex-2024.1.0.dist-info/RECORD +0 -97
  42. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/__init__.py +0 -0
  43. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  44. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  45. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +0 -0
  46. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  47. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +0 -0
  48. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  49. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  50. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  51. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  52. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  53. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  54. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -0
  55. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -0
  56. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  57. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  58. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -0
  59. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  60. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  61. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  62. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  63. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  64. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +0 -0
  65. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  66. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  67. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  68. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  69. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  70. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  71. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  72. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  73. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  74. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  75. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +0 -0
  76. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  77. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  78. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  79. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +0 -0
  80. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -0
  81. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -0
  82. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +0 -0
  83. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  84. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  86. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  87. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  89. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +0 -0
  91. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  92. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  93. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  94. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +0 -0
  96. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  97. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -0
  98. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  99. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +0 -0
  100. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  101. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -0
  102. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  104. {scikit_learn_intelex-2024.1.0.data → scikit_learn_intelex-2024.2.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  105. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.2.0.dist-info}/LICENSE.txt +0 -0
  106. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.2.0.dist-info}/WHEEL +0 -0
  107. {scikit_learn_intelex-2024.1.0.dist-info → scikit_learn_intelex-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,8 @@ from scipy import sparse as sp
22
22
  from sklearn.cluster import DBSCAN as sklearn_DBSCAN
23
23
  from sklearn.utils.validation import _check_sample_weight
24
24
 
25
- from daal4py.sklearn._utils import control_n_jobs, run_with_n_jobs, sklearn_check_version
25
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
26
+ from daal4py.sklearn._utils import sklearn_check_version
26
27
  from onedal.cluster import DBSCAN as onedal_DBSCAN
27
28
 
28
29
  from .._device_offload import dispatch, wrap_output_data
@@ -45,7 +46,7 @@ class BaseDBSCAN(ABC):
45
46
  self.n_features_in_ = self._onedal_estimator.n_features_in_
46
47
 
47
48
 
48
- @control_n_jobs
49
+ @control_n_jobs(decorated_methods=["fit"])
49
50
  class DBSCAN(sklearn_DBSCAN, BaseDBSCAN):
50
51
  __doc__ = sklearn_DBSCAN.__doc__
51
52
 
@@ -83,7 +84,6 @@ class DBSCAN(sklearn_DBSCAN, BaseDBSCAN):
83
84
  self.p = p
84
85
  self.n_jobs = n_jobs
85
86
 
86
- @run_with_n_jobs
87
87
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
88
88
  onedal_params = {
89
89
  "eps": self.eps,
@@ -0,0 +1,19 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ from .incremental_covariance import IncrementalEmpiricalCovariance
18
+
19
+ __all__ = ["IncrementalEmpiricalCovariance"]
@@ -0,0 +1,130 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ from sklearn.utils import check_array, gen_batches
19
+
20
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
21
+ from onedal._device_offload import support_usm_ndarray
22
+ from onedal.covariance import (
23
+ IncrementalEmpiricalCovariance as onedal_IncrementalEmpiricalCovariance,
24
+ )
25
+
26
+
27
+ @control_n_jobs(decorated_methods=["partial_fit"])
28
+ class IncrementalEmpiricalCovariance:
29
+ """
30
+ Incremental estimator for covariance.
31
+ Allows to compute empirical covariance estimated by maximum
32
+ likelihood method if data are splitted into batches.
33
+
34
+ Parameters
35
+ ----------
36
+ batch_size : int, default=None
37
+ The number of samples to use for each batch. Only used when calling
38
+ ``fit``. If ``batch_size`` is ``None``, then ``batch_size``
39
+ is inferred from the data and set to ``5 * n_features``, to provide a
40
+ balance between approximation accuracy and memory consumption.
41
+
42
+ Attributes
43
+ ----------
44
+ location_ : ndarray of shape (n_features,)
45
+ Estimated location, i.e. the estimated mean.
46
+
47
+ covariance_ : ndarray of shape (n_features, n_features)
48
+ Estimated covariance matrix
49
+ """
50
+
51
+ _onedal_incremental_covariance = staticmethod(onedal_IncrementalEmpiricalCovariance)
52
+
53
+ def __init__(self, batch_size=None):
54
+ self._need_to_finalize = False # If True then finalize compute should
55
+ # be called to obtain covariance_ or location_ from partial compute data
56
+ self.batch_size = batch_size
57
+
58
+ def _onedal_finalize_fit(self):
59
+ assert hasattr(self, "_onedal_estimator")
60
+ self._onedal_estimator.finalize_fit()
61
+ self._need_to_finalize = False
62
+
63
+ def _onedal_partial_fit(self, X, queue):
64
+ onedal_params = {
65
+ "method": "dense",
66
+ "bias": True,
67
+ }
68
+ if not hasattr(self, "_onedal_estimator"):
69
+ self._onedal_estimator = self._onedal_incremental_covariance(**onedal_params)
70
+ self._onedal_estimator.partial_fit(X, queue)
71
+ self._need_to_finalize = True
72
+
73
+ @property
74
+ def covariance_(self):
75
+ if self._need_to_finalize:
76
+ self._onedal_finalize_fit()
77
+ return self._onedal_estimator.covariance_
78
+
79
+ @property
80
+ def location_(self):
81
+ if self._need_to_finalize:
82
+ self._onedal_finalize_fit()
83
+ return self._onedal_estimator.location_
84
+
85
+ @support_usm_ndarray()
86
+ def partial_fit(self, X, queue=None):
87
+ """
88
+ Incremental fit with X. All of X is processed as a single batch.
89
+
90
+ Parameters
91
+ ----------
92
+ X : array-like of shape (n_samples, n_features)
93
+ Training data, where `n_samples` is the number of samples and
94
+ `n_features` is the number of features.
95
+
96
+ Returns
97
+ -------
98
+ self : object
99
+ Returns the instance itself.
100
+ """
101
+ X = check_array(X, dtype=[np.float64, np.float32])
102
+ self._onedal_partial_fit(X, queue)
103
+ return self
104
+
105
+ def fit(self, X, queue=None):
106
+ """
107
+ Fit the model with X, using minibatches of size batch_size.
108
+
109
+ Parameters
110
+ ----------
111
+ X : array-like of shape (n_samples, n_features)
112
+ Training data, where `n_samples` is the number of samples and
113
+ `n_features` is the number of features.
114
+
115
+ Returns
116
+ -------
117
+ self : object
118
+ Returns the instance itself.
119
+ """
120
+ n_samples, n_features = X.shape
121
+ if self.batch_size is None:
122
+ batch_size_ = 5 * n_features
123
+ else:
124
+ batch_size_ = self.batch_size
125
+ for batch in gen_batches(n_samples, batch_size_):
126
+ X_batch = X[batch]
127
+ self.partial_fit(X_batch, queue=queue)
128
+
129
+ self._onedal_finalize_fit()
130
+ return self
@@ -0,0 +1,143 @@
1
+ # ===============================================================================
2
+ # Copyright 2024 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import numpy as np
18
+ import pytest
19
+ from numpy.testing import assert_allclose
20
+
21
+ from onedal.tests.utils._dataframes_support import (
22
+ _convert_to_dataframe,
23
+ get_dataframes_and_queues,
24
+ )
25
+
26
+
27
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
28
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
29
+ def test_sklearnex_partial_fit_on_gold_data(dataframe, queue, dtype):
30
+ from sklearnex.covariance import IncrementalEmpiricalCovariance
31
+
32
+ X = np.array([[0, 1], [0, 1]])
33
+ X = X.astype(dtype)
34
+ X_split = np.array_split(X, 2)
35
+ inccov = IncrementalEmpiricalCovariance()
36
+
37
+ for i in range(2):
38
+ X_split_df = _convert_to_dataframe(
39
+ X_split[i], sycl_queue=queue, target_df=dataframe
40
+ )
41
+ result = inccov.partial_fit(X_split_df)
42
+
43
+ expected_covariance = np.array([[0, 0], [0, 0]])
44
+ expected_means = np.array([0, 1])
45
+
46
+ assert_allclose(expected_covariance, result.covariance_)
47
+ assert_allclose(expected_means, result.location_)
48
+
49
+ X = np.array([[1, 2], [3, 6]])
50
+ X = X.astype(dtype)
51
+ X_split = np.array_split(X, 2)
52
+ inccov = IncrementalEmpiricalCovariance()
53
+
54
+ for i in range(2):
55
+ X_split_df = _convert_to_dataframe(
56
+ X_split[i], sycl_queue=queue, target_df=dataframe
57
+ )
58
+ result = inccov.partial_fit(X_split_df)
59
+
60
+ expected_covariance = np.array([[1, 2], [2, 4]])
61
+ expected_means = np.array([2, 4])
62
+
63
+ assert_allclose(expected_covariance, result.covariance_)
64
+ assert_allclose(expected_means, result.location_)
65
+
66
+
67
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
68
+ @pytest.mark.parametrize("batch_size", [2, 4])
69
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
70
+ def test_sklearnex_fit_on_gold_data(dataframe, queue, batch_size, dtype):
71
+ from sklearnex.covariance import IncrementalEmpiricalCovariance
72
+
73
+ X = np.array([[0, 1, 2, 3], [0, -1, -2, -3], [0, 1, 2, 3], [0, 1, 2, 3]])
74
+ X = X.astype(dtype)
75
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
76
+ inccov = IncrementalEmpiricalCovariance(batch_size=batch_size)
77
+
78
+ result = inccov.fit(X_df)
79
+
80
+ expected_covariance = np.array(
81
+ [[0, 0, 0, 0], [0, 0.75, 1.5, 2.25], [0, 1.5, 3, 4.5], [0, 2.25, 4.5, 6.75]]
82
+ )
83
+ expected_means = np.array([0, 0.5, 1, 1.5])
84
+
85
+ assert_allclose(expected_covariance, result.covariance_)
86
+ assert_allclose(expected_means, result.location_)
87
+
88
+
89
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
90
+ @pytest.mark.parametrize("num_batches", [2, 4, 6, 8, 10])
91
+ @pytest.mark.parametrize("row_count", [100, 1000, 2000])
92
+ @pytest.mark.parametrize("column_count", [10, 100, 200])
93
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
94
+ def test_sklearnex_partial_fit_on_random_data(
95
+ dataframe, queue, num_batches, row_count, column_count, dtype
96
+ ):
97
+ from sklearnex.covariance import IncrementalEmpiricalCovariance
98
+
99
+ seed = 77
100
+ gen = np.random.default_rng(seed)
101
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
102
+ X = X.astype(dtype)
103
+ X_split = np.array_split(X, num_batches)
104
+ inccov = IncrementalEmpiricalCovariance()
105
+
106
+ for i in range(num_batches):
107
+ X_split_df = _convert_to_dataframe(
108
+ X_split[i], sycl_queue=queue, target_df=dataframe
109
+ )
110
+ result = inccov.partial_fit(X_split_df)
111
+
112
+ expected_covariance = np.cov(X.T, bias=1)
113
+ expected_means = np.mean(X, axis=0)
114
+
115
+ assert_allclose(expected_covariance, result.covariance_, atol=1e-6)
116
+ assert_allclose(expected_means, result.location_, atol=1e-6)
117
+
118
+
119
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
120
+ @pytest.mark.parametrize("num_batches", [2, 4, 6, 8, 10])
121
+ @pytest.mark.parametrize("row_count", [100, 1000, 2000])
122
+ @pytest.mark.parametrize("column_count", [10, 100, 200])
123
+ @pytest.mark.parametrize("dtype", [np.float32, np.float64])
124
+ def test_sklearnex_fit_on_random_data(
125
+ dataframe, queue, num_batches, row_count, column_count, dtype
126
+ ):
127
+ from sklearnex.covariance import IncrementalEmpiricalCovariance
128
+
129
+ seed = 77
130
+ gen = np.random.default_rng(seed)
131
+ X = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
132
+ X = X.astype(dtype)
133
+ X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
134
+ batch_size = row_count // num_batches
135
+ inccov = IncrementalEmpiricalCovariance(batch_size=batch_size)
136
+
137
+ result = inccov.fit(X_df)
138
+
139
+ expected_covariance = np.cov(X.T, bias=1)
140
+ expected_means = np.mean(X, axis=0)
141
+
142
+ assert_allclose(expected_covariance, result.covariance_, atol=1e-6)
143
+ assert_allclose(expected_means, result.location_, atol=1e-6)
@@ -69,6 +69,7 @@ def get_patch_map():
69
69
  from .ensemble import RandomForestClassifier as RandomForestClassifier_sklearnex
70
70
  from .ensemble import RandomForestRegressor as RandomForestRegressor_sklearnex
71
71
  from .linear_model import LinearRegression as LinearRegression_sklearnex
72
+ from .linear_model import LogisticRegression as LogisticRegression_sklearnex
72
73
  from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex
73
74
  from .neighbors import KNeighborsRegressor as KNeighborsRegressor_sklearnex
74
75
  from .neighbors import LocalOutlierFactor as LocalOutlierFactor_sklearnex
@@ -80,9 +81,6 @@ def get_patch_map():
80
81
  EmpiricalCovariance as EmpiricalCovariance_sklearnex,
81
82
  )
82
83
  from .preview.decomposition import PCA as PCA_sklearnex
83
- from .preview.linear_model import (
84
- LogisticRegression as LogisticRegression_sklearnex,
85
- )
86
84
  from .svm import SVC as SVC_sklearnex
87
85
  from .svm import SVR as SVR_sklearnex
88
86
  from .svm import NuSVC as NuSVC_sklearnex
@@ -119,21 +117,6 @@ def get_patch_map():
119
117
  ]
120
118
  ]
121
119
 
122
- # LogisticRegression
123
- mapping.pop("logisticregression")
124
- mapping.pop("log_reg")
125
- mapping["log_reg"] = [
126
- [
127
- (
128
- linear_model_module,
129
- "LogisticRegression",
130
- LogisticRegression_sklearnex,
131
- ),
132
- None,
133
- ]
134
- ]
135
- mapping["logisticregression"] = mapping["log_reg"]
136
-
137
120
  # DBSCAN
138
121
  mapping.pop("dbscan")
139
122
  mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]]
@@ -161,6 +144,24 @@ def get_patch_map():
161
144
  ]
162
145
  mapping["linearregression"] = mapping["linear"]
163
146
 
147
+ # Logistic Regression
148
+
149
+ mapping.pop("logisticregression")
150
+ mapping.pop("log_reg")
151
+ mapping.pop("logistic")
152
+ mapping.pop("_logistic_regression_path")
153
+ mapping["log_reg"] = [
154
+ [
155
+ (
156
+ linear_model_module,
157
+ "LogisticRegression",
158
+ LogisticRegression_sklearnex,
159
+ ),
160
+ None,
161
+ ]
162
+ ]
163
+ mapping["logisticregression"] = mapping["log_reg"]
164
+
164
165
  # kNN
165
166
  mapping.pop("knn_classifier")
166
167
  mapping.pop("kneighborsclassifier")
@@ -42,11 +42,10 @@ from sklearn.utils.validation import (
42
42
  check_X_y,
43
43
  )
44
44
 
45
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
45
46
  from daal4py.sklearn._utils import (
46
47
  check_tree_nodes,
47
- control_n_jobs,
48
48
  daal_check_version,
49
- run_with_n_jobs,
50
49
  sklearn_check_version,
51
50
  )
52
51
  from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
@@ -78,7 +77,6 @@ if sklearn_check_version("1.4"):
78
77
  class BaseForest(ABC):
79
78
  _onedal_factory = None
80
79
 
81
- @run_with_n_jobs
82
80
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
83
81
  if sklearn_check_version("0.24"):
84
82
  X, y = self._validate_data(
@@ -787,7 +785,6 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
787
785
 
788
786
  return patching_status
789
787
 
790
- @run_with_n_jobs
791
788
  def _onedal_predict(self, X, queue=None):
792
789
  X = check_array(
793
790
  X,
@@ -802,7 +799,6 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
802
799
  res = self._onedal_estimator.predict(X, queue=queue)
803
800
  return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
804
801
 
805
- @run_with_n_jobs
806
802
  def _onedal_predict_proba(self, X, queue=None):
807
803
  X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
808
804
  check_is_fitted(self, "_onedal_estimator")
@@ -1096,7 +1092,6 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1096
1092
 
1097
1093
  return patching_status
1098
1094
 
1099
- @run_with_n_jobs
1100
1095
  def _onedal_predict(self, X, queue=None):
1101
1096
  X = check_array(
1102
1097
  X, dtype=[np.float64, np.float32], force_all_finite=False
@@ -1138,7 +1133,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1138
1133
  predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1139
1134
 
1140
1135
 
1141
- @control_n_jobs
1136
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1142
1137
  class RandomForestClassifier(ForestClassifier):
1143
1138
  __doc__ = sklearn_RandomForestClassifier.__doc__
1144
1139
  _onedal_factory = onedal_RandomForestClassifier
@@ -1348,7 +1343,7 @@ class RandomForestClassifier(ForestClassifier):
1348
1343
  self.min_bin_size = min_bin_size
1349
1344
 
1350
1345
 
1351
- @control_n_jobs
1346
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1352
1347
  class RandomForestRegressor(ForestRegressor):
1353
1348
  __doc__ = sklearn_RandomForestRegressor.__doc__
1354
1349
  _onedal_factory = onedal_RandomForestRegressor
@@ -1549,7 +1544,7 @@ class RandomForestRegressor(ForestRegressor):
1549
1544
  self.min_bin_size = min_bin_size
1550
1545
 
1551
1546
 
1552
- @control_n_jobs
1547
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1553
1548
  class ExtraTreesClassifier(ForestClassifier):
1554
1549
  __doc__ = sklearn_ExtraTreesClassifier.__doc__
1555
1550
  _onedal_factory = onedal_ExtraTreesClassifier
@@ -1759,7 +1754,7 @@ class ExtraTreesClassifier(ForestClassifier):
1759
1754
  self.min_bin_size = min_bin_size
1760
1755
 
1761
1756
 
1762
- @control_n_jobs
1757
+ @control_n_jobs(decorated_methods=["fit", "predict"])
1763
1758
  class ExtraTreesRegressor(ForestRegressor):
1764
1759
  __doc__ = sklearn_ExtraTreesRegressor.__doc__
1765
1760
  _onedal_factory = onedal_ExtraTreesRegressor
@@ -16,14 +16,13 @@
16
16
 
17
17
  from .coordinate_descent import ElasticNet, Lasso
18
18
  from .linear import LinearRegression
19
- from .logistic_path import LogisticRegression, logistic_regression_path
19
+ from .logistic_regression import LogisticRegression
20
20
  from .ridge import Ridge
21
21
 
22
22
  __all__ = [
23
23
  "Ridge",
24
24
  "LinearRegression",
25
25
  "LogisticRegression",
26
- "logistic_regression_path",
27
26
  "ElasticNet",
28
27
  "Lasso",
29
28
  ]
@@ -65,13 +65,8 @@ if daal_check_version((2023, "P", 100)):
65
65
  import numpy as np
66
66
  from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
67
67
 
68
- from daal4py.sklearn._utils import (
69
- control_n_jobs,
70
- get_dtype,
71
- make2d,
72
- run_with_n_jobs,
73
- sklearn_check_version,
74
- )
68
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
69
+ from daal4py.sklearn._utils import get_dtype, make2d, sklearn_check_version
75
70
 
76
71
  from .._device_offload import dispatch, wrap_output_data
77
72
  from .._utils import (
@@ -93,7 +88,7 @@ if daal_check_version((2023, "P", 100)):
93
88
  from onedal.utils import _num_features, _num_samples
94
89
 
95
90
  @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
96
- @control_n_jobs
91
+ @control_n_jobs(decorated_methods=["fit", "predict"])
97
92
  class LinearRegression(sklearn_LinearRegression, BaseLinearRegression):
98
93
  __doc__ = sklearn_LinearRegression.__doc__
99
94
  intercept_, coef_ = None, None
@@ -330,7 +325,6 @@ if daal_check_version((2023, "P", 100)):
330
325
  onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
331
326
  self._onedal_estimator = onedal_LinearRegression(**onedal_params)
332
327
 
333
- @run_with_n_jobs
334
328
  def _onedal_fit(self, X, y, sample_weight, queue=None):
335
329
  assert sample_weight is None
336
330
 
@@ -369,7 +363,6 @@ if daal_check_version((2023, "P", 100)):
369
363
  del self._onedal_estimator
370
364
  super().fit(X, y)
371
365
 
372
- @run_with_n_jobs
373
366
  def _onedal_predict(self, X, queue=None):
374
367
  X = self._validate_data(X, accept_sparse=False, reset=False)
375
368
  if not hasattr(self, "_onedal_estimator"):
@@ -1,5 +1,5 @@
1
1
  # ===============================================================================
2
- # Copyright 2023 Intel Corporation
2
+ # Copyright 2024 Intel Corporation
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,14 +17,11 @@
17
17
  import logging
18
18
  from abc import ABC
19
19
 
20
- import sklearn.linear_model._logistic as logistic_module
21
-
22
20
  from daal4py.sklearn._utils import daal_check_version
23
21
  from daal4py.sklearn.linear_model.logistic_path import (
24
- LogisticRegression,
25
- daal4py_predict,
26
- logistic_regression_path,
22
+ LogisticRegression as LogisticRegression_daal4py,
27
23
  )
24
+ from daal4py.sklearn.linear_model.logistic_path import daal4py_fit, daal4py_predict
28
25
 
29
26
 
30
27
  class BaseLogisticRegression(ABC):
@@ -43,14 +40,18 @@ if daal_check_version((2024, "P", 1)):
43
40
  from sklearn.linear_model import LogisticRegression as sklearn_LogisticRegression
44
41
  from sklearn.utils.validation import check_X_y
45
42
 
43
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
46
44
  from daal4py.sklearn._utils import sklearn_check_version
47
45
  from onedal.linear_model import LogisticRegression as onedal_LogisticRegression
48
46
  from onedal.utils import _num_features, _num_samples
49
47
 
50
- from ..._device_offload import dispatch, wrap_output_data
51
- from ..._utils import PatchingConditionsChain, get_patch_message
52
- from ...utils.validation import _assert_all_finite
48
+ from .._device_offload import dispatch, wrap_output_data
49
+ from .._utils import PatchingConditionsChain, get_patch_message
50
+ from ..utils.validation import _assert_all_finite
53
51
 
52
+ @control_n_jobs(
53
+ decorated_methods=["fit", "predict", "predict_proba", "predict_log_proba"]
54
+ )
54
55
  class LogisticRegression(sklearn_LogisticRegression, BaseLogisticRegression):
55
56
  __doc__ = sklearn_LogisticRegression.__doc__
56
57
  intercept_, coef_, n_iter_ = None, None, None
@@ -97,6 +98,8 @@ if daal_check_version((2024, "P", 1)):
97
98
  l1_ratio=l1_ratio,
98
99
  )
99
100
 
101
+ _onedal_cpu_fit = daal4py_fit
102
+
100
103
  def fit(self, X, y, sample_weight=None):
101
104
  if sklearn_check_version("1.0"):
102
105
  self._check_feature_names(X, reset=True)
@@ -160,10 +163,8 @@ if daal_check_version((2024, "P", 1)):
160
163
  def _test_type_and_finiteness(self, X_in):
161
164
  X = np.asarray(X_in)
162
165
 
163
- dtype = X.dtype
164
- if "complex" in str(type(dtype)):
166
+ if np.iscomplexobj(X):
165
167
  return False
166
-
167
168
  try:
168
169
  _assert_all_finite(X)
169
170
  except BaseException:
@@ -268,15 +269,6 @@ if daal_check_version((2024, "P", 1)):
268
269
  }
269
270
  self._onedal_estimator = onedal_LogisticRegression(**onedal_params)
270
271
 
271
- def _onedal_cpu_fit(self, X, y, sample_weight):
272
- which, what = logistic_module, "_logistic_regression_path"
273
- replacer = logistic_regression_path
274
- descriptor = getattr(which, what, None)
275
- setattr(which, what, replacer)
276
- clf = super().fit(X, y, sample_weight)
277
- setattr(which, what, descriptor)
278
- return clf
279
-
280
272
  def _onedal_fit(self, X, y, sample_weight, queue=None):
281
273
  if queue is None or queue.sycl_device.is_cpu:
282
274
  return self._onedal_cpu_fit(X, y, sample_weight)
@@ -313,38 +305,27 @@ if daal_check_version((2024, "P", 1)):
313
305
  return daal4py_predict(self, X, "computeClassLabels")
314
306
 
315
307
  X = self._validate_data(X, accept_sparse=False, reset=False)
316
- if not hasattr(self, "_onedal_estimator"):
317
- self._initialize_onedal_estimator()
318
- self._onedal_estimator.coef_ = self.coef_
319
- self._onedal_estimator.intercept_ = self.intercept_
320
- self._onedal_estimator.classes_ = self.classes_
321
-
308
+ assert hasattr(self, "_onedal_estimator")
322
309
  return self._onedal_estimator.predict(X, queue=queue)
323
310
 
324
311
  def _onedal_predict_proba(self, X, queue=None):
325
312
  if queue is None or queue.sycl_device.is_cpu:
326
313
  return daal4py_predict(self, X, "computeClassProbabilities")
327
- X = self._validate_data(X, accept_sparse=False, reset=False)
328
- if not hasattr(self, "_onedal_estimator"):
329
- self._initialize_onedal_estimator()
330
- self._onedal_estimator.coef_ = self.coef_
331
- self._onedal_estimator.intercept_ = self.intercept_
332
314
 
315
+ X = self._validate_data(X, accept_sparse=False, reset=False)
316
+ assert hasattr(self, "_onedal_estimator")
333
317
  return self._onedal_estimator.predict_proba(X, queue=queue)
334
318
 
335
319
  def _onedal_predict_log_proba(self, X, queue=None):
336
320
  if queue is None or queue.sycl_device.is_cpu:
337
321
  return daal4py_predict(self, X, "computeClassLogProbabilities")
338
- X = self._validate_data(X, accept_sparse=False, reset=False)
339
- if not hasattr(self, "_onedal_estimator"):
340
- self._initialize_onedal_estimator()
341
- self._onedal_estimator.coef_ = self.coef_
342
- self._onedal_estimator.intercept_ = self.intercept_
343
322
 
323
+ X = self._validate_data(X, accept_sparse=False, reset=False)
324
+ assert hasattr(self, "_onedal_estimator")
344
325
  return self._onedal_estimator.predict_log_proba(X, queue=queue)
345
326
 
346
327
  else:
347
- from daal4py.sklearn.linear_model import LogisticRegression
328
+ LogisticRegression = LogisticRegression_daal4py
348
329
 
349
330
  logging.warning(
350
331
  "Sklearnex LogisticRegression requires oneDAL version >= 2024.0.1 "