scikit-learn-intelex 2024.2.0__py39-none-win_amd64.whl → 2024.4.0__py39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (112) hide show
  1. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/__init__.py +9 -7
  2. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_device_offload.py +31 -4
  3. {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/spmd → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex}/basic_statistics/__init__.py +2 -1
  4. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/incremental_basic_statistics.py +288 -0
  5. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/basic_statistics/tests/test_incremental_basic_statistics.py +386 -0
  6. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/dbscan.py +3 -1
  7. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/conftest.py +63 -0
  8. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +335 -0
  9. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +22 -8
  10. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/dispatcher.py +74 -43
  11. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/_forest.py +78 -89
  12. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +15 -19
  13. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +316 -0
  14. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_regression.py +63 -11
  15. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +40 -5
  16. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +0 -2
  17. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/_lof.py +74 -20
  18. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +4 -1
  19. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +44 -131
  20. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +198 -221
  21. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +146 -0
  22. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -5
  23. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/__init__.py +1 -1
  24. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +5 -73
  25. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/covariance.py +6 -5
  26. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/tests/test_covariance.py +18 -5
  27. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +4 -12
  28. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/_common.py +4 -7
  29. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +70 -50
  30. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +6 -52
  31. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/svc.py +70 -51
  32. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/svr.py +3 -49
  33. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/_utils.py +164 -0
  34. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +8 -3
  35. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +268 -0
  36. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_n_jobs_support.py +8 -2
  37. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_parallel.py +6 -8
  38. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +371 -0
  39. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +2 -1
  40. scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/utils/_namespace.py +97 -0
  41. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/METADATA +2 -2
  42. scikit_learn_intelex-2024.4.0.dist-info/RECORD +101 -0
  43. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/pca.py +0 -17
  44. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +0 -27
  45. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/linear_model/linear.py +0 -381
  46. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -308
  47. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +0 -19
  48. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +0 -374
  49. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -170
  50. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +0 -240
  51. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/test_patching.py +0 -136
  52. scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +0 -118
  53. scikit_learn_intelex-2024.2.0.dist-info/RECORD +0 -101
  54. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/__main__.py +0 -0
  55. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_config.py +0 -0
  56. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/_utils.py +0 -0
  57. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +0 -0
  58. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +0 -0
  59. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +0 -0
  60. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +0 -0
  61. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +0 -0
  62. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/__init__.py +0 -0
  63. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/incremental_covariance.py +0 -0
  64. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/covariance/tests/test_incremental_covariance.py +0 -0
  65. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +0 -0
  66. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  67. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -0
  68. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +0 -0
  69. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +0 -0
  70. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +0 -0
  71. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +0 -0
  72. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +0 -0
  73. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +0 -0
  74. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +0 -0
  75. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +0 -0
  76. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +0 -0
  77. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +0 -0
  78. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +0 -0
  79. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +0 -0
  80. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +0 -0
  81. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +0 -0
  82. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +0 -0
  83. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +0 -0
  84. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +0 -0
  85. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +0 -0
  86. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +0 -0
  87. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/preview/covariance/__init__.py +0 -0
  88. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -0
  89. {scikit_learn_intelex-2024.2.0.data/data/Lib/site-packages/sklearnex → scikit_learn_intelex-2024.4.0.data/data/Lib/site-packages/sklearnex/spmd}/basic_statistics/__init__.py +0 -0
  90. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +0 -0
  91. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +0 -0
  92. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +0 -0
  93. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +0 -0
  94. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/__init__.py +0 -0
  95. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/covariance/covariance.py +0 -0
  96. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +0 -0
  97. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +0 -0
  98. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +0 -0
  99. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +0 -0
  100. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +0 -0
  101. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/logistic_regression.py +0 -0
  102. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +0 -0
  103. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +0 -0
  104. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +0 -0
  105. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -0
  106. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +0 -0
  107. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -0
  108. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/parallel.py +0 -0
  109. {scikit_learn_intelex-2024.2.0.data → scikit_learn_intelex-2024.4.0.data}/data/Lib/site-packages/sklearnex/utils/validation.py +0 -0
  110. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/LICENSE.txt +0 -0
  111. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/WHEEL +0 -0
  112. {scikit_learn_intelex-2024.2.0.dist-info → scikit_learn_intelex-2024.4.0.dist-info}/top_level.txt +0 -0
