scikit-learn-intelex 2023.2.1__py39-none-win_amd64.whl → 2024.0.1__py39-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (109) hide show
  1. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
  2. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
  3. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
  4. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
  7. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
  8. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
  11. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
  12. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
  13. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
  14. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
  15. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
  16. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
  17. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  20. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
  21. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
  22. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
  23. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
  24. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
  25. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
  26. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
  27. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
  28. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
  29. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
  30. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
  31. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
  32. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
  33. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
  34. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
  35. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
  36. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
  37. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
  38. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
  39. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
  40. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  43. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
  44. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  46. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
  47. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
  48. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
  49. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
  50. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
  51. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
  52. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
  53. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
  54. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
  55. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
  56. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  58. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
  59. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
  60. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
  61. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
  62. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
  63. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
  64. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
  65. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
  66. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
  67. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
  68. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
  69. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
  70. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
  71. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
  72. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  75. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
  76. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
  77. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  79. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  81. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
  82. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  84. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
  85. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  86. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  87. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
  88. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
  89. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  90. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  91. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
  92. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
  93. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
  94. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
  95. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
  96. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
  97. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
  98. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
  99. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
  100. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  101. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
  102. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
  103. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
  104. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
  105. scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
  106. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  107. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  108. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  109. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1947 @@