@@ -25,8 +25,11 @@ from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifie
25
25
  from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
26
26
  from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
27
27
  from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
28
+ from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
29
+ from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
28
30
  from sklearn.ensemble._forest import _get_n_samples_bootstrap
29
31
  from sklearn.exceptions import DataConversionWarning
32
+ from sklearn.metrics import accuracy_score
30
33
  from sklearn.tree import (
31
34
  DecisionTreeClassifier,
32
35
  DecisionTreeRegressor,
@@ -35,12 +38,7 @@ from sklearn.tree import (
35
38
  )
36
39
  from sklearn.tree._tree import Tree
37
40
  from sklearn.utils import check_random_state, deprecated
38
- from sklearn.utils.validation import (
39
- check_array,
40
- check_consistent_length,
41
- check_is_fitted,
42
- check_X_y,
43
- )
41
+ from sklearn.utils.validation import check_array, check_is_fitted
44
42
 
45
43
  from daal4py.sklearn._n_jobs_support import control_n_jobs
46
44
  from daal4py.sklearn._utils import (
@@ -52,19 +50,10 @@ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
52
50
  from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
53
51
  from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
54
52
  from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
55
-
56
- # try catch needed for changes in structures observed in Scikit-learn around v0.22
57
- try:
58
- from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
59
- from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
60
- except ModuleNotFoundError:
61
- from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
62
- from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
63
-
64
53
  from onedal.primitives import get_tree_state_cls, get_tree_state_reg
65
54
  from onedal.utils import _num_features, _num_samples
55
+ from sklearnex.utils import get_namespace
66
56
 
67
- from .._config import get_config
68
57
  from .._device_offload import dispatch, wrap_output_data
69
58
  from .._utils import PatchingConditionsChain
70
59
 
@@ -78,24 +67,14 @@ class BaseForest(ABC):
78
67
  _onedal_factory = None
79
68
 
80
69
  def _onedal_fit(self, X, y, sample_weight=None, queue=None):
81
- if sklearn_check_version("0.24"):
82
- X, y = self._validate_data(
83
- X,
84
- y,
85
- multi_output=False,
86
- accept_sparse=False,
87
- dtype=[np.float64, np.float32],
88
- force_all_finite=False,
89
- )
90
- else:
91
- X, y = check_X_y(
92
- X,
93
- y,
94
- accept_sparse=False,
95
- dtype=[np.float64, np.float32],
96
- multi_output=False,
97
- force_all_finite=False,
98
- )
70
+ X, y = self._validate_data(
71
+ X,
72
+ y,
73
+ multi_output=False,
74
+ accept_sparse=False,
75
+ dtype=[np.float64, np.float32],
76
+ force_all_finite=False,
77
+ )
99
78
 
100
79
  if sample_weight is not None:
101
80
  sample_weight = self.check_sample_weight(sample_weight, X)
@@ -173,15 +152,6 @@ class BaseForest(ABC):
173
152
 
174
153
  return self
175
154
 
176
- def _fit_proba(self, X, y, sample_weight=None, queue=None):
177
- params = self.get_params()
178
- self.__class__(**params)
179
-
180
- # We use stock metaestimators below, so the only way
181
- # to pass a queue is using config_context.
182
- cfg = get_config()
183
- cfg["target_offload"] = queue
184
-
185
155
  def _save_attributes(self):
186
156
  if self.oob_score:
187
157
  self.oob_score_ = self._onedal_estimator.oob_score_
@@ -204,8 +174,6 @@ class BaseForest(ABC):
204
174
  self._validate_estimator()
205
175
  return self
206
176
 
207
- # TODO:
208
- # move to onedal modul.
209
177
  def _check_parameters(self):
210
178
  if isinstance(self.min_samples_leaf, numbers.Integral):
211
179
  if not 1 <= self.min_samples_leaf:
@@ -453,14 +421,12 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
453
421
 
454
422
  # The estimator is checked against the class attribute for conformance.
455
423
  # This should only trigger if the user uses this class directly.
456
- if (
457
- self.estimator.__class__ == DecisionTreeClassifier
458
- and self._onedal_factory != onedal_RandomForestClassifier
424
+ if self.estimator.__class__ == DecisionTreeClassifier and not issubclass(
425
+ self._onedal_factory, onedal_RandomForestClassifier
459
426
  ):
460
427
  self._onedal_factory = onedal_RandomForestClassifier
461
- elif (
462
- self.estimator.__class__ == ExtraTreeClassifier
463
- and self._onedal_factory != onedal_ExtraTreesClassifier
428
+ elif self.estimator.__class__ == ExtraTreeClassifier and not issubclass(
429
+ self._onedal_factory, onedal_ExtraTreesClassifier
464
430
  ):
465
431
  self._onedal_factory = onedal_ExtraTreesClassifier
466
432
 
@@ -552,18 +518,14 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
552
518
  )
553
519
 
554
520
  if patching_status.get_status():
555
- if sklearn_check_version("0.24"):
556
- X, y = self._validate_data(
557
- X,
558
- y,
559
- multi_output=True,
560
- accept_sparse=True,
561
- dtype=[np.float64, np.float32],
562
- force_all_finite=False,
563
- )
564
- else:
565
- X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
566
- y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
521
+ X, y = self._validate_data(
522
+ X,
523
+ y,
524
+ multi_output=True,
525
+ accept_sparse=True,
526
+ dtype=[np.float64, np.float32],
527
+ force_all_finite=False,
528
+ )
567
529
 
568
530
  if y.ndim == 2 and y.shape[1] == 1:
569
531
  warnings.warn(
@@ -657,9 +619,38 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
657
619
  X,
658
620
  )
659
621
 
622
+ def predict_log_proba(self, X):
623
+ xp, _ = get_namespace(X)
624
+ proba = self.predict_proba(X)
625
+
626
+ if self.n_outputs_ == 1:
627
+ return xp.log(proba)
628
+
629
+ else:
630
+ for k in range(self.n_outputs_):
631
+ proba[k] = xp.log(proba[k])
632
+
633
+ return proba
634
+
635
+ @wrap_output_data
636
+ def score(self, X, y, sample_weight=None):
637
+ return dispatch(
638
+ self,
639
+ "score",
640
+ {
641
+ "onedal": self.__class__._onedal_score,
642
+ "sklearn": sklearn_ForestClassifier.score,
643
+ },
644
+ X,
645
+ y,
646
+ sample_weight=sample_weight,
647
+ )
648
+
660
649
  fit.__doc__ = sklearn_ForestClassifier.fit.__doc__
661
650
  predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
662
651
  predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
652
+ predict_log_proba.__doc__ = sklearn_ForestClassifier.predict_log_proba.__doc__
653
+ score.__doc__ = sklearn_ForestClassifier.score.__doc__
663
654
 
664
655
  def _onedal_cpu_supported(self, method_name, *data):
665
656
  class_name = self.__class__.__name__
@@ -686,7 +677,7 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
686
677
  ]
687
678
  )
688
679
 
689
- elif method_name in ["predict", "predict_proba"]:
680
+ elif method_name in ["predict", "predict_proba", "score"]:
690
681
  X = data[0]
691
682
 
692
683
  patching_status.and_conditions(
@@ -747,11 +738,11 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
747
738
  or self.estimator.__class__ == DecisionTreeClassifier,
748
739
  "ExtraTrees only supported starting from oneDAL version 2023.1",
749
740
  ),
750
- (sample_weight is not None, "sample_weight is not supported."),
741
+ (sample_weight is None, "sample_weight is not supported."),
751
742
  ]
752
743
  )
753
744
 
754
- elif method_name in ["predict", "predict_proba"]:
745
+ elif method_name in ["predict", "predict_proba", "score"]:
755
746
  X = data[0]
756
747
 
757
748
  patching_status.and_conditions(
@@ -803,12 +794,16 @@ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
803
794
  X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
804
795
  check_is_fitted(self, "_onedal_estimator")
805
796
 
806
- if sklearn_check_version("0.23"):
807
- self._check_n_features(X, reset=False)
797
+ self._check_n_features(X, reset=False)
808
798
  if sklearn_check_version("1.0"):
809
799
  self._check_feature_names(X, reset=False)
810
800
  return self._onedal_estimator.predict_proba(X, queue=queue)
811
801
 
802
+ def _onedal_score(self, X, y, sample_weight=None, queue=None):
803
+ return accuracy_score(
804
+ y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight
805
+ )
806
+
812
807
 
813
808
  class ForestRegressor(sklearn_ForestRegressor, BaseForest):
814
809
  _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
@@ -843,14 +838,12 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
843
838
 
844
839
  # The splitter is checked against the class attribute for conformance
845
840
  # This should only trigger if the user uses this class directly.
846
- if (
847
- self.estimator.__class__ == DecisionTreeRegressor
848
- and self._onedal_factory != onedal_RandomForestRegressor
841
+ if self.estimator.__class__ == DecisionTreeRegressor and not issubclass(
842
+ self._onedal_factory, onedal_RandomForestRegressor
849
843
  ):
850
844
  self._onedal_factory = onedal_RandomForestRegressor
851
- elif (
852
- self.estimator.__class__ == ExtraTreeRegressor
853
- and self._onedal_factory != onedal_ExtraTreesRegressor
845
+ elif self.estimator.__class__ == ExtraTreeRegressor and not issubclass(
846
+ self._onedal_factory, onedal_ExtraTreesRegressor
854
847
  ):
855
848
  self._onedal_factory = onedal_ExtraTreesRegressor
856
849
 
@@ -920,18 +913,14 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
920
913
  )
921
914
 
922
915
  if patching_status.get_status():
923
- if sklearn_check_version("0.24"):
924
- X, y = self._validate_data(
925
- X,
926
- y,
927
- multi_output=True,
928
- accept_sparse=True,
929
- dtype=[np.float64, np.float32],
930
- force_all_finite=False,
931
- )
932
- else:
933
- X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
934
- y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
916
+ X, y = self._validate_data(
917
+ X,
918
+ y,
919
+ multi_output=True,
920
+ accept_sparse=True,
921
+ dtype=[np.float64, np.float32],
922
+ force_all_finite=False,
923
+ )
935
924
 
936
925
  if y.ndim == 2 and y.shape[1] == 1:
937
926
  warnings.warn(
@@ -1056,7 +1045,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1056
1045
  or self.estimator.__class__ == DecisionTreeClassifier,
1057
1046
  "ExtraTrees only supported starting from oneDAL version 2023.1",
1058
1047
  ),
1059
- (sample_weight is not None, "sample_weight is not supported."),
1048
+ (sample_weight is None, "sample_weight is not supported."),
1060
1049
  ]
1061
1050
  )