1
+ #!/usr/bin/env python
2
+ # ==============================================================================
3
+ # Copyright 2021 Intel Corporation
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+
18
+ import numbers
19
+ import warnings
20
+ from abc import ABC
21
+
22
+ import numpy as np
23
+ from scipy import sparse as sp
24
+ from sklearn.base import clone
25
+ from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
26
+ from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
27
+ from sklearn.ensemble import RandomForestClassifier as sklearn_RandomForestClassifier
28
+ from sklearn.ensemble import RandomForestRegressor as sklearn_RandomForestRegressor
29
+ from sklearn.ensemble._forest import _get_n_samples_bootstrap
30
+ from sklearn.exceptions import DataConversionWarning
31
+ from sklearn.tree import (
32
+ DecisionTreeClassifier,
33
+ DecisionTreeRegressor,
34
+ ExtraTreeClassifier,
35
+ ExtraTreeRegressor,
36
+ )
37
+ from sklearn.tree._tree import Tree
38
+ from sklearn.utils import check_random_state, deprecated
39
+ from sklearn.utils.validation import (
40
+ check_array,
41
+ check_consistent_length,
42
+ check_is_fitted,
43
+ check_X_y,
44
+ )
45
+
46
+ from daal4py.sklearn._utils import (
47
+ check_tree_nodes,
48
+ daal_check_version,
49
+ sklearn_check_version,
50
+ )
51
+ from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
52
+ from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
53
+ from onedal.ensemble import RandomForestClassifier as onedal_RandomForestClassifier
54
+ from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
55
+
56
+ # try catch needed for changes in structures observed in Scikit-learn around v0.22
57
+ try:
58
+ from sklearn.ensemble._forest import ForestClassifier as sklearn_ForestClassifier
59
+ from sklearn.ensemble._forest import ForestRegressor as sklearn_ForestRegressor
60
+ except ModuleNotFoundError:
61
+ from sklearn.ensemble.forest import ForestClassifier as sklearn_ForestClassifier
62
+ from sklearn.ensemble.forest import ForestRegressor as sklearn_ForestRegressor
63
+
64
+ from onedal.primitives import get_tree_state_cls, get_tree_state_reg
65
+ from onedal.utils import _num_features, _num_samples
66
+
67
+ from .._config import get_config
68
+ from .._device_offload import dispatch, wrap_output_data
69
+ from .._utils import PatchingConditionsChain
70
+
71
+ if sklearn_check_version("1.2"):
72
+ from sklearn.utils._param_validation import Interval
73
+ if sklearn_check_version("1.4"):
74
+ from daal4py.sklearn.utils import _assert_all_finite
75
+
76
+
77
+ class BaseForest(ABC):
78
+ _onedal_factory = None
79
+
80
+ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
81
+ if sklearn_check_version("0.24"):
82
+ X, y = self._validate_data(
83
+ X,
84
+ y,
85
+ multi_output=False,
86
+ accept_sparse=False,
87
+ dtype=[np.float64, np.float32],
88
+ force_all_finite=False,
89
+ )
90
+ else:
91
+ X, y = check_X_y(
92
+ X,
93
+ y,
94
+ accept_sparse=False,
95
+ dtype=[np.float64, np.float32],
96
+ multi_output=False,
97
+ force_all_finite=False,
98
+ )
99
+
100
+ if sample_weight is not None:
101
+ sample_weight = self.check_sample_weight(sample_weight, X)
102
+
103
+ if y.ndim == 2 and y.shape[1] == 1:
104
+ warnings.warn(
105
+ "A column-vector y was passed when a 1d array was"
106
+ " expected. Please change the shape of y to "
107
+ "(n_samples,), for example using ravel().",
108
+ DataConversionWarning,
109
+ stacklevel=2,
110
+ )
111
+
112
+ if y.ndim == 1:
113
+ # reshape is necessary to preserve the data contiguity against vs
114
+ # [:, np.newaxis] that does not.
115
+ y = np.reshape(y, (-1, 1))
116
+
117
+ y, expanded_class_weight = self._validate_y_class_weight(y)
118
+
119
+ self.n_features_in_ = X.shape[1]
120
+
121
+ if expanded_class_weight is not None:
122
+ if sample_weight is not None:
123
+ sample_weight = sample_weight * expanded_class_weight
124
+ else:
125
+ sample_weight = expanded_class_weight
126
+ if sample_weight is not None:
127
+ sample_weight = [sample_weight]
128
+
129
+ onedal_params = {
130
+ "n_estimators": self.n_estimators,
131
+ "criterion": self.criterion,
132
+ "max_depth": self.max_depth,
133
+ "min_samples_split": self.min_samples_split,
134
+ "min_samples_leaf": self.min_samples_leaf,
135
+ "min_weight_fraction_leaf": self.min_weight_fraction_leaf,
136
+ "max_features": self.max_features,
137
+ "max_leaf_nodes": self.max_leaf_nodes,
138
+ "min_impurity_decrease": self.min_impurity_decrease,
139
+ "bootstrap": self.bootstrap,
140
+ "oob_score": self.oob_score,
141
+ "n_jobs": self.n_jobs,
142
+ "random_state": self.random_state,
143
+ "verbose": self.verbose,
144
+ "warm_start": self.warm_start,
145
+ "error_metric_mode": self._err if self.oob_score else "none",
146
+ "variable_importance_mode": "mdi",
147
+ "class_weight": self.class_weight,
148
+ "max_bins": self.max_bins,
149
+ "min_bin_size": self.min_bin_size,
150
+ "max_samples": self.max_samples,
151
+ }
152
+
153
+ if not sklearn_check_version("1.0"):
154
+ onedal_params["min_impurity_split"] = self.min_impurity_split
155
+ else:
156
+ onedal_params["min_impurity_split"] = None
157
+
158
+ # Lazy evaluation of estimators_
159
+ self._cached_estimators_ = None
160
+
161
+ # Compute
162
+ self._onedal_estimator = self._onedal_factory(**onedal_params)
163
+ self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue)
164
+
165
+ self._save_attributes()
166
+
167
+ # Decapsulate classes_ attributes
168
+ if hasattr(self, "classes_") and self.n_outputs_ == 1:
169
+ self.n_classes_ = self.n_classes_[0]
170
+ self.classes_ = self.classes_[0]
171
+
172
+ return self
173
+
174
+ def _fit_proba(self, X, y, sample_weight=None, queue=None):
175
+ params = self.get_params()
176
+ self.__class__(**params)
177
+
178
+ # We use stock metaestimators below, so the only way
179
+ # to pass a queue is using config_context.
180
+ cfg = get_config()
181
+ cfg["target_offload"] = queue
182
+
183
+ def _save_attributes(self):
184
+ if self.oob_score:
185
+ self.oob_score_ = self._onedal_estimator.oob_score_
186
+ if hasattr(self._onedal_estimator, "oob_prediction_"):
187
+ self.oob_prediction_ = self._onedal_estimator.oob_prediction_
188
+ if hasattr(self._onedal_estimator, "oob_decision_function_"):
189
+ self.oob_decision_function_ = (
190
+ self._onedal_estimator.oob_decision_function_
191
+ )
192
+
193
+ self._validate_estimator()
194
+ return self
195
+
196
+ # TODO:
197
+ # move to onedal modul.
198
+ def _check_parameters(self):
199
+ if isinstance(self.min_samples_leaf, numbers.Integral):
200
+ if not 1 <= self.min_samples_leaf:
201
+ raise ValueError(
202
+ "min_samples_leaf must be at least 1 "
203
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
204
+ )
205
+ else: # float
206
+ if not 0.0 < self.min_samples_leaf <= 0.5:
207
+ raise ValueError(
208
+ "min_samples_leaf must be at least 1 "
209
+ "or in (0, 0.5], got %s" % self.min_samples_leaf
210
+ )
211
+ if isinstance(self.min_samples_split, numbers.Integral):
212
+ if not 2 <= self.min_samples_split:
213
+ raise ValueError(
214
+ "min_samples_split must be an integer "
215
+ "greater than 1 or a float in (0.0, 1.0]; "
216
+ "got the integer %s" % self.min_samples_split
217
+ )
218
+ else: # float
219
+ if not 0.0 < self.min_samples_split <= 1.0:
220
+ raise ValueError(
221
+ "min_samples_split must be an integer "
222
+ "greater than 1 or a float in (0.0, 1.0]; "
223
+ "got the float %s" % self.min_samples_split
224
+ )
225
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
226
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
227
+ if hasattr(self, "min_impurity_split"):
228
+ warnings.warn(
229
+ "The min_impurity_split parameter is deprecated. "
230
+ "Its default value has changed from 1e-7 to 0 in "
231
+ "version 0.23, and it will be removed in 0.25. "
232
+ "Use the min_impurity_decrease parameter instead.",
233
+ FutureWarning,
234
+ )
235
+
236
+ if getattr(self, "min_impurity_split") < 0.0:
237
+ raise ValueError(
238
+ "min_impurity_split must be greater than " "or equal to 0"
239
+ )
240
+ if self.min_impurity_decrease < 0.0:
241
+ raise ValueError(
242
+ "min_impurity_decrease must be greater than " "or equal to 0"
243
+ )
244
+ if self.max_leaf_nodes is not None:
245
+ if not isinstance(self.max_leaf_nodes, numbers.Integral):
246
+ raise ValueError(
247
+ "max_leaf_nodes must be integral number but was "
248
+ "%r" % self.max_leaf_nodes
249
+ )
250
+ if self.max_leaf_nodes < 2:
251
+ raise ValueError(
252
+ ("max_leaf_nodes {0} must be either None " "or larger than 1").format(
253
+ self.max_leaf_nodes
254
+ )
255
+ )
256
+ if isinstance(self.max_bins, numbers.Integral):
257
+ if not 2 <= self.max_bins:
258
+ raise ValueError("max_bins must be at least 2, got %s" % self.max_bins)
259
+ else:
260
+ raise ValueError(
261
+ "max_bins must be integral number but was " "%r" % self.max_bins
262
+ )
263
+ if isinstance(self.min_bin_size, numbers.Integral):
264
+ if not 1 <= self.min_bin_size:
265
+ raise ValueError(
266
+ "min_bin_size must be at least 1, got %s" % self.min_bin_size
267
+ )
268
+ else:
269
+ raise ValueError(
270
+ "min_bin_size must be integral number but was " "%r" % self.min_bin_size
271
+ )
272
+
273
+ def check_sample_weight(self, sample_weight, X, dtype=None):
274
+ n_samples = _num_samples(X)
275
+
276
+ if dtype is not None and dtype not in [np.float32, np.float64]:
277
+ dtype = np.float64
278
+
279
+ if sample_weight is None:
280
+ sample_weight = np.ones(n_samples, dtype=dtype)
281
+ elif isinstance(sample_weight, numbers.Number):
282
+ sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
283
+ else:
284
+ if dtype is None:
285
+ dtype = [np.float64, np.float32]
286
+ sample_weight = check_array(
287
+ sample_weight,
288
+ accept_sparse=False,
289
+ ensure_2d=False,
290
+ dtype=dtype,
291
+ order="C",
292
+ force_all_finite=False,
293
+ )
294
+ if sample_weight.ndim != 1:
295
+ raise ValueError("Sample weights must be 1D array or scalar")
296
+
297
+ if sample_weight.shape != (n_samples,):
298
+ raise ValueError(
299
+ "sample_weight.shape == {}, expected {}!".format(
300
+ sample_weight.shape, (n_samples,)
301
+ )
302
+ )
303
+ return sample_weight
304
+
305
+ @property
306
+ def estimators_(self):
307
+ if hasattr(self, "_cached_estimators_"):
308
+ if self._cached_estimators_ is None:
309
+ self._estimators_()
310
+ return self._cached_estimators_
311
+ else:
312
+ raise AttributeError(
313
+ f"'{self.__class__.__name__}' object has no attribute 'estimators_'"
314
+ )
315
+
316
+ @estimators_.setter
317
+ def estimators_(self, estimators):
318
+ # Needed to allow for proper sklearn operation in fallback mode
319
+ self._cached_estimators_ = estimators
320
+
321
+ def _estimators_(self):
322
+ # _estimators_ should only be called if _onedal_estimator exists
323
+ check_is_fitted(self, "_onedal_estimator")
324
+ if hasattr(self, "n_classes_"):
325
+ n_classes_ = (
326
+ self.n_classes_
327
+ if isinstance(self.n_classes_, int)
328
+ else self.n_classes_[0]
329
+ )
330
+ else:
331
+ n_classes_ = 1
332
+
333
+ # convert model to estimators
334
+ params = {
335
+ "criterion": self._onedal_estimator.criterion,
336
+ "max_depth": self._onedal_estimator.max_depth,
337
+ "min_samples_split": self._onedal_estimator.min_samples_split,
338
+ "min_samples_leaf": self._onedal_estimator.min_samples_leaf,
339
+ "min_weight_fraction_leaf": self._onedal_estimator.min_weight_fraction_leaf,
340
+ "max_features": self._onedal_estimator.max_features,
341
+ "max_leaf_nodes": self._onedal_estimator.max_leaf_nodes,
342
+ "min_impurity_decrease": self._onedal_estimator.min_impurity_decrease,
343
+ "random_state": None,
344
+ }
345
+ if not sklearn_check_version("1.0"):
346
+ params["min_impurity_split"] = self._onedal_estimator.min_impurity_split
347
+ est = self.estimator.__class__(**params)
348
+ # we need to set est.tree_ field with Trees constructed from Intel(R)
349
+ # oneAPI Data Analytics Library solution
350
+ estimators_ = []
351
+
352
+ random_state_checked = check_random_state(self.random_state)
353
+
354
+ for i in range(self._onedal_estimator.n_estimators):
355
+ est_i = clone(est)
356
+ est_i.set_params(
357
+ random_state=random_state_checked.randint(np.iinfo(np.int32).max)
358
+ )
359
+ if sklearn_check_version("1.0"):
360
+ est_i.n_features_in_ = self.n_features_in_
361
+ else:
362
+ est_i.n_features_ = self.n_features_in_
363
+ est_i.n_outputs_ = self.n_outputs_
364
+ est_i.n_classes_ = n_classes_
365
+ tree_i_state_class = self._get_tree_state(
366
+ self._onedal_estimator._onedal_model, i, n_classes_
367
+ )
368
+ tree_i_state_dict = {
369
+ "max_depth": tree_i_state_class.max_depth,
370
+ "node_count": tree_i_state_class.node_count,
371
+ "nodes": check_tree_nodes(tree_i_state_class.node_ar),
372
+ "values": tree_i_state_class.value_ar,
373
+ }
374
+ est_i.tree_ = Tree(
375
+ self.n_features_in_,
376
+ np.array([n_classes_], dtype=np.intp),
377
+ self.n_outputs_,
378
+ )
379
+ est_i.tree_.__setstate__(tree_i_state_dict)
380
+ estimators_.append(est_i)
381
+
382
+ self._cached_estimators_ = estimators_
383
+
384
+ if sklearn_check_version("1.0"):
385
+
386
+ @deprecated(
387
+ "Attribute `n_features_` was deprecated in version 1.0 and will be "
388
+ "removed in 1.2. Use `n_features_in_` instead."
389
+ )
390
+ @property
391
+ def n_features_(self):
392
+ return self.n_features_in_
393
+
394
+ if not sklearn_check_version("1.2"):
395
+
396
+ @property
397
+ def base_estimator(self):
398
+ return self.estimator
399
+
400
+ @base_estimator.setter
401
+ def base_estimator(self, estimator):
402
+ self.estimator = estimator
403
+
404
+
405
+ class ForestClassifier(sklearn_ForestClassifier, BaseForest):
406
+ # Surprisingly, even though scikit-learn warns against using
407
+ # their ForestClassifier directly, it actually has a more stable
408
+ # API than the user-facing objects (over time). If they change it
409
+ # significantly at some point then this may need to be versioned.
410
+
411
+ _err = "out_of_bag_error_accuracy|out_of_bag_error_decision_function"
412
+ _get_tree_state = staticmethod(get_tree_state_cls)
413
+
414
+ def __init__(
415
+ self,
416
+ estimator,
417
+ n_estimators=100,
418
+ *,
419
+ estimator_params=tuple(),
420
+ bootstrap=False,
421
+ oob_score=False,
422
+ n_jobs=None,
423
+ random_state=None,
424
+ verbose=0,
425
+ warm_start=False,
426
+ class_weight=None,
427
+ max_samples=None,
428
+ ):
429
+ super().__init__(
430
+ estimator,
431
+ n_estimators=n_estimators,
432
+ estimator_params=estimator_params,
433
+ bootstrap=bootstrap,
434
+ oob_score=oob_score,
435
+ n_jobs=n_jobs,
436
+ random_state=random_state,
437
+ verbose=verbose,
438
+ warm_start=warm_start,
439
+ class_weight=class_weight,
440
+ max_samples=max_samples,
441
+ )
442
+
443
+ # The estimator is checked against the class attribute for conformance.
444
+ # This should only trigger if the user uses this class directly.
445
+ if (
446
+ self.estimator.__class__ == DecisionTreeClassifier
447
+ and self._onedal_factory != onedal_RandomForestClassifier
448
+ ):
449
+ self._onedal_factory = onedal_RandomForestClassifier
450
+ elif (
451
+ self.estimator.__class__ == ExtraTreeClassifier
452
+ and self._onedal_factory != onedal_ExtraTreesClassifier
453
+ ):
454
+ self._onedal_factory = onedal_ExtraTreesClassifier
455
+
456
+ if self._onedal_factory is None:
457
+ raise TypeError(f" oneDAL estimator has not been set.")
458
+
459
+ def _estimators_(self):
460
+ super()._estimators_()
461
+ classes_ = self.classes_[0]
462
+ for est in self._cached_estimators_:
463
+ est.classes_ = classes_
464
+
465
+ def fit(self, X, y, sample_weight=None):
466
+ dispatch(
467
+ self,
468
+ "fit",
469
+ {
470
+ "onedal": self.__class__._onedal_fit,
471
+ "sklearn": sklearn_ForestClassifier.fit,
472
+ },
473
+ X,
474
+ y,
475
+ sample_weight,
476
+ )
477
+ return self
478
+
479
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
480
+ if sp.issparse(y):
481
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
482
+
483
+ if sklearn_check_version("1.2"):
484
+ self._validate_params()
485
+ else:
486
+ self._check_parameters()
487
+
488
+ if not self.bootstrap and self.oob_score:
489
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
490
+
491
+ patching_status.and_conditions(
492
+ [
493
+ (
494
+ self.oob_score
495
+ and daal_check_version((2021, "P", 500))
496
+ or not self.oob_score,
497
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
498
+ ),
499
+ (self.warm_start is False, "Warm start is not supported."),
500
+ (
501
+ self.criterion == "gini",
502
+ f"'{self.criterion}' criterion is not supported. "
503
+ "Only 'gini' criterion is supported.",
504
+ ),
505
+ (
506
+ self.ccp_alpha == 0.0,
507
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
508
+ ),
509
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
510
+ (
511
+ self.n_estimators <= 6024,
512
+ "More than 6024 estimators is not supported.",
513
+ ),
514
+ ]
515
+ )
516
+
517
+ if self.bootstrap:
518
+ patching_status.and_conditions(
519
+ [
520
+ (
521
+ self.class_weight != "balanced_subsample",
522
+ "'balanced_subsample' for class_weight is not supported",
523
+ )
524
+ ]
525
+ )
526
+
527
+ if patching_status.get_status() and sklearn_check_version("1.4"):
528
+ try:
529
+ _assert_all_finite(X)
530
+ input_is_finite = True
531
+ except ValueError:
532
+ input_is_finite = False
533
+ patching_status.and_conditions(
534
+ [
535
+ (input_is_finite, "Non-finite input is not supported."),
536
+ (
537
+ self.monotonic_cst is None,
538
+ "Monotonicity constraints are not supported.",
539
+ ),
540
+ ]
541
+ )
542
+
543
+ if patching_status.get_status():
544
+ if sklearn_check_version("0.24"):
545
+ X, y = self._validate_data(
546
+ X,
547
+ y,
548
+ multi_output=True,
549
+ accept_sparse=True,
550
+ dtype=[np.float64, np.float32],
551
+ force_all_finite=False,
552
+ )
553
+ else:
554
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
555
+ y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
556
+
557
+ if y.ndim == 2 and y.shape[1] == 1:
558
+ warnings.warn(
559
+ "A column-vector y was passed when a 1d array was"
560
+ " expected. Please change the shape of y to "
561
+ "(n_samples,), for example using ravel().",
562
+ DataConversionWarning,
563
+ stacklevel=2,
564
+ )
565
+
566
+ if y.ndim == 1:
567
+ y = np.reshape(y, (-1, 1))
568
+
569
+ self.n_outputs_ = y.shape[1]
570
+
571
+ patching_status.and_conditions(
572
+ [
573
+ (
574
+ self.n_outputs_ == 1,
575
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
576
+ ),
577
+ (
578
+ y.dtype in [np.float32, np.float64, np.int32, np.int64],
579
+ f"Datatype ({y.dtype}) for y is not supported.",
580
+ ),
581
+ ]
582
+ )
583
+ # TODO: Fix to support integers as input
584
+
585
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
586
+
587
+ if not self.bootstrap and self.max_samples is not None:
588
+ raise ValueError(
589
+ "`max_sample` cannot be set if `bootstrap=False`. "
590
+ "Either switch to `bootstrap=True` or set "
591
+ "`max_sample=None`."
592
+ )
593
+
594
+ if (
595
+ patching_status.get_status()
596
+ and (self.random_state is not None)
597
+ and (not daal_check_version((2024, "P", 0)))
598
+ ):
599
+ warnings.warn(
600
+ "Setting 'random_state' value is not supported. "
601
+ "State set by oneDAL to default value (777).",
602
+ RuntimeWarning,
603
+ )
604
+
605
+ return patching_status, X, y, sample_weight
606
+
607
+ @wrap_output_data
608
+ def predict(self, X):
609
+ return dispatch(
610
+ self,
611
+ "predict",
612
+ {
613
+ "onedal": self.__class__._onedal_predict,
614
+ "sklearn": sklearn_ForestClassifier.predict,
615
+ },
616
+ X,
617
+ )
618
+
619
+ @wrap_output_data
620
+ def predict_proba(self, X):
621
+ # TODO:
622
+ # _check_proba()
623
+ # self._check_proba()
624
+ if sklearn_check_version("1.0"):
625
+ self._check_feature_names(X, reset=False)
626
+ if hasattr(self, "n_features_in_"):
627
+ try:
628
+ num_features = _num_features(X)
629
+ except TypeError:
630
+ num_features = _num_samples(X)
631
+ if num_features != self.n_features_in_:
632
+ raise ValueError(
633
+ (
634
+ f"X has {num_features} features, "
635
+ f"but {self.__class__.__name__} is expecting "
636
+ f"{self.n_features_in_} features as input"
637
+ )
638
+ )
639
+ return dispatch(
640
+ self,
641
+ "predict_proba",
642
+ {
643
+ "onedal": self.__class__._onedal_predict_proba,
644
+ "sklearn": sklearn_ForestClassifier.predict_proba,
645
+ },
646
+ X,
647
+ )
648
+
649
+ fit.__doc__ = sklearn_ForestClassifier.fit.__doc__
650
+ predict.__doc__ = sklearn_ForestClassifier.predict.__doc__
651
+ predict_proba.__doc__ = sklearn_ForestClassifier.predict_proba.__doc__
652
+
653
+ def _onedal_cpu_supported(self, method_name, *data):
654
+ class_name = self.__class__.__name__
655
+ patching_status = PatchingConditionsChain(
656
+ f"sklearn.ensemble.{class_name}.{method_name}"
657
+ )
658
+
659
+ if method_name == "fit":
660
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
661
+ patching_status, *data
662
+ )
663
+
664
+ patching_status.and_conditions(
665
+ [
666
+ (
667
+ daal_check_version((2023, "P", 200))
668
+ or self.estimator.__class__ == DecisionTreeClassifier,
669
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
670
+ ),
671
+ (
672
+ not sp.issparse(sample_weight),
673
+ "sample_weight is sparse. " "Sparse input is not supported.",
674
+ ),
675
+ ]
676
+ )
677
+
678
+ elif method_name in ["predict", "predict_proba"]:
679
+ X = data[0]
680
+
681
+ patching_status.and_conditions(
682
+ [
683
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
684
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
685
+ (self.warm_start is False, "Warm start is not supported."),
686
+ (
687
+ daal_check_version((2023, "P", 100))
688
+ or self.estimator.__class__ == DecisionTreeClassifier,
689
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
690
+ ),
691
+ ]
692
+ )
693
+
694
+ if method_name == "predict_proba":
695
+ patching_status.and_conditions(
696
+ [
697
+ (
698
+ daal_check_version((2021, "P", 400)),
699
+ "oneDAL version is lower than 2021.4.",
700
+ )
701
+ ]
702
+ )
703
+
704
+ if hasattr(self, "n_outputs_"):
705
+ patching_status.and_conditions(
706
+ [
707
+ (
708
+ self.n_outputs_ == 1,
709
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
710
+ ),
711
+ ]
712
+ )
713
+
714
+ else:
715
+ raise RuntimeError(
716
+ f"Unknown method {method_name} in {self.__class__.__name__}"
717
+ )
718
+
719
+ return patching_status
720
+
721
+ def _onedal_gpu_supported(self, method_name, *data):
722
+ class_name = self.__class__.__name__
723
+ patching_status = PatchingConditionsChain(
724
+ f"sklearn.ensemble.{class_name}.{method_name}"
725
+ )
726
+
727
+ if method_name == "fit":
728
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
729
+ patching_status, *data
730
+ )
731
+
732
+ patching_status.and_conditions(
733
+ [
734
+ (
735
+ daal_check_version((2023, "P", 100))
736
+ or self.estimator.__class__ == DecisionTreeClassifier,
737
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
738
+ ),
739
+ (sample_weight is not None, "sample_weight is not supported."),
740
+ ]
741
+ )
742
+
743
+ elif method_name in ["predict", "predict_proba"]:
744
+ X = data[0]
745
+
746
+ patching_status.and_conditions(
747
+ [
748
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained"),
749
+ (
750
+ not sp.issparse(X),
751
+ "X is sparse. Sparse input is not supported.",
752
+ ),
753
+ (self.warm_start is False, "Warm start is not supported."),
754
+ (
755
+ daal_check_version((2023, "P", 100)),
756
+ "ExtraTrees supported starting from oneDAL version 2023.1",
757
+ ),
758
+ ]
759
+ )
760
+ if hasattr(self, "n_outputs_"):
761
+ patching_status.and_conditions(
762
+ [
763
+ (
764
+ self.n_outputs_ == 1,
765
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
766
+ ),
767
+ ]
768
+ )
769
+
770
+ else:
771
+ raise RuntimeError(
772
+ f"Unknown method {method_name} in {self.__class__.__name__}"
773
+ )
774
+
775
+ return patching_status
776
+
777
+ def _onedal_predict(self, X, queue=None):
778
+ X = check_array(
779
+ X,
780
+ dtype=[np.float64, np.float32],
781
+ force_all_finite=False,
782
+ ) # Warning, order of dtype matters
783
+ check_is_fitted(self, "_onedal_estimator")
784
+
785
+ if sklearn_check_version("1.0"):
786
+ self._check_feature_names(X, reset=False)
787
+
788
+ res = self._onedal_estimator.predict(X, queue=queue)
789
+ return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))
790
+
791
+ def _onedal_predict_proba(self, X, queue=None):
792
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
793
+ check_is_fitted(self, "_onedal_estimator")
794
+
795
+ if sklearn_check_version("0.23"):
796
+ self._check_n_features(X, reset=False)
797
+ if sklearn_check_version("1.0"):
798
+ self._check_feature_names(X, reset=False)
799
+ return self._onedal_estimator.predict_proba(X, queue=queue)
800
+
801
+
802
+ class ForestRegressor(sklearn_ForestRegressor, BaseForest):
803
+ _err = "out_of_bag_error_r2|out_of_bag_error_prediction"
804
+ _get_tree_state = staticmethod(get_tree_state_reg)
805
+
806
+ def __init__(
807
+ self,
808
+ estimator,
809
+ n_estimators=100,
810
+ *,
811
+ estimator_params=tuple(),
812
+ bootstrap=False,
813
+ oob_score=False,
814
+ n_jobs=None,
815
+ random_state=None,
816
+ verbose=0,
817
+ warm_start=False,
818
+ max_samples=None,
819
+ ):
820
+ super().__init__(
821
+ estimator,
822
+ n_estimators=n_estimators,
823
+ estimator_params=estimator_params,
824
+ bootstrap=bootstrap,
825
+ oob_score=oob_score,
826
+ n_jobs=n_jobs,
827
+ random_state=random_state,
828
+ verbose=verbose,
829
+ warm_start=warm_start,
830
+ max_samples=max_samples,
831
+ )
832
+
833
+ # The splitter is checked against the class attribute for conformance
834
+ # This should only trigger if the user uses this class directly.
835
+ if (
836
+ self.estimator.__class__ == DecisionTreeRegressor
837
+ and self._onedal_factory != onedal_RandomForestRegressor
838
+ ):
839
+ self._onedal_factory = onedal_RandomForestRegressor
840
+ elif (
841
+ self.estimator.__class__ == ExtraTreeRegressor
842
+ and self._onedal_factory != onedal_ExtraTreesRegressor
843
+ ):
844
+ self._onedal_factory = onedal_ExtraTreesRegressor
845
+
846
+ if self._onedal_factory is None:
847
+ raise TypeError(f" oneDAL estimator has not been set.")
848
+
849
+ def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
850
+ if sp.issparse(y):
851
+ raise ValueError("sparse multilabel-indicator for y is not supported.")
852
+
853
+ if sklearn_check_version("1.2"):
854
+ self._validate_params()
855
+ else:
856
+ self._check_parameters()
857
+
858
+ if not self.bootstrap and self.oob_score:
859
+ raise ValueError("Out of bag estimation only available" " if bootstrap=True")
860
+
861
+ if sklearn_check_version("1.0") and self.criterion == "mse":
862
+ warnings.warn(
863
+ "Criterion 'mse' was deprecated in v1.0 and will be "
864
+ "removed in version 1.2. Use `criterion='squared_error'` "
865
+ "which is equivalent.",
866
+ FutureWarning,
867
+ )
868
+
869
+ patching_status.and_conditions(
870
+ [
871
+ (
872
+ self.oob_score
873
+ and daal_check_version((2021, "P", 500))
874
+ or not self.oob_score,
875
+ "OOB score is only supported starting from 2021.5 version of oneDAL.",
876
+ ),
877
+ (self.warm_start is False, "Warm start is not supported."),
878
+ (
879
+ self.criterion in ["mse", "squared_error"],
880
+ f"'{self.criterion}' criterion is not supported. "
881
+ "Only 'mse' and 'squared_error' criteria are supported.",
882
+ ),
883
+ (
884
+ self.ccp_alpha == 0.0,
885
+ f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported.",
886
+ ),
887
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
888
+ (
889
+ self.n_estimators <= 6024,
890
+ "More than 6024 estimators is not supported.",
891
+ ),
892
+ ]
893
+ )
894
+
895
+ if patching_status.get_status() and sklearn_check_version("1.4"):
896
+ try:
897
+ _assert_all_finite(X)
898
+ input_is_finite = True
899
+ except ValueError:
900
+ input_is_finite = False
901
+ patching_status.and_conditions(
902
+ [
903
+ (input_is_finite, "Non-finite input is not supported."),
904
+ (
905
+ self.monotonic_cst is None,
906
+ "Monotonicity constraints are not supported.",
907
+ ),
908
+ ]
909
+ )
910
+
911
+ if patching_status.get_status():
912
+ if sklearn_check_version("0.24"):
913
+ X, y = self._validate_data(
914
+ X,
915
+ y,
916
+ multi_output=True,
917
+ accept_sparse=True,
918
+ dtype=[np.float64, np.float32],
919
+ force_all_finite=False,
920
+ )
921
+ else:
922
+ X = check_array(X, dtype=[np.float64, np.float32], force_all_finite=False)
923
+ y = check_array(y, ensure_2d=False, dtype=X.dtype, force_all_finite=False)
924
+
925
+ if y.ndim == 2 and y.shape[1] == 1:
926
+ warnings.warn(
927
+ "A column-vector y was passed when a 1d array was"
928
+ " expected. Please change the shape of y to "
929
+ "(n_samples,), for example using ravel().",
930
+ DataConversionWarning,
931
+ stacklevel=2,
932
+ )
933
+
934
+ if y.ndim == 1:
935
+ # reshape is necessary to preserve the data contiguity against vs
936
+ # [:, np.newaxis] that does not.
937
+ y = np.reshape(y, (-1, 1))
938
+
939
+ self.n_outputs_ = y.shape[1]
940
+
941
+ patching_status.and_conditions(
942
+ [
943
+ (
944
+ self.n_outputs_ == 1,
945
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
946
+ )
947
+ ]
948
+ )
949
+
950
+ # Sklearn function used for doing checks on max_samples attribute
951
+ _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
952
+
953
+ if not self.bootstrap and self.max_samples is not None:
954
+ raise ValueError(
955
+ "`max_sample` cannot be set if `bootstrap=False`. "
956
+ "Either switch to `bootstrap=True` or set "
957
+ "`max_sample=None`."
958
+ )
959
+
960
+ if (
961
+ patching_status.get_status()
962
+ and (self.random_state is not None)
963
+ and (not daal_check_version((2024, "P", 0)))
964
+ ):
965
+ warnings.warn(
966
+ "Setting 'random_state' value is not supported. "
967
+ "State set by oneDAL to default value (777).",
968
+ RuntimeWarning,
969
+ )
970
+
971
+ return patching_status, X, y, sample_weight
972
+
973
+ def _onedal_cpu_supported(self, method_name, *data):
974
+ class_name = self.__class__.__name__
975
+ patching_status = PatchingConditionsChain(
976
+ f"sklearn.ensemble.{class_name}.{method_name}"
977
+ )
978
+
979
+ if method_name == "fit":
980
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
981
+ patching_status, *data
982
+ )
983
+
984
+ patching_status.and_conditions(
985
+ [
986
+ (
987
+ daal_check_version((2023, "P", 200))
988
+ or self.estimator.__class__ == DecisionTreeClassifier,
989
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
990
+ ),
991
+ (
992
+ not sp.issparse(sample_weight),
993
+ "sample_weight is sparse. " "Sparse input is not supported.",
994
+ ),
995
+ ]
996
+ )
997
+
998
+ elif method_name == "predict":
999
+ X = data[0]
1000
+
1001
+ patching_status.and_conditions(
1002
+ [
1003
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1004
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1005
+ (self.warm_start is False, "Warm start is not supported."),
1006
+ (
1007
+ daal_check_version((2023, "P", 200))
1008
+ or self.estimator.__class__ == DecisionTreeClassifier,
1009
+ "ExtraTrees only supported starting from oneDAL version 2023.2",
1010
+ ),
1011
+ ]
1012
+ )
1013
+ if hasattr(self, "n_outputs_"):
1014
+ patching_status.and_conditions(
1015
+ [
1016
+ (
1017
+ self.n_outputs_ == 1,
1018
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1019
+ ),
1020
+ ]
1021
+ )
1022
+
1023
+ else:
1024
+ raise RuntimeError(
1025
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1026
+ )
1027
+
1028
+ return patching_status
1029
+
1030
+ def _onedal_gpu_supported(self, method_name, *data):
1031
+ class_name = self.__class__.__name__
1032
+ patching_status = PatchingConditionsChain(
1033
+ f"sklearn.ensemble.{class_name}.{method_name}"
1034
+ )
1035
+
1036
+ if method_name == "fit":
1037
+ patching_status, X, y, sample_weight = self._onedal_fit_ready(
1038
+ patching_status, *data
1039
+ )
1040
+
1041
+ patching_status.and_conditions(
1042
+ [
1043
+ (
1044
+ daal_check_version((2023, "P", 100))
1045
+ or self.estimator.__class__ == DecisionTreeClassifier,
1046
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1047
+ ),
1048
+ (sample_weight is not None, "sample_weight is not supported."),
1049
+ ]
1050
+ )
1051
+
1052
+ elif method_name == "predict":
1053
+ X = data[0]
1054
+
1055
+ patching_status.and_conditions(
1056
+ [
1057
+ (hasattr(self, "_onedal_estimator"), "oneDAL model was not trained."),
1058
+ (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1059
+ (self.warm_start is False, "Warm start is not supported."),
1060
+ (
1061
+ daal_check_version((2023, "P", 100))
1062
+ or self.estimator.__class__ == DecisionTreeClassifier,
1063
+ "ExtraTrees only supported starting from oneDAL version 2023.1",
1064
+ ),
1065
+ ]
1066
+ )
1067
+ if hasattr(self, "n_outputs_"):
1068
+ patching_status.and_conditions(
1069
+ [
1070
+ (
1071
+ self.n_outputs_ == 1,
1072
+ f"Number of outputs ({self.n_outputs_}) is not 1.",
1073
+ ),
1074
+ ]
1075
+ )
1076
+
1077
+ else:
1078
+ raise RuntimeError(
1079
+ f"Unknown method {method_name} in {self.__class__.__name__}"
1080
+ )
1081
+
1082
+ return patching_status
1083
+
1084
+ def _onedal_predict(self, X, queue=None):
1085
+ X = check_array(
1086
+ X, dtype=[np.float64, np.float32], force_all_finite=False
1087
+ ) # Warning, order of dtype matters
1088
+ check_is_fitted(self, "_onedal_estimator")
1089
+
1090
+ if sklearn_check_version("1.0"):
1091
+ self._check_feature_names(X, reset=False)
1092
+
1093
+ return self._onedal_estimator.predict(X, queue=queue)
1094
+
1095
+ def fit(self, X, y, sample_weight=None):
1096
+ dispatch(
1097
+ self,
1098
+ "fit",
1099
+ {
1100
+ "onedal": self.__class__._onedal_fit,
1101
+ "sklearn": sklearn_ForestRegressor.fit,
1102
+ },
1103
+ X,
1104
+ y,
1105
+ sample_weight,
1106
+ )
1107
+ return self
1108
+
1109
+ @wrap_output_data
1110
+ def predict(self, X):
1111
+ return dispatch(
1112
+ self,
1113
+ "predict",
1114
+ {
1115
+ "onedal": self.__class__._onedal_predict,
1116
+ "sklearn": sklearn_ForestRegressor.predict,
1117
+ },
1118
+ X,
1119
+ )
1120
+
1121
+ fit.__doc__ = sklearn_ForestRegressor.fit.__doc__
1122
+ predict.__doc__ = sklearn_ForestRegressor.predict.__doc__
1123
+
1124
+
1125
+ class RandomForestClassifier(ForestClassifier):
1126
+ __doc__ = sklearn_RandomForestClassifier.__doc__
1127
+ _onedal_factory = onedal_RandomForestClassifier
1128
+
1129
+ if sklearn_check_version("1.2"):
1130
+ _parameter_constraints: dict = {
1131
+ **sklearn_RandomForestClassifier._parameter_constraints,
1132
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1133
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1134
+ }
1135
+
1136
+ if sklearn_check_version("1.4"):
1137
+
1138
+ def __init__(
1139
+ self,
1140
+ n_estimators=100,
1141
+ *,
1142
+ criterion="gini",
1143
+ max_depth=None,
1144
+ min_samples_split=2,
1145
+ min_samples_leaf=1,
1146
+ min_weight_fraction_leaf=0.0,
1147
+ max_features="sqrt",
1148
+ max_leaf_nodes=None,
1149
+ min_impurity_decrease=0.0,
1150
+ bootstrap=True,
1151
+ oob_score=False,
1152
+ n_jobs=None,
1153
+ random_state=None,
1154
+ verbose=0,
1155
+ warm_start=False,
1156
+ class_weight=None,
1157
+ ccp_alpha=0.0,
1158
+ max_samples=None,
1159
+ monotonic_cst=None,
1160
+ max_bins=256,
1161
+ min_bin_size=1,
1162
+ ):
1163
+ super().__init__(
1164
+ DecisionTreeClassifier(),
1165
+ n_estimators,
1166
+ estimator_params=(
1167
+ "criterion",
1168
+ "max_depth",
1169
+ "min_samples_split",
1170
+ "min_samples_leaf",
1171
+ "min_weight_fraction_leaf",
1172
+ "max_features",
1173
+ "max_leaf_nodes",
1174
+ "min_impurity_decrease",
1175
+ "random_state",
1176
+ "ccp_alpha",
1177
+ "monotonic_cst",
1178
+ ),
1179
+ bootstrap=bootstrap,
1180
+ oob_score=oob_score,
1181
+ n_jobs=n_jobs,
1182
+ random_state=random_state,
1183
+ verbose=verbose,
1184
+ warm_start=warm_start,
1185
+ class_weight=class_weight,
1186
+ max_samples=max_samples,
1187
+ )
1188
+
1189
+ self.criterion = criterion
1190
+ self.max_depth = max_depth
1191
+ self.min_samples_split = min_samples_split
1192
+ self.min_samples_leaf = min_samples_leaf
1193
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1194
+ self.max_features = max_features
1195
+ self.max_leaf_nodes = max_leaf_nodes
1196
+ self.min_impurity_decrease = min_impurity_decrease
1197
+ self.ccp_alpha = ccp_alpha
1198
+ self.max_bins = max_bins
1199
+ self.min_bin_size = min_bin_size
1200
+ self.monotonic_cst = monotonic_cst
1201
+
1202
+ elif sklearn_check_version("1.0"):
1203
+
1204
+ def __init__(
1205
+ self,
1206
+ n_estimators=100,
1207
+ *,
1208
+ criterion="gini",
1209
+ max_depth=None,
1210
+ min_samples_split=2,
1211
+ min_samples_leaf=1,
1212
+ min_weight_fraction_leaf=0.0,
1213
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1214
+ max_leaf_nodes=None,
1215
+ min_impurity_decrease=0.0,
1216
+ bootstrap=True,
1217
+ oob_score=False,
1218
+ n_jobs=None,
1219
+ random_state=None,
1220
+ verbose=0,
1221
+ warm_start=False,
1222
+ class_weight=None,
1223
+ ccp_alpha=0.0,
1224
+ max_samples=None,
1225
+ max_bins=256,
1226
+ min_bin_size=1,
1227
+ ):
1228
+ super().__init__(
1229
+ DecisionTreeClassifier(),
1230
+ n_estimators,
1231
+ estimator_params=(
1232
+ "criterion",
1233
+ "max_depth",
1234
+ "min_samples_split",
1235
+ "min_samples_leaf",
1236
+ "min_weight_fraction_leaf",
1237
+ "max_features",
1238
+ "max_leaf_nodes",
1239
+ "min_impurity_decrease",
1240
+ "random_state",
1241
+ "ccp_alpha",
1242
+ ),
1243
+ bootstrap=bootstrap,
1244
+ oob_score=oob_score,
1245
+ n_jobs=n_jobs,
1246
+ random_state=random_state,
1247
+ verbose=verbose,
1248
+ warm_start=warm_start,
1249
+ class_weight=class_weight,
1250
+ max_samples=max_samples,
1251
+ )
1252
+
1253
+ self.criterion = criterion
1254
+ self.max_depth = max_depth
1255
+ self.min_samples_split = min_samples_split
1256
+ self.min_samples_leaf = min_samples_leaf
1257
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1258
+ self.max_features = max_features
1259
+ self.max_leaf_nodes = max_leaf_nodes
1260
+ self.min_impurity_decrease = min_impurity_decrease
1261
+ self.ccp_alpha = ccp_alpha
1262
+ self.max_bins = max_bins
1263
+ self.min_bin_size = min_bin_size
1264
+
1265
+ else:
1266
+
1267
+ def __init__(
1268
+ self,
1269
+ n_estimators=100,
1270
+ *,
1271
+ criterion="gini",
1272
+ max_depth=None,
1273
+ min_samples_split=2,
1274
+ min_samples_leaf=1,
1275
+ min_weight_fraction_leaf=0.0,
1276
+ max_features="auto",
1277
+ max_leaf_nodes=None,
1278
+ min_impurity_decrease=0.0,
1279
+ min_impurity_split=None,
1280
+ bootstrap=True,
1281
+ oob_score=False,
1282
+ n_jobs=None,
1283
+ random_state=None,
1284
+ verbose=0,
1285
+ warm_start=False,
1286
+ class_weight=None,
1287
+ ccp_alpha=0.0,
1288
+ max_samples=None,
1289
+ max_bins=256,
1290
+ min_bin_size=1,
1291
+ ):
1292
+ super().__init__(
1293
+ DecisionTreeClassifier(),
1294
+ n_estimators,
1295
+ estimator_params=(
1296
+ "criterion",
1297
+ "max_depth",
1298
+ "min_samples_split",
1299
+ "min_samples_leaf",
1300
+ "min_weight_fraction_leaf",
1301
+ "max_features",
1302
+ "max_leaf_nodes",
1303
+ "min_impurity_decrease",
1304
+ "min_impurity_split",
1305
+ "random_state",
1306
+ "ccp_alpha",
1307
+ ),
1308
+ bootstrap=bootstrap,
1309
+ oob_score=oob_score,
1310
+ n_jobs=n_jobs,
1311
+ random_state=random_state,
1312
+ verbose=verbose,
1313
+ warm_start=warm_start,
1314
+ class_weight=class_weight,
1315
+ max_samples=max_samples,
1316
+ )
1317
+
1318
+ self.criterion = criterion
1319
+ self.max_depth = max_depth
1320
+ self.min_samples_split = min_samples_split
1321
+ self.min_samples_leaf = min_samples_leaf
1322
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1323
+ self.max_features = max_features
1324
+ self.max_leaf_nodes = max_leaf_nodes
1325
+ self.min_impurity_decrease = min_impurity_decrease
1326
+ self.min_impurity_split = min_impurity_split
1327
+ self.ccp_alpha = ccp_alpha
1328
+ self.max_bins = max_bins
1329
+ self.min_bin_size = min_bin_size
1330
+ self.max_bins = max_bins
1331
+ self.min_bin_size = min_bin_size
1332
+
1333
+
1334
+ class RandomForestRegressor(ForestRegressor):
1335
+ __doc__ = sklearn_RandomForestRegressor.__doc__
1336
+ _onedal_factory = onedal_RandomForestRegressor
1337
+
1338
+ if sklearn_check_version("1.2"):
1339
+ _parameter_constraints: dict = {
1340
+ **sklearn_RandomForestRegressor._parameter_constraints,
1341
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1342
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1343
+ }
1344
+
1345
+ if sklearn_check_version("1.4"):
1346
+
1347
+ def __init__(
1348
+ self,
1349
+ n_estimators=100,
1350
+ *,
1351
+ criterion="squared_error",
1352
+ max_depth=None,
1353
+ min_samples_split=2,
1354
+ min_samples_leaf=1,
1355
+ min_weight_fraction_leaf=0.0,
1356
+ max_features=1.0,
1357
+ max_leaf_nodes=None,
1358
+ min_impurity_decrease=0.0,
1359
+ bootstrap=True,
1360
+ oob_score=False,
1361
+ n_jobs=None,
1362
+ random_state=None,
1363
+ verbose=0,
1364
+ warm_start=False,
1365
+ ccp_alpha=0.0,
1366
+ max_samples=None,
1367
+ monotonic_cst=None,
1368
+ max_bins=256,
1369
+ min_bin_size=1,
1370
+ ):
1371
+ super().__init__(
1372
+ DecisionTreeRegressor(),
1373
+ n_estimators=n_estimators,
1374
+ estimator_params=(
1375
+ "criterion",
1376
+ "max_depth",
1377
+ "min_samples_split",
1378
+ "min_samples_leaf",
1379
+ "min_weight_fraction_leaf",
1380
+ "max_features",
1381
+ "max_leaf_nodes",
1382
+ "min_impurity_decrease",
1383
+ "random_state",
1384
+ "ccp_alpha",
1385
+ "monotonic_cst",
1386
+ ),
1387
+ bootstrap=bootstrap,
1388
+ oob_score=oob_score,
1389
+ n_jobs=n_jobs,
1390
+ random_state=random_state,
1391
+ verbose=verbose,
1392
+ warm_start=warm_start,
1393
+ max_samples=max_samples,
1394
+ )
1395
+
1396
+ self.criterion = criterion
1397
+ self.max_depth = max_depth
1398
+ self.min_samples_split = min_samples_split
1399
+ self.min_samples_leaf = min_samples_leaf
1400
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1401
+ self.max_features = max_features
1402
+ self.max_leaf_nodes = max_leaf_nodes
1403
+ self.min_impurity_decrease = min_impurity_decrease
1404
+ self.ccp_alpha = ccp_alpha
1405
+ self.max_bins = max_bins
1406
+ self.min_bin_size = min_bin_size
1407
+ self.monotonic_cst = monotonic_cst
1408
+
1409
+ elif sklearn_check_version("1.0"):
1410
+
1411
+ def __init__(
1412
+ self,
1413
+ n_estimators=100,
1414
+ *,
1415
+ criterion="squared_error",
1416
+ max_depth=None,
1417
+ min_samples_split=2,
1418
+ min_samples_leaf=1,
1419
+ min_weight_fraction_leaf=0.0,
1420
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1421
+ max_leaf_nodes=None,
1422
+ min_impurity_decrease=0.0,
1423
+ bootstrap=True,
1424
+ oob_score=False,
1425
+ n_jobs=None,
1426
+ random_state=None,
1427
+ verbose=0,
1428
+ warm_start=False,
1429
+ ccp_alpha=0.0,
1430
+ max_samples=None,
1431
+ max_bins=256,
1432
+ min_bin_size=1,
1433
+ ):
1434
+ super().__init__(
1435
+ DecisionTreeRegressor(),
1436
+ n_estimators=n_estimators,
1437
+ estimator_params=(
1438
+ "criterion",
1439
+ "max_depth",
1440
+ "min_samples_split",
1441
+ "min_samples_leaf",
1442
+ "min_weight_fraction_leaf",
1443
+ "max_features",
1444
+ "max_leaf_nodes",
1445
+ "min_impurity_decrease",
1446
+ "random_state",
1447
+ "ccp_alpha",
1448
+ ),
1449
+ bootstrap=bootstrap,
1450
+ oob_score=oob_score,
1451
+ n_jobs=n_jobs,
1452
+ random_state=random_state,
1453
+ verbose=verbose,
1454
+ warm_start=warm_start,
1455
+ max_samples=max_samples,
1456
+ )
1457
+
1458
+ self.criterion = criterion
1459
+ self.max_depth = max_depth
1460
+ self.min_samples_split = min_samples_split
1461
+ self.min_samples_leaf = min_samples_leaf
1462
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1463
+ self.max_features = max_features
1464
+ self.max_leaf_nodes = max_leaf_nodes
1465
+ self.min_impurity_decrease = min_impurity_decrease
1466
+ self.ccp_alpha = ccp_alpha
1467
+ self.max_bins = max_bins
1468
+ self.min_bin_size = min_bin_size
1469
+
1470
+ else:
1471
+
1472
+ def __init__(
1473
+ self,
1474
+ n_estimators=100,
1475
+ *,
1476
+ criterion="mse",
1477
+ max_depth=None,
1478
+ min_samples_split=2,
1479
+ min_samples_leaf=1,
1480
+ min_weight_fraction_leaf=0.0,
1481
+ max_features="auto",
1482
+ max_leaf_nodes=None,
1483
+ min_impurity_decrease=0.0,
1484
+ min_impurity_split=None,
1485
+ bootstrap=True,
1486
+ oob_score=False,
1487
+ n_jobs=None,
1488
+ random_state=None,
1489
+ verbose=0,
1490
+ warm_start=False,
1491
+ ccp_alpha=0.0,
1492
+ max_samples=None,
1493
+ max_bins=256,
1494
+ min_bin_size=1,
1495
+ ):
1496
+ super().__init__(
1497
+ DecisionTreeRegressor(),
1498
+ n_estimators=n_estimators,
1499
+ estimator_params=(
1500
+ "criterion",
1501
+ "max_depth",
1502
+ "min_samples_split",
1503
+ "min_samples_leaf",
1504
+ "min_weight_fraction_leaf",
1505
+ "max_features",
1506
+ "max_leaf_nodes",
1507
+ "min_impurity_decrease",
1508
+ "min_impurity_split" "random_state",
1509
+ "ccp_alpha",
1510
+ ),
1511
+ bootstrap=bootstrap,
1512
+ oob_score=oob_score,
1513
+ n_jobs=n_jobs,
1514
+ random_state=random_state,
1515
+ verbose=verbose,
1516
+ warm_start=warm_start,
1517
+ max_samples=max_samples,
1518
+ )
1519
+
1520
+ self.criterion = criterion
1521
+ self.max_depth = max_depth
1522
+ self.min_samples_split = min_samples_split
1523
+ self.min_samples_leaf = min_samples_leaf
1524
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1525
+ self.max_features = max_features
1526
+ self.max_leaf_nodes = max_leaf_nodes
1527
+ self.min_impurity_decrease = min_impurity_decrease
1528
+ self.min_impurity_split = min_impurity_split
1529
+ self.ccp_alpha = ccp_alpha
1530
+ self.max_bins = max_bins
1531
+ self.min_bin_size = min_bin_size
1532
+
1533
+
1534
+ class ExtraTreesClassifier(ForestClassifier):
1535
+ __doc__ = sklearn_ExtraTreesClassifier.__doc__
1536
+ _onedal_factory = onedal_ExtraTreesClassifier
1537
+
1538
+ if sklearn_check_version("1.2"):
1539
+ _parameter_constraints: dict = {
1540
+ **sklearn_ExtraTreesClassifier._parameter_constraints,
1541
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1542
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1543
+ }
1544
+
1545
+ if sklearn_check_version("1.4"):
1546
+
1547
+ def __init__(
1548
+ self,
1549
+ n_estimators=100,
1550
+ *,
1551
+ criterion="gini",
1552
+ max_depth=None,
1553
+ min_samples_split=2,
1554
+ min_samples_leaf=1,
1555
+ min_weight_fraction_leaf=0.0,
1556
+ max_features="sqrt",
1557
+ max_leaf_nodes=None,
1558
+ min_impurity_decrease=0.0,
1559
+ bootstrap=False,
1560
+ oob_score=False,
1561
+ n_jobs=None,
1562
+ random_state=None,
1563
+ verbose=0,
1564
+ warm_start=False,
1565
+ class_weight=None,
1566
+ ccp_alpha=0.0,
1567
+ max_samples=None,
1568
+ monotonic_cst=None,
1569
+ max_bins=256,
1570
+ min_bin_size=1,
1571
+ ):
1572
+ super().__init__(
1573
+ ExtraTreeClassifier(),
1574
+ n_estimators,
1575
+ estimator_params=(
1576
+ "criterion",
1577
+ "max_depth",
1578
+ "min_samples_split",
1579
+ "min_samples_leaf",
1580
+ "min_weight_fraction_leaf",
1581
+ "max_features",
1582
+ "max_leaf_nodes",
1583
+ "min_impurity_decrease",
1584
+ "random_state",
1585
+ "ccp_alpha",
1586
+ "monotonic_cst",
1587
+ ),
1588
+ bootstrap=bootstrap,
1589
+ oob_score=oob_score,
1590
+ n_jobs=n_jobs,
1591
+ random_state=random_state,
1592
+ verbose=verbose,
1593
+ warm_start=warm_start,
1594
+ class_weight=class_weight,
1595
+ max_samples=max_samples,
1596
+ )
1597
+
1598
+ self.criterion = criterion
1599
+ self.max_depth = max_depth
1600
+ self.min_samples_split = min_samples_split
1601
+ self.min_samples_leaf = min_samples_leaf
1602
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1603
+ self.max_features = max_features
1604
+ self.max_leaf_nodes = max_leaf_nodes
1605
+ self.min_impurity_decrease = min_impurity_decrease
1606
+ self.ccp_alpha = ccp_alpha
1607
+ self.max_bins = max_bins
1608
+ self.min_bin_size = min_bin_size
1609
+ self.monotonic_cst = monotonic_cst
1610
+
1611
+ elif sklearn_check_version("1.0"):
1612
+
1613
+ def __init__(
1614
+ self,
1615
+ n_estimators=100,
1616
+ *,
1617
+ criterion="gini",
1618
+ max_depth=None,
1619
+ min_samples_split=2,
1620
+ min_samples_leaf=1,
1621
+ min_weight_fraction_leaf=0.0,
1622
+ max_features="sqrt" if sklearn_check_version("1.1") else "auto",
1623
+ max_leaf_nodes=None,
1624
+ min_impurity_decrease=0.0,
1625
+ bootstrap=False,
1626
+ oob_score=False,
1627
+ n_jobs=None,
1628
+ random_state=None,
1629
+ verbose=0,
1630
+ warm_start=False,
1631
+ class_weight=None,
1632
+ ccp_alpha=0.0,
1633
+ max_samples=None,
1634
+ max_bins=256,
1635
+ min_bin_size=1,
1636
+ ):
1637
+ super().__init__(
1638
+ ExtraTreeClassifier(),
1639
+ n_estimators,
1640
+ estimator_params=(
1641
+ "criterion",
1642
+ "max_depth",
1643
+ "min_samples_split",
1644
+ "min_samples_leaf",
1645
+ "min_weight_fraction_leaf",
1646
+ "max_features",
1647
+ "max_leaf_nodes",
1648
+ "min_impurity_decrease",
1649
+ "random_state",
1650
+ "ccp_alpha",
1651
+ ),
1652
+ bootstrap=bootstrap,
1653
+ oob_score=oob_score,
1654
+ n_jobs=n_jobs,
1655
+ random_state=random_state,
1656
+ verbose=verbose,
1657
+ warm_start=warm_start,
1658
+ class_weight=class_weight,
1659
+ max_samples=max_samples,
1660
+ )
1661
+
1662
+ self.criterion = criterion
1663
+ self.max_depth = max_depth
1664
+ self.min_samples_split = min_samples_split
1665
+ self.min_samples_leaf = min_samples_leaf
1666
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1667
+ self.max_features = max_features
1668
+ self.max_leaf_nodes = max_leaf_nodes
1669
+ self.min_impurity_decrease = min_impurity_decrease
1670
+ self.ccp_alpha = ccp_alpha
1671
+ self.max_bins = max_bins
1672
+ self.min_bin_size = min_bin_size
1673
+
1674
+ else:
1675
+
1676
+ def __init__(
1677
+ self,
1678
+ n_estimators=100,
1679
+ *,
1680
+ criterion="gini",
1681
+ max_depth=None,
1682
+ min_samples_split=2,
1683
+ min_samples_leaf=1,
1684
+ min_weight_fraction_leaf=0.0,
1685
+ max_features="auto",
1686
+ max_leaf_nodes=None,
1687
+ min_impurity_decrease=0.0,
1688
+ min_impurity_split=None,
1689
+ bootstrap=False,
1690
+ oob_score=False,
1691
+ n_jobs=None,
1692
+ random_state=None,
1693
+ verbose=0,
1694
+ warm_start=False,
1695
+ class_weight=None,
1696
+ ccp_alpha=0.0,
1697
+ max_samples=None,
1698
+ max_bins=256,
1699
+ min_bin_size=1,
1700
+ ):
1701
+ super().__init__(
1702
+ ExtraTreeClassifier(),
1703
+ n_estimators,
1704
+ estimator_params=(
1705
+ "criterion",
1706
+ "max_depth",
1707
+ "min_samples_split",
1708
+ "min_samples_leaf",
1709
+ "min_weight_fraction_leaf",
1710
+ "max_features",
1711
+ "max_leaf_nodes",
1712
+ "min_impurity_decrease",
1713
+ "min_impurity_split",
1714
+ "random_state",
1715
+ "ccp_alpha",
1716
+ ),
1717
+ bootstrap=bootstrap,
1718
+ oob_score=oob_score,
1719
+ n_jobs=n_jobs,
1720
+ random_state=random_state,
1721
+ verbose=verbose,
1722
+ warm_start=warm_start,
1723
+ class_weight=class_weight,
1724
+ max_samples=max_samples,
1725
+ )
1726
+
1727
+ self.criterion = criterion
1728
+ self.max_depth = max_depth
1729
+ self.min_samples_split = min_samples_split
1730
+ self.min_samples_leaf = min_samples_leaf
1731
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1732
+ self.max_features = max_features
1733
+ self.max_leaf_nodes = max_leaf_nodes
1734
+ self.min_impurity_decrease = min_impurity_decrease
1735
+ self.min_impurity_split = min_impurity_split
1736
+ self.ccp_alpha = ccp_alpha
1737
+ self.max_bins = max_bins
1738
+ self.min_bin_size = min_bin_size
1739
+ self.max_bins = max_bins
1740
+ self.min_bin_size = min_bin_size
1741
+
1742
+
1743
+ class ExtraTreesRegressor(ForestRegressor):
1744
+ __doc__ = sklearn_ExtraTreesRegressor.__doc__
1745
+ _onedal_factory = onedal_ExtraTreesRegressor
1746
+
1747
+ if sklearn_check_version("1.2"):
1748
+ _parameter_constraints: dict = {
1749
+ **sklearn_ExtraTreesRegressor._parameter_constraints,
1750
+ "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
1751
+ "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")],
1752
+ }
1753
+
1754
+ if sklearn_check_version("1.4"):
1755
+
1756
+ def __init__(
1757
+ self,
1758
+ n_estimators=100,
1759
+ *,
1760
+ criterion="squared_error",
1761
+ max_depth=None,
1762
+ min_samples_split=2,
1763
+ min_samples_leaf=1,
1764
+ min_weight_fraction_leaf=0.0,
1765
+ max_features=1.0,
1766
+ max_leaf_nodes=None,
1767
+ min_impurity_decrease=0.0,
1768
+ bootstrap=False,
1769
+ oob_score=False,
1770
+ n_jobs=None,
1771
+ random_state=None,
1772
+ verbose=0,
1773
+ warm_start=False,
1774
+ ccp_alpha=0.0,
1775
+ max_samples=None,
1776
+ monotonic_cst=None,
1777
+ max_bins=256,
1778
+ min_bin_size=1,
1779
+ ):
1780
+ super().__init__(
1781
+ ExtraTreeRegressor(),
1782
+ n_estimators=n_estimators,
1783
+ estimator_params=(
1784
+ "criterion",
1785
+ "max_depth",
1786
+ "min_samples_split",
1787
+ "min_samples_leaf",
1788
+ "min_weight_fraction_leaf",
1789
+ "max_features",
1790
+ "max_leaf_nodes",
1791
+ "min_impurity_decrease",
1792
+ "random_state",
1793
+ "ccp_alpha",
1794
+ "monotonic_cst",
1795
+ ),
1796
+ bootstrap=bootstrap,
1797
+ oob_score=oob_score,
1798
+ n_jobs=n_jobs,
1799
+ random_state=random_state,
1800
+ verbose=verbose,
1801
+ warm_start=warm_start,
1802
+ max_samples=max_samples,
1803
+ )
1804
+
1805
+ self.criterion = criterion
1806
+ self.max_depth = max_depth
1807
+ self.min_samples_split = min_samples_split
1808
+ self.min_samples_leaf = min_samples_leaf
1809
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1810
+ self.max_features = max_features
1811
+ self.max_leaf_nodes = max_leaf_nodes
1812
+ self.min_impurity_decrease = min_impurity_decrease
1813
+ self.ccp_alpha = ccp_alpha
1814
+ self.max_bins = max_bins
1815
+ self.min_bin_size = min_bin_size
1816
+ self.monotonic_cst = monotonic_cst
1817
+
1818
+ elif sklearn_check_version("1.0"):
1819
+
1820
+ def __init__(
1821
+ self,
1822
+ n_estimators=100,
1823
+ *,
1824
+ criterion="squared_error",
1825
+ max_depth=None,
1826
+ min_samples_split=2,
1827
+ min_samples_leaf=1,
1828
+ min_weight_fraction_leaf=0.0,
1829
+ max_features=1.0 if sklearn_check_version("1.1") else "auto",
1830
+ max_leaf_nodes=None,
1831
+ min_impurity_decrease=0.0,
1832
+ bootstrap=False,
1833
+ oob_score=False,
1834
+ n_jobs=None,
1835
+ random_state=None,
1836
+ verbose=0,
1837
+ warm_start=False,
1838
+ ccp_alpha=0.0,
1839
+ max_samples=None,
1840
+ max_bins=256,
1841
+ min_bin_size=1,
1842
+ ):
1843
+ super().__init__(
1844
+ ExtraTreeRegressor(),
1845
+ n_estimators=n_estimators,
1846
+ estimator_params=(
1847
+ "criterion",
1848
+ "max_depth",
1849
+ "min_samples_split",
1850
+ "min_samples_leaf",
1851
+ "min_weight_fraction_leaf",
1852
+ "max_features",
1853
+ "max_leaf_nodes",
1854
+ "min_impurity_decrease",
1855
+ "random_state",
1856
+ "ccp_alpha",
1857
+ ),
1858
+ bootstrap=bootstrap,
1859
+ oob_score=oob_score,
1860
+ n_jobs=n_jobs,
1861
+ random_state=random_state,
1862
+ verbose=verbose,
1863
+ warm_start=warm_start,
1864
+ max_samples=max_samples,
1865
+ )
1866
+
1867
+ self.criterion = criterion
1868
+ self.max_depth = max_depth
1869
+ self.min_samples_split = min_samples_split
1870
+ self.min_samples_leaf = min_samples_leaf
1871
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1872
+ self.max_features = max_features
1873
+ self.max_leaf_nodes = max_leaf_nodes
1874
+ self.min_impurity_decrease = min_impurity_decrease
1875
+ self.ccp_alpha = ccp_alpha
1876
+ self.max_bins = max_bins
1877
+ self.min_bin_size = min_bin_size
1878
+
1879
+ else:
1880
+
1881
+ def __init__(
1882
+ self,
1883
+ n_estimators=100,
1884
+ *,
1885
+ criterion="mse",
1886
+ max_depth=None,
1887
+ min_samples_split=2,
1888
+ min_samples_leaf=1,
1889
+ min_weight_fraction_leaf=0.0,
1890
+ max_features="auto",
1891
+ max_leaf_nodes=None,
1892
+ min_impurity_decrease=0.0,
1893
+ min_impurity_split=None,
1894
+ bootstrap=False,
1895
+ oob_score=False,
1896
+ n_jobs=None,
1897
+ random_state=None,
1898
+ verbose=0,
1899
+ warm_start=False,
1900
+ ccp_alpha=0.0,
1901
+ max_samples=None,
1902
+ max_bins=256,
1903
+ min_bin_size=1,
1904
+ ):
1905
+ super().__init__(
1906
+ ExtraTreeRegressor(),
1907
+ n_estimators=n_estimators,
1908
+ estimator_params=(
1909
+ "criterion",
1910
+ "max_depth",
1911
+ "min_samples_split",
1912
+ "min_samples_leaf",
1913
+ "min_weight_fraction_leaf",
1914
+ "max_features",
1915
+ "max_leaf_nodes",
1916
+ "min_impurity_decrease",
1917
+ "min_impurity_split" "random_state",
1918
+ "ccp_alpha",
1919
+ ),
1920
+ bootstrap=bootstrap,
1921
+ oob_score=oob_score,
1922
+ n_jobs=n_jobs,
1923
+ random_state=random_state,
1924
+ verbose=verbose,
1925
+ warm_start=warm_start,
1926
+ max_samples=max_samples,
1927
+ )
1928
+
1929
+ self.criterion = criterion
1930
+ self.max_depth = max_depth
1931
+ self.min_samples_split = min_samples_split
1932
+ self.min_samples_leaf = min_samples_leaf
1933
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
1934
+ self.max_features = max_features
1935
+ self.max_leaf_nodes = max_leaf_nodes
1936
+ self.min_impurity_decrease = min_impurity_decrease
1937
+ self.min_impurity_split = min_impurity_split
1938
+ self.ccp_alpha = ccp_alpha
1939
+ self.max_bins = max_bins
1940
+ self.min_bin_size = min_bin_size
1941
+
1942
+
1943
+ # Allow for isinstance calls without inheritance changes using ABCMeta
1944
+ sklearn_RandomForestClassifier.register(RandomForestClassifier)
1945
+ sklearn_RandomForestRegressor.register(RandomForestRegressor)
1946
+ sklearn_ExtraTreesClassifier.register(ExtraTreesClassifier)
1947
+ sklearn_ExtraTreesRegressor.register(ExtraTreesRegressor)