1062
1051
 
@@ -1133,7 +1122,7 @@ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
1133
1122
  predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1134
1123
 
1135
1124
 
1136
- @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1125
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1137
1126
  class RandomForestClassifier(ForestClassifier):
1138
1127
  __doc__ = sklearn_RandomForestClassifier.__doc__
1139
1128
  _onedal_factory = onedal_RandomForestClassifier
@@ -1544,7 +1533,7 @@ class RandomForestRegressor(ForestRegressor):
1544
1533
  self.min_bin_size = min_bin_size
1545
1534
 
1546
1535
 
1547
- @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba"])
1536
+ @control_n_jobs(decorated_methods=["fit", "predict", "predict_proba", "score"])
1548
1537
  class ExtraTreesClassifier(ForestClassifier):
1549
1538
  __doc__ = sklearn_ExtraTreesClassifier.__doc__
1550
1539
  _onedal_factory = onedal_ExtraTreesClassifier
@@ -45,11 +45,7 @@ def test_sklearnex_import_rf_classifier(dataframe, queue):
45
45
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
46
46
 
47
47
 
48
- # TODO:
49
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
50
- @pytest.mark.parametrize(
51
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
52
- )
48
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
53
49
  def test_sklearnex_import_rf_regression(dataframe, queue):
54
50
  from sklearnex.ensemble import RandomForestRegressor
55
51
 
@@ -59,17 +55,17 @@ def test_sklearnex_import_rf_regression(dataframe, queue):
59
55
  rf = RandomForestRegressor(max_depth=2, random_state=0).fit(X, y)
60
56
  assert "sklearnex" in rf.__module__
61
57
  pred = _as_numpy(rf.predict([[0, 0, 0, 0]]))
62
- if daal_check_version((2024, "P", 0)):
63
- assert_allclose([-6.971], pred, atol=1e-2)
58
+
59
+ if queue is not None and queue.sycl_device.is_gpu:
60
+ assert_allclose([-0.011208], pred, atol=1e-2)
64
61
  else:
65
- assert_allclose([-6.839], pred, atol=1e-2)
62
+ if daal_check_version((2024, "P", 0)):
63
+ assert_allclose([-6.971], pred, atol=1e-2)
64
+ else:
65
+ assert_allclose([-6.839], pred, atol=1e-2)
66
66
 
67
67
 
68
- # TODO:
69
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
70
- @pytest.mark.parametrize(
71
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
72
- )
68
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
73
69
  def test_sklearnex_import_et_classifier(dataframe, queue):
74
70
  from sklearnex.ensemble import ExtraTreesClassifier
75
71
 
@@ -90,11 +86,7 @@ def test_sklearnex_import_et_classifier(dataframe, queue):
90
86
  assert_allclose([1], _as_numpy(rf.predict([[0, 0, 0, 0]])))
91
87
 
92
88
 
93
- # TODO:
94
- # investigate failure for `dpnp.ndarrays` and `dpctl.tensors` on `GPU`
95
- @pytest.mark.parametrize(
96
- "dataframe,queue", get_dataframes_and_queues(device_filter_="cpu")
97
- )
89
+ @pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
98
90
  def test_sklearnex_import_et_regression(dataframe, queue):
99
91
  from sklearnex.ensemble import ExtraTreesRegressor
100
92
 
@@ -114,4 +106,8 @@ def test_sklearnex_import_et_regression(dataframe, queue):
114
106
  ]
115
107
  )
116
108
  )
117
- assert_allclose([0.445], pred, atol=1e-2)
109
+
110
+ if queue is not None and queue.sycl_device.is_gpu:
111
+ assert_allclose([1.909769], pred, atol=1e-2)
112
+ else:
113
+ assert_allclose([0.445], pred, atol=1e-2)
@@ -0,0 +1,316 @@
1
+ # ===============================================================================
2
+ # Copyright 2023 Intel Corporation
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ===============================================================================
16
+
17
+ import logging
18
+ from abc import ABC
19
+
20
+ import numpy as np
21
+ from sklearn.exceptions import NotFittedError
22
+ from sklearn.linear_model import LinearRegression as sklearn_LinearRegression
23
+
24
+ from daal4py.sklearn._n_jobs_support import control_n_jobs
25
+ from daal4py.sklearn._utils import sklearn_check_version
26
+
27
+ from .._device_offload import dispatch, wrap_output_data
28
+ from .._utils import PatchingConditionsChain, get_patch_message, register_hyperparameters
29
+ from ..utils.validation import _assert_all_finite
30
+
31
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
32
+ from sklearn.linear_model._base import _deprecate_normalize
33
+
34
+ from scipy.sparse import issparse
35
+ from sklearn.utils.validation import check_X_y
36
+
37
+ from onedal.common.hyperparameters import get_hyperparameters
38
+ from onedal.linear_model import LinearRegression as onedal_LinearRegression
39
+ from onedal.utils import _num_features, _num_samples
40
+
41
+
42
+ @register_hyperparameters({"fit": get_hyperparameters("linear_regression", "train")})
43
+ @control_n_jobs(decorated_methods=["fit", "predict"])
44
+ class LinearRegression(sklearn_LinearRegression):
45
+ __doc__ = sklearn_LinearRegression.__doc__
46
+
47
+ if sklearn_check_version("1.2"):
48
+ _parameter_constraints: dict = {**sklearn_LinearRegression._parameter_constraints}
49
+
50
+ def __init__(
51
+ self,
52
+ fit_intercept=True,
53
+ copy_X=True,
54
+ n_jobs=None,
55
+ positive=False,
56
+ ):
57
+ super().__init__(
58
+ fit_intercept=fit_intercept,
59
+ copy_X=copy_X,
60
+ n_jobs=n_jobs,
61
+ positive=positive,
62
+ )
63
+
64
+ else:
65
+
66
+ def __init__(
67
+ self,
68
+ fit_intercept=True,
69
+ normalize="deprecated" if sklearn_check_version("1.0") else False,
70
+ copy_X=True,
71
+ n_jobs=None,
72
+ positive=False,
73
+ ):
74
+ super().__init__(
75
+ fit_intercept=fit_intercept,
76
+ normalize=normalize,
77
+ copy_X=copy_X,
78
+ n_jobs=n_jobs,
79
+ positive=positive,
80
+ )
81
+
82
+ def fit(self, X, y, sample_weight=None):
83
+ if sklearn_check_version("1.0"):
84
+ self._check_feature_names(X, reset=True)
85
+ if sklearn_check_version("1.2"):
86
+ self._validate_params()
87
+
88
+ # It is necessary to properly update coefs for predict if we
89
+ # fallback to sklearn in dispatch
90
+ if hasattr(self, "_onedal_estimator"):
91
+ del self._onedal_estimator
92
+
93
+ dispatch(
94
+ self,
95
+ "fit",
96
+ {
97
+ "onedal": self.__class__._onedal_fit,
98
+ "sklearn": sklearn_LinearRegression.fit,
99
+ },
100
+ X,
101
+ y,
102
+ sample_weight,
103
+ )
104
+ return self
105
+
106
+ @wrap_output_data
107
+ def predict(self, X):
108
+
109
+ if not hasattr(self, "coef_"):
110
+ msg = (
111
+ "This %(name)s instance is not fitted yet. Call 'fit' with "
112
+ "appropriate arguments before using this estimator."
113
+ )
114
+ raise NotFittedError(msg % {"name": self.__class__.__name__})
115
+
116
+ return dispatch(
117
+ self,
118
+ "predict",
119
+ {
120
+ "onedal": self.__class__._onedal_predict,
121
+ "sklearn": sklearn_LinearRegression.predict,
122
+ },
123
+ X,
124
+ )
125
+
126
+ def _test_type_and_finiteness(self, X_in):
127
+ X = X_in if isinstance(X_in, np.ndarray) else np.asarray(X_in)
128
+
129
+ dtype = X.dtype
130
+ if "complex" in str(type(dtype)):
131
+ return False
132
+
133
+ try:
134
+ _assert_all_finite(X)
135
+ except BaseException:
136
+ return False
137
+ return True
138
+
139
+ def _onedal_fit_supported(self, method_name, *data):
140
+ assert method_name == "fit"
141
+ assert len(data) == 3
142
+ X, y, sample_weight = data
143
+
144
+ class_name = self.__class__.__name__
145
+ patching_status = PatchingConditionsChain(
146
+ f"sklearn.linear_model.{class_name}.fit"
147
+ )
148
+
149
+ normalize_is_set = (
150
+ hasattr(self, "normalize")
151
+ and self.normalize
152
+ and self.normalize != "deprecated"
153
+ )
154
+ positive_is_set = hasattr(self, "positive") and self.positive
155
+
156
+ n_samples = _num_samples(X)
157
+ n_features = _num_features(X, fallback_1d=True)
158
+
159
+ # Check if equations are well defined
160
+ is_good_for_onedal = n_samples >= (n_features + int(self.fit_intercept))
161
+
162
+ dal_ready = patching_status.and_conditions(
163
+ [
164
+ (sample_weight is None, "Sample weight is not supported."),
165
+ (
166
+ not issparse(X) and not issparse(y),
167
+ "Sparse input is not supported.",
168
+ ),
169
+ (not normalize_is_set, "Normalization is not supported."),
170
+ (
171
+ not positive_is_set,
172
+ "Forced positive coefficients are not supported.",
173
+ ),
174
+ (
175
+ is_good_for_onedal,
176
+ "The shape of X (fitting) does not satisfy oneDAL requirements:"
177
+ "Number of features + 1 >= number of samples.",
178
+ ),
179
+ ]
180
+ )
181
+ if not dal_ready:
182
+ return patching_status
183
+
184
+ if not patching_status.and_condition(
185
+ self._test_type_and_finiteness(X), "Input X is not supported."
186
+ ):
187
+ return patching_status
188
+
189
+ patching_status.and_condition(
190
+ self._test_type_and_finiteness(y), "Input y is not supported."
191
+ )
192
+
193
+ return patching_status
194
+
195
+ def _onedal_predict_supported(self, method_name, *data):
196
+ assert method_name == "predict"
197
+ assert len(data) == 1
198
+
199
+ class_name = self.__class__.__name__
200
+ patching_status = PatchingConditionsChain(
201
+ f"sklearn.linear_model.{class_name}.predict"
202
+ )
203
+
204
+ n_samples = _num_samples(*data)
205
+ model_is_sparse = issparse(self.coef_) or (
206
+ self.fit_intercept and issparse(self.intercept_)
207
+ )
208
+ dal_ready = patching_status.and_conditions(
209
+ [
210
+ (n_samples > 0, "Number of samples is less than 1."),
211
+ (not issparse(*data), "Sparse input is not supported."),
212
+ (not model_is_sparse, "Sparse coefficients are not supported."),
213
+ ]
214
+ )
215
+ if not dal_ready:
216
+ return patching_status
217
+
218
+ patching_status.and_condition(
219
+ self._test_type_and_finiteness(*data), "Input X is not supported."
220
+ )
221
+
222
+ return patching_status
223
+
224
+ def _onedal_supported(self, method_name, *data):
225
+ if method_name == "fit":
226
+ return self._onedal_fit_supported(method_name, *data)
227
+ if method_name == "predict":
228
+ return self._onedal_predict_supported(method_name, *data)
229
+ raise RuntimeError(f"Unknown method {method_name} in {self.__class__.__name__}")
230
+
231
+ _onedal_gpu_supported = _onedal_supported
232
+ _onedal_cpu_supported = _onedal_supported
233
+
234
+ def _initialize_onedal_estimator(self):
235
+ onedal_params = {"fit_intercept": self.fit_intercept, "copy_X": self.copy_X}
236
+ self._onedal_estimator = onedal_LinearRegression(**onedal_params)
237
+
238
+ def _onedal_fit(self, X, y, sample_weight, queue=None):
239
+ assert sample_weight is None
240
+
241
+ check_params = {
242
+ "X": X,
243
+ "y": y,
244
+ "dtype": [np.float64, np.float32],
245
+ "accept_sparse": ["csr", "csc", "coo"],
246
+ "y_numeric": True,
247
+ "multi_output": True,
248
+ "force_all_finite": False,
249
+ }
250
+ if sklearn_check_version("1.2"):
251
+ X, y = self._validate_data(**check_params)
252
+ else:
253
+ X, y = check_X_y(**check_params)
254
+
255
+ if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
256
+ self._normalize = _deprecate_normalize(
257
+ self.normalize,
258
+ default=False,
259
+ estimator_name=self.__class__.__name__,
260
+ )
261
+
262
+ self._initialize_onedal_estimator()
263
+ try:
264
+ self._onedal_estimator.fit(X, y, queue=queue)
265
+ self._save_attributes()
266
+
267
+ except RuntimeError:
268
+ logging.getLogger("sklearnex").info(
269
+ f"{self.__class__.__name__}.fit "
270
+ + get_patch_message("sklearn_after_onedal")
271
+ )
272
+
273
+ del self._onedal_estimator
274
+ super().fit(X, y)
275
+
276
+ def _onedal_predict(self, X, queue=None):
277
+ if sklearn_check_version("1.0"):
278
+ self._check_feature_names(X, reset=False)
279
+
280
+ X = self._validate_data(X, accept_sparse=False, reset=False)
281
+ if not hasattr(self, "_onedal_estimator"):
282
+ self._initialize_onedal_estimator()
283
+ self._onedal_estimator.coef_ = self.coef_
284
+ self._onedal_estimator.intercept_ = self.intercept_
285
+
286
+ res = self._onedal_estimator.predict(X, queue=queue)
287
+ return res
288
+
289
+ def get_coef_(self):
290
+ return self.coef_
291
+
292
+ def set_coef_(self, value):
293
+ self.__dict__["coef_"] = value
294
+ if hasattr(self, "_onedal_estimator"):
295
+ self._onedal_estimator.coef_ = value
296
+ del self._onedal_estimator._onedal_model
297
+
298
+ def get_intercept_(self):
299
+ return self.intercept_
300
+
301
+ def set_intercept_(self, value):
302
+ self.__dict__["intercept_"] = value
303
+ if hasattr(self, "_onedal_estimator"):
304
+ self._onedal_estimator.intercept_ = value
305
+ del self._onedal_estimator._onedal_model
306
+
307
+ def _save_attributes(self):
308
+ self.coef_ = property(self.get_coef_, self.set_coef_)
309
+ self.intercept_ = property(self.get_intercept_, self.set_intercept_)
310
+ self.n_features_in_ = self._onedal_estimator.n_features_in_
311
+ self._sparse = False
312
+ self.__dict__["coef_"] = self._onedal_estimator.coef_
313
+ self.__dict__["intercept_"] = self._onedal_estimator.intercept_
314
+
315
+ fit.__doc__ = sklearn_LinearRegression.fit.__doc__
316
+ predict.__doc__ = sklearn_LinearRegression.predict.__doc__