scikit-learn-intelex 2023.2.1__py38-none-win_amd64.whl → 2024.0.1__py38-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scikit-learn-intelex might be problematic. Click here for more details.

Files changed (109) hide show
  1. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__init__.py +2 -2
  2. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/__main__.py +16 -12
  3. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_config.py +2 -2
  4. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/_device_offload.py +90 -56
  5. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/_utils.py +95 -0
  6. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/__init__.py +3 -3
  7. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/basic_statistics/basic_statistics.py +2 -2
  8. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/__init__.py +4 -4
  9. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +187 -0
  10. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/k_means.py +2 -2
  11. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_dbscan.py +12 -6
  12. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/cluster/tests/test_kmeans.py +5 -5
  13. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/__init__.py +3 -3
  14. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/pca.py +2 -2
  15. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/decomposition/tests/test_pca.py +5 -4
  16. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/dispatcher.py +102 -72
  17. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/ensemble/__init__.py +12 -4
  18. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/_forest.py +1947 -0
  19. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +118 -0
  20. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/__main__.py +31 -16
  21. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/glob/dispatcher.py +21 -14
  22. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/__init__.py +10 -10
  23. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/coordinate_descent.py +2 -2
  24. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex}/linear_model/linear.py +173 -83
  25. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/logistic_path.py +3 -3
  26. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/ridge.py +2 -2
  27. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_linear.py +23 -7
  28. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/linear_model/tests/test_logreg.py +4 -3
  29. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/__init__.py +3 -3
  30. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/t_sne.py +2 -2
  31. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/manifold/tests/test_tsne.py +4 -3
  32. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/__init__.py +5 -5
  33. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/pairwise.py +2 -2
  34. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/ranking.py +2 -2
  35. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/metrics/tests/test_metrics.py +8 -6
  36. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/__init__.py +3 -3
  37. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/split.py +2 -2
  38. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/model_selection/tests/test_model_selection.py +6 -3
  39. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/__init__.py +9 -5
  40. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/common.py +100 -77
  41. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +331 -0
  42. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +307 -0
  43. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/knn_unsupervised.py +116 -58
  44. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/neighbors/lof.py +118 -56
  45. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +85 -0
  46. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview}/__init__.py +18 -20
  47. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/__init__.py +3 -3
  48. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/_common.py +7 -7
  49. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/cluster/k_means.py +104 -73
  50. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/linear_model/linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/__init__.py +4 -1
  51. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/preview/decomposition/pca.py +128 -100
  52. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/tests/test_preview_linear.py → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +18 -16
  53. {scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model → scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd}/__init__.py +24 -22
  54. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/__init__.py +3 -3
  55. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/basic_statistics/basic_statistics.py +2 -2
  56. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/__init__.py +11 -5
  57. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/spmd/cluster/dbscan.py +50 -0
  58. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/cluster/kmeans.py +2 -2
  59. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/__init__.py +3 -3
  60. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/decomposition/pca.py +2 -2
  61. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/__init__.py +3 -3
  62. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/ensemble/forest.py +16 -14
  63. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/__init__.py +3 -3
  64. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/linear_model/linear_model.py +2 -2
  65. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/__init__.py +3 -3
  66. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/spmd/neighbors/neighbors.py +3 -3
  67. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/__init__.py +11 -8
  68. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/_common.py +56 -56
  69. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvc.py +110 -55
  70. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/nusvr.py +65 -31
  71. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svc.py +136 -78
  72. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/svm/svr.py +65 -31
  73. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +102 -0
  74. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +170 -0
  75. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_config.py +9 -8
  76. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_memory_usage.py +63 -69
  77. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_monkeypatch.py +55 -53
  78. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_parallel.py +50 -0
  79. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/test_patching.py +8 -7
  80. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +428 -0
  81. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/tests/utils/_launch_algorithms.py +39 -39
  82. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/__init__.py +3 -3
  83. scikit_learn_intelex-2024.0.1.data/data/Lib/site-packages/sklearnex/utils/parallel.py +59 -0
  84. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/utils/validation.py +2 -2
  85. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/METADATA +34 -35
  86. scikit_learn_intelex-2024.0.1.dist-info/RECORD +90 -0
  87. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/_utils.py +0 -82
  88. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/cluster/dbscan.py +0 -18
  89. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/__init__.py +0 -20
  90. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/forest.py +0 -18
  91. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/ensemble/tests/test_forest.py +0 -46
  92. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_classification.py +0 -228
  93. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/knn_regression.py +0 -213
  94. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/neighbors/tests/test_neighbors.py +0 -57
  95. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/__init__.py +0 -18
  96. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/decomposition/tests/test_preview_pca.py +0 -28
  97. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/extra_trees.py +0 -1261
  98. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/forest.py +0 -1155
  99. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/ensemble/tests/test_preview_ensemble.py +0 -67
  100. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/preview/linear_model/_common.py +0 -66
  101. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/spmd/__init__.py +0 -23
  102. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/svm/tests/test_svm.py +0 -63
  103. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/_models_info.py +0 -159
  104. scikit_learn_intelex-2023.2.1.data/data/Lib/site-packages/sklearnex/tests/test_run_to_run_stability_tests.py +0 -383
  105. scikit_learn_intelex-2023.2.1.dist-info/RECORD +0 -95
  106. {scikit_learn_intelex-2023.2.1.data → scikit_learn_intelex-2024.0.1.data}/data/Lib/site-packages/sklearnex/doc/third-party-programs.txt +0 -0
  107. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/LICENSE.txt +0 -0
  108. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/WHEEL +0 -0
  109. {scikit_learn_intelex-2023.2.1.dist-info → scikit_learn_intelex-2024.0.1.dist-info}/top_level.txt +0 -0
@@ -1,1261 +0,0 @@
1
- #!/usr/bin/env python
2
- #===============================================================================
3
- # Copyright 2021 Intel Corporation
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- #===============================================================================
17
-
18
- from daal4py.sklearn._utils import (
19
- daal_check_version, sklearn_check_version,
20
- make2d, PatchingConditionsChain, check_tree_nodes
21
- )
22
-
23
- import numpy as np
24
-
25
- import numbers
26
-
27
- import warnings
28
-
29
- from abc import ABC
30
-
31
- from sklearn.exceptions import DataConversionWarning
32
-
33
- from ..._config import get_config
34
- from ..._device_offload import dispatch, wrap_output_data
35
-
36
- from sklearn.ensemble import ExtraTreesClassifier as sklearn_ExtraTreesClassifier
37
- from sklearn.ensemble import ExtraTreesRegressor as sklearn_ExtraTreesRegressor
38
-
39
- from sklearn.utils.validation import (
40
- check_is_fitted,
41
- check_consistent_length,
42
- check_array,
43
- check_X_y)
44
-
45
- from onedal.datatypes import _num_features, _num_samples
46
-
47
- from sklearn.utils import check_random_state, deprecated
48
-
49
- from sklearn.base import clone
50
-
51
- from sklearn.tree import ExtraTreeClassifier, ExtraTreeRegressor
52
- from sklearn.tree._tree import Tree
53
-
54
- from onedal.ensemble import ExtraTreesClassifier as onedal_ExtraTreesClassifier
55
- from onedal.ensemble import ExtraTreesRegressor as onedal_ExtraTreesRegressor
56
- from onedal.primitives import get_tree_state_cls, get_tree_state_reg
57
-
58
- from scipy import sparse as sp
59
-
60
- if sklearn_check_version('1.2'):
61
- from sklearn.utils._param_validation import Interval
62
-
63
-
64
- class BaseTree(ABC):
65
- def _fit_proba(self, X, y, sample_weight=None, queue=None):
66
- params = self.get_params()
67
- self.__class__(**params)
68
-
69
- # We use stock metaestimators below, so the only way
70
- # to pass a queue is using config_context.
71
- cfg = get_config()
72
- cfg['target_offload'] = queue
73
-
74
- def _save_attributes(self):
75
- self._onedal_model = self._onedal_estimator._onedal_model
76
-
77
- if self.oob_score:
78
- self.oob_score_ = self._onedal_estimator.oob_score_
79
- if hasattr(self._onedal_estimator, "oob_prediction_"):
80
- self.oob_prediction_ = self._onedal_estimator.oob_prediction_
81
- if hasattr(self._onedal_estimator, "oob_decision_function_"):
82
- self.oob_decision_function_ = \
83
- self._onedal_estimator.oob_decision_function_
84
- return self
85
-
86
- def _onedal_classifier(self, **onedal_params):
87
- return onedal_ExtraTreesClassifier(**onedal_params)
88
-
89
- def _onedal_regressor(self, **onedal_params):
90
- return onedal_ExtraTreesRegressor(**onedal_params)
91
-
92
- # TODO:
93
- # move to onedal modul.
94
- def _check_parameters(self):
95
-
96
- if isinstance(self.min_samples_leaf, numbers.Integral):
97
- if not 1 <= self.min_samples_leaf:
98
- raise ValueError("min_samples_leaf must be at least 1 "
99
- "or in (0, 0.5], got %s"
100
- % self.min_samples_leaf)
101
- else: # float
102
- if not 0. < self.min_samples_leaf <= 0.5:
103
- raise ValueError("min_samples_leaf must be at least 1 "
104
- "or in (0, 0.5], got %s"
105
- % self.min_samples_leaf)
106
- if isinstance(self.min_samples_split, numbers.Integral):
107
- if not 2 <= self.min_samples_split:
108
- raise ValueError("min_samples_split must be an integer "
109
- "greater than 1 or a float in (0.0, 1.0]; "
110
- "got the integer %s"
111
- % self.min_samples_split)
112
- else: # float
113
- if not 0. < self.min_samples_split <= 1.:
114
- raise ValueError("min_samples_split must be an integer "
115
- "greater than 1 or a float in (0.0, 1.0]; "
116
- "got the float %s"
117
- % self.min_samples_split)
118
- if not 0 <= self.min_weight_fraction_leaf <= 0.5:
119
- raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
120
- if getattr(self, "min_impurity_split", None) is not None:
121
- warnings.warn("The min_impurity_split parameter is deprecated. "
122
- "Its default value has changed from 1e-7 to 0 in "
123
- "version 0.23, and it will be removed in 0.25. "
124
- "Use the min_impurity_decrease parameter instead.",
125
- FutureWarning)
126
-
127
- if getattr(self, "min_impurity_split") < 0.:
128
- raise ValueError("min_impurity_split must be greater than "
129
- "or equal to 0")
130
- if self.min_impurity_decrease < 0.:
131
- raise ValueError("min_impurity_decrease must be greater than "
132
- "or equal to 0")
133
- if self.max_leaf_nodes is not None:
134
- if not isinstance(self.max_leaf_nodes, numbers.Integral):
135
- raise ValueError(
136
- "max_leaf_nodes must be integral number but was "
137
- "%r" %
138
- self.max_leaf_nodes)
139
- if self.max_leaf_nodes < 2:
140
- raise ValueError(
141
- ("max_leaf_nodes {0} must be either None "
142
- "or larger than 1").format(
143
- self.max_leaf_nodes))
144
- if isinstance(self.max_bins, numbers.Integral):
145
- if not 2 <= self.max_bins:
146
- raise ValueError("max_bins must be at least 2, got %s"
147
- % self.max_bins)
148
- else:
149
- raise ValueError("max_bins must be integral number but was "
150
- "%r" % self.max_bins)
151
- if isinstance(self.min_bin_size, numbers.Integral):
152
- if not 1 <= self.min_bin_size:
153
- raise ValueError("min_bin_size must be at least 1, got %s"
154
- % self.min_bin_size)
155
- else:
156
- raise ValueError("min_bin_size must be integral number but was "
157
- "%r" % self.min_bin_size)
158
-
159
- def check_sample_weight(self, sample_weight, X, dtype=None):
160
- n_samples = _num_samples(X)
161
-
162
- if dtype is not None and dtype not in [np.float32, np.float64]:
163
- dtype = np.float64
164
-
165
- if sample_weight is None:
166
- sample_weight = np.ones(n_samples, dtype=dtype)
167
- elif isinstance(sample_weight, numbers.Number):
168
- sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
169
- else:
170
- if dtype is None:
171
- dtype = [np.float64, np.float32]
172
- sample_weight = check_array(
173
- sample_weight,
174
- accept_sparse=False,
175
- ensure_2d=False,
176
- dtype=dtype,
177
- order="C")
178
- if sample_weight.ndim != 1:
179
- raise ValueError("Sample weights must be 1D array or scalar")
180
-
181
- if sample_weight.shape != (n_samples,):
182
- raise ValueError("sample_weight.shape == {}, expected {}!"
183
- .format(sample_weight.shape, (n_samples,)))
184
- return sample_weight
185
-
186
-
187
- class ExtraTreesClassifier(sklearn_ExtraTreesClassifier, BaseTree):
188
- __doc__ = sklearn_ExtraTreesClassifier.__doc__
189
-
190
- if sklearn_check_version('1.2'):
191
- _parameter_constraints: dict = {
192
- **sklearn_ExtraTreesClassifier._parameter_constraints,
193
- "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
194
- "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
195
- }
196
-
197
- if sklearn_check_version('1.0'):
198
- def __init__(
199
- self,
200
- n_estimators=100,
201
- criterion="gini",
202
- max_depth=None,
203
- min_samples_split=2,
204
- min_samples_leaf=1,
205
- min_weight_fraction_leaf=0.,
206
- max_features='sqrt' if sklearn_check_version('1.1') else 'auto',
207
- max_leaf_nodes=None,
208
- min_impurity_decrease=0.,
209
- bootstrap=False,
210
- oob_score=False,
211
- n_jobs=None,
212
- random_state=None,
213
- verbose=0,
214
- warm_start=False,
215
- class_weight=None,
216
- ccp_alpha=0.0,
217
- max_samples=None,
218
- max_bins=256,
219
- min_bin_size=1):
220
- super(ExtraTreesClassifier, self).__init__(
221
- n_estimators=n_estimators,
222
- criterion=criterion,
223
- max_depth=max_depth,
224
- min_samples_split=min_samples_split,
225
- min_samples_leaf=min_samples_leaf,
226
- min_weight_fraction_leaf=min_weight_fraction_leaf,
227
- max_features=max_features,
228
- max_leaf_nodes=max_leaf_nodes,
229
- min_impurity_decrease=min_impurity_decrease,
230
- bootstrap=bootstrap,
231
- oob_score=oob_score,
232
- n_jobs=n_jobs,
233
- random_state=random_state,
234
- verbose=verbose,
235
- warm_start=warm_start,
236
- class_weight=class_weight
237
- )
238
- self.warm_start = warm_start
239
- self.ccp_alpha = ccp_alpha
240
- self.max_samples = max_samples
241
- self.max_bins = max_bins
242
- self.min_bin_size = min_bin_size
243
-
244
- else:
245
- def __init__(self,
246
- n_estimators=100,
247
- criterion="gini",
248
- max_depth=None,
249
- min_samples_split=2,
250
- min_samples_leaf=1,
251
- min_weight_fraction_leaf=0.,
252
- max_features="auto",
253
- max_leaf_nodes=None,
254
- min_impurity_decrease=0.,
255
- min_impurity_split=None,
256
- bootstrap=False,
257
- oob_score=False,
258
- n_jobs=None,
259
- random_state=None,
260
- verbose=0,
261
- warm_start=False,
262
- class_weight=None,
263
- ccp_alpha=0.0,
264
- max_samples=None,
265
- max_bins=256,
266
- min_bin_size=1):
267
- super(ExtraTreesClassifier, self).__init__(
268
- n_estimators=n_estimators,
269
- criterion=criterion,
270
- max_depth=max_depth,
271
- min_samples_split=min_samples_split,
272
- min_samples_leaf=min_samples_leaf,
273
- min_weight_fraction_leaf=min_weight_fraction_leaf,
274
- max_features=max_features,
275
- max_leaf_nodes=max_leaf_nodes,
276
- min_impurity_decrease=min_impurity_decrease,
277
- min_impurity_split=min_impurity_split,
278
- bootstrap=bootstrap,
279
- oob_score=oob_score,
280
- n_jobs=n_jobs,
281
- random_state=random_state,
282
- verbose=verbose,
283
- warm_start=warm_start,
284
- class_weight=class_weight,
285
- ccp_alpha=ccp_alpha,
286
- max_samples=max_samples
287
- )
288
- self.warm_start = warm_start
289
- self.ccp_alpha = ccp_alpha
290
- self.max_samples = max_samples
291
- self.max_bins = max_bins
292
- self.min_bin_size = min_bin_size
293
-
294
- def fit(self, X, y, sample_weight=None):
295
- """
296
- Build a forest of trees from the training set (X, y).
297
-
298
- Parameters
299
- ----------
300
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
301
- The training input samples. Internally, its dtype will be converted
302
- to ``dtype=np.float32``. If a sparse matrix is provided, it will be
303
- converted into a sparse ``csc_matrix``.
304
-
305
- y : array-like of shape (n_samples,) or (n_samples, n_outputs)
306
- The target values (class labels in classification, real numbers in
307
- regression).
308
-
309
- sample_weight : array-like of shape (n_samples,), default=None
310
- Sample weights. If None, then samples are equally weighted. Splits
311
- that would create child nodes with net zero or negative weight are
312
- ignored while searching for a split in each node. In the case of
313
- classification, splits are also ignored if they would result in any
314
- single class carrying a negative weight in either child node.
315
-
316
- Returns
317
- -------
318
- self : object
319
- """
320
- dispatch(self, 'fit', {
321
- 'onedal': self.__class__._onedal_fit,
322
- 'sklearn': sklearn_ExtraTreesClassifier.fit,
323
- }, X, y, sample_weight)
324
- return self
325
-
326
- def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
327
- if sp.issparse(y):
328
- raise ValueError(
329
- "sparse multilabel-indicator for y is not supported."
330
- )
331
-
332
- if sklearn_check_version("1.2"):
333
- self._validate_params()
334
- else:
335
- self._check_parameters()
336
-
337
- if not self.bootstrap and self.oob_score:
338
- raise ValueError("Out of bag estimation only available"
339
- " if bootstrap=True")
340
-
341
- ready = patching_status.and_conditions([
342
- (self.oob_score and daal_check_version((2021, 'P', 500)) or not
343
- self.oob_score,
344
- "OOB score is only supported starting from 2021.5 version of oneDAL."),
345
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
346
- (self.ccp_alpha == 0.0,
347
- f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported."),
348
- (self.criterion == "gini",
349
- f"'{self.criterion}' criterion is not supported. "
350
- "Only 'gini' criterion is supported."),
351
- (self.warm_start is False, "Warm start is not supported."),
352
- (self.n_estimators <= 6024, "More than 6024 estimators is not supported.")
353
- ])
354
-
355
- if ready:
356
- if sklearn_check_version("1.0"):
357
- self._check_feature_names(X, reset=True)
358
- X = check_array(X, dtype=[np.float32, np.float64])
359
- y = np.asarray(y)
360
- y = np.atleast_1d(y)
361
- if y.ndim == 2 and y.shape[1] == 1:
362
- warnings.warn(
363
- "A column-vector y was passed when a 1d array was"
364
- " expected. Please change the shape of y to "
365
- "(n_samples,), for example using ravel().",
366
- DataConversionWarning,
367
- stacklevel=2)
368
- check_consistent_length(X, y)
369
-
370
- y = make2d(y)
371
- self.n_outputs_ = y.shape[1]
372
- ready = patching_status.and_conditions([
373
- (self.n_outputs_ == 1,
374
- f"Number of outputs ({self.n_outputs_}) is not 1."),
375
- (y.dtype in [np.float32, np.float64, np.int32, np.int64],
376
- f"Datatype ({y.dtype}) for y is not supported.")
377
- ])
378
- # TODO: Fix to support integers as input
379
-
380
- n_samples = X.shape[0]
381
- if isinstance(self.max_samples, numbers.Integral):
382
- if not sklearn_check_version('1.2'):
383
- if not (1 <= self.max_samples <= n_samples):
384
- msg = "`max_samples` must be in range 1 to {} but got value {}"
385
- raise ValueError(msg.format(n_samples, self.max_samples))
386
- else:
387
- if self.max_samples > n_samples:
388
- msg = "`max_samples` must be <= n_samples={} but got value {}"
389
- raise ValueError(msg.format(n_samples, self.max_samples))
390
- elif isinstance(self.max_samples, numbers.Real):
391
- if sklearn_check_version('1.2'):
392
- pass
393
- elif sklearn_check_version('1.0'):
394
- if not (0 < float(self.max_samples) <= 1):
395
- msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
396
- raise ValueError(msg.format(self.max_samples))
397
- else:
398
- if not (0 < float(self.max_samples) < 1):
399
- msg = "`max_samples` must be in range (0, 1) but got value {}"
400
- raise ValueError(msg.format(self.max_samples))
401
- elif self.max_samples is not None:
402
- msg = "`max_samples` should be int or float, but got type '{}'"
403
- raise TypeError(msg.format(type(self.max_samples)))
404
-
405
- if not self.bootstrap and self.max_samples is not None:
406
- raise ValueError(
407
- "`max_sample` cannot be set if `bootstrap=False`. "
408
- "Either switch to `bootstrap=True` or set "
409
- "`max_sample=None`."
410
- )
411
-
412
- return ready, X, y, sample_weight
413
-
414
- @wrap_output_data
415
- def predict(self, X):
416
- """
417
- Predict class for X.
418
-
419
- The predicted class of an input sample is a vote by the trees in
420
- the forest, weighted by their probability estimates. That is,
421
- the predicted class is the one with highest mean probability
422
- estimate across the trees.
423
-
424
- Parameters
425
- ----------
426
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
427
- The input samples. Internally, its dtype will be converted to
428
- ``dtype=np.float32``. If a sparse matrix is provided, it will be
429
- converted into a sparse ``csr_matrix``.
430
-
431
- Returns
432
- -------
433
- y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
434
- The predicted classes.
435
- """
436
- return dispatch(self, 'predict', {
437
- 'onedal': self.__class__._onedal_predict,
438
- 'sklearn': sklearn_ExtraTreesClassifier.predict,
439
- }, X)
440
-
441
- @wrap_output_data
442
- def predict_proba(self, X):
443
- """
444
- Predict class probabilities for X.
445
-
446
- The predicted class probabilities of an input sample are computed as
447
- the mean predicted class probabilities of the trees in the forest.
448
- The class probability of a single tree is the fraction of samples of
449
- the same class in a leaf.
450
-
451
- Parameters
452
- ----------
453
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
454
- The input samples. Internally, its dtype will be converted to
455
- ``dtype=np.float32``. If a sparse matrix is provided, it will be
456
- converted into a sparse ``csr_matrix``.
457
-
458
- Returns
459
- -------
460
- p : ndarray of shape (n_samples, n_classes), or a list of n_outputs
461
- such arrays if n_outputs > 1.
462
- The class probabilities of the input samples. The order of the
463
- classes corresponds to that in the attribute :term:`classes_`.
464
- """
465
- # TODO:
466
- # _check_proba()
467
- # self._check_proba()
468
- if sklearn_check_version("1.0"):
469
- self._check_feature_names(X, reset=False)
470
- if hasattr(self, 'n_features_in_'):
471
- try:
472
- num_features = _num_features(X)
473
- except TypeError:
474
- num_features = _num_samples(X)
475
- if num_features != self.n_features_in_:
476
- raise ValueError(
477
- (f'X has {num_features} features, '
478
- f'but ExtraTreesClassifier is expecting '
479
- f'{self.n_features_in_} features as input'))
480
- return dispatch(self, 'predict_proba', {
481
- 'onedal': self.__class__._onedal_predict_proba,
482
- 'sklearn': sklearn_ExtraTreesClassifier.predict_proba,
483
- }, X)
484
-
485
- if sklearn_check_version('1.0'):
486
- @deprecated(
487
- "Attribute `n_features_` was deprecated in version 1.0 and will be "
488
- "removed in 1.2. Use `n_features_in_` instead.")
489
- @property
490
- def n_features_(self):
491
- return self.n_features_in_
492
-
493
- @property
494
- def _estimators_(self):
495
- if hasattr(self, '_cached_estimators_'):
496
- if self._cached_estimators_:
497
- return self._cached_estimators_
498
- if sklearn_check_version('0.22'):
499
- check_is_fitted(self)
500
- else:
501
- check_is_fitted(self, '_onedal_model')
502
- classes_ = self.classes_[0]
503
- n_classes_ = self.n_classes_[0]
504
- # convert model to estimators
505
- params = {
506
- 'criterion': self.criterion,
507
- 'max_depth': self.max_depth,
508
- 'min_samples_split': self.min_samples_split,
509
- 'min_samples_leaf': self.min_samples_leaf,
510
- 'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
511
- 'max_features': self.max_features,
512
- 'max_leaf_nodes': self.max_leaf_nodes,
513
- 'min_impurity_decrease': self.min_impurity_decrease,
514
- 'random_state': None,
515
- }
516
- if not sklearn_check_version('1.0'):
517
- params['min_impurity_split'] = self.min_impurity_split
518
- est = ExtraTreeClassifier(**params)
519
- # we need to set est.tree_ field with Trees constructed from Intel(R)
520
- # oneAPI Data Analytics Library solution
521
- estimators_ = []
522
- random_state_checked = check_random_state(self.random_state)
523
- for i in range(self.n_estimators):
524
- est_i = clone(est)
525
- est_i.set_params(
526
- random_state=random_state_checked.randint(
527
- np.iinfo(
528
- np.int32).max))
529
- if sklearn_check_version('1.0'):
530
- est_i.n_features_in_ = self.n_features_in_
531
- else:
532
- est_i.n_features_ = self.n_features_in_
533
- est_i.n_outputs_ = self.n_outputs_
534
- est_i.classes_ = classes_
535
- est_i.n_classes_ = n_classes_
536
- tree_i_state_class = get_tree_state_cls(
537
- self._onedal_model, i, n_classes_)
538
- tree_i_state_dict = {
539
- 'max_depth': tree_i_state_class.max_depth,
540
- 'node_count': tree_i_state_class.node_count,
541
- 'nodes': check_tree_nodes(tree_i_state_class.node_ar),
542
- 'values': tree_i_state_class.value_ar}
543
- est_i.tree_ = Tree(
544
- self.n_features_in_,
545
- np.array(
546
- [n_classes_],
547
- dtype=np.intp),
548
- self.n_outputs_)
549
- est_i.tree_.__setstate__(tree_i_state_dict)
550
- estimators_.append(est_i)
551
-
552
- self._cached_estimators_ = estimators_
553
- return estimators_
554
-
555
- def _onedal_cpu_supported(self, method_name, *data):
556
- class_name = self.__class__.__name__
557
- _patching_status = PatchingConditionsChain(
558
- f'sklearn.ensemble.{class_name}.{method_name}')
559
-
560
- if method_name == 'fit':
561
- ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
562
-
563
- dal_ready = ready and _patching_status.and_conditions([
564
- (daal_check_version((2023, 'P', 200)),
565
- "ExtraTrees only supported starting from oneDAL version 2023.2"),
566
- (not sp.issparse(sample_weight), "sample_weight is sparse. "
567
- "Sparse input is not supported."),
568
- ])
569
-
570
- dal_ready = dal_ready and not hasattr(self, 'estimators_')
571
-
572
- if dal_ready and (self.random_state is not None):
573
- warnings.warn("Setting 'random_state' value is not supported. "
574
- "State set by oneDAL to default value (777).",
575
- RuntimeWarning)
576
-
577
- elif method_name in ['predict',
578
- 'predict_proba']:
579
-
580
- X = data[0]
581
-
582
- dal_ready = _patching_status.and_conditions([
583
- (hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
584
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
585
- (self.warm_start is False, "Warm start is not supported."),
586
- (daal_check_version((2023, 'P', 100)),
587
- "ExtraTrees only supported starting from oneDAL version 2023.2")
588
- ])
589
- if hasattr(self, 'n_outputs_'):
590
- dal_ready = dal_ready and _patching_status.and_conditions([
591
- (self.n_outputs_ == 1,
592
- f"Number of outputs ({self.n_outputs_}) is not 1."),
593
- ])
594
- else:
595
- dal_ready = False
596
-
597
- else:
598
- raise RuntimeError(
599
- f'Unknown method {method_name} in {self.__class__.__name__}')
600
-
601
- _patching_status.write_log()
602
- return dal_ready
603
-
604
- def _onedal_gpu_supported(self, method_name, *data):
605
- class_name = self.__class__.__name__
606
- _patching_status = PatchingConditionsChain(
607
- f'sklearn.ensemble.{class_name}.{method_name}')
608
-
609
- if method_name == 'fit':
610
- ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
611
-
612
- dal_ready = ready and _patching_status.and_conditions([
613
- (daal_check_version((2023, 'P', 100)),
614
- "ExtraTrees only supported starting from oneDAL version 2023.1"),
615
- (sample_weight is not None, "sample_weight is not supported.")
616
- ])
617
-
618
- dal_ready &= not hasattr(self, 'estimators_')
619
-
620
- if dal_ready and (self.random_state is not None):
621
- warnings.warn("Setting 'random_state' value is not supported. "
622
- "State set by oneDAL to default value (777).",
623
- RuntimeWarning)
624
-
625
- elif method_name in ['predict',
626
- 'predict_proba']:
627
-
628
- X = data[0]
629
-
630
- dal_ready = hasattr(self, '_onedal_model') and hasattr(self, 'n_outputs_')
631
-
632
- if dal_ready:
633
- dal_ready = _patching_status.and_conditions([
634
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
635
- (self.warm_start is False, "Warm start is not supported."),
636
- (daal_check_version((2023, 'P', 100)),
637
- "ExtraTrees only supported starting from oneDAL version 2023.1")
638
- ])
639
-
640
- if hasattr(self, 'n_outputs_'):
641
- dal_ready &= _patching_status.and_conditions([
642
- (self.n_outputs_ == 1,
643
- f"Number of outputs ({self.n_outputs_}) is not 1."),
644
- ])
645
-
646
- else:
647
- raise RuntimeError(
648
- f'Unknown method {method_name} in {self.__class__.__name__}')
649
-
650
- _patching_status.write_log()
651
- return dal_ready
652
-
653
- def _onedal_fit(self, X, y, sample_weight=None, queue=None):
654
- if sklearn_check_version('1.2'):
655
- X, y = self._validate_data(
656
- X, y, multi_output=False, accept_sparse=False,
657
- dtype=[np.float64, np.float32]
658
- )
659
- else:
660
- X, y = check_X_y(
661
- X, y, accept_sparse=False, dtype=[np.float64, np.float32],
662
- multi_output=False
663
- )
664
-
665
- if sample_weight is not None:
666
- sample_weight = self.check_sample_weight(sample_weight, X)
667
-
668
- y = np.atleast_1d(y)
669
- if y.ndim == 2 and y.shape[1] == 1:
670
- warnings.warn(
671
- "A column-vector y was passed when a 1d array was"
672
- " expected. Please change the shape of y to "
673
- "(n_samples,), for example using ravel().",
674
- DataConversionWarning,
675
- stacklevel=2,
676
- )
677
- if y.ndim == 1:
678
- # reshape is necessary to preserve the data contiguity against vs
679
- # [:, np.newaxis] that does not.
680
- y = np.reshape(y, (-1, 1))
681
-
682
- y, expanded_class_weight = self._validate_y_class_weight(y)
683
-
684
- n_classes_ = self.n_classes_[0]
685
- self.n_features_in_ = X.shape[1]
686
- if not sklearn_check_version('1.0'):
687
- self.n_features_ = self.n_features_in_
688
-
689
- if expanded_class_weight is not None:
690
- if sample_weight is not None:
691
- sample_weight = sample_weight * expanded_class_weight
692
- else:
693
- sample_weight = expanded_class_weight
694
- if sample_weight is not None:
695
- sample_weight = [sample_weight]
696
-
697
- if n_classes_ < 2:
698
- raise ValueError(
699
- "Training data only contain information about one class.")
700
-
701
- if self.oob_score:
702
- err = 'out_of_bag_error_accuracy|out_of_bag_error_decision_function'
703
- else:
704
- err = 'none'
705
-
706
- onedal_params = {
707
- 'n_estimators': self.n_estimators,
708
- 'criterion': self.criterion,
709
- 'max_depth': self.max_depth,
710
- 'min_samples_split': self.min_samples_split,
711
- 'min_samples_leaf': self.min_samples_leaf,
712
- 'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
713
- 'max_features': self.max_features,
714
- 'max_leaf_nodes': self.max_leaf_nodes,
715
- 'min_impurity_decrease': self.min_impurity_decrease,
716
- 'bootstrap': self.bootstrap,
717
- 'oob_score': self.oob_score,
718
- 'n_jobs': self.n_jobs,
719
- 'random_state': self.random_state,
720
- 'verbose': self.verbose,
721
- 'warm_start': self.warm_start,
722
- 'error_metric_mode': err,
723
- 'variable_importance_mode': 'mdi',
724
- 'class_weight': self.class_weight,
725
- 'max_bins': self.max_bins,
726
- 'min_bin_size': self.min_bin_size,
727
- 'max_samples': self.max_samples
728
- }
729
- if daal_check_version((2023, 'P', 101)):
730
- onedal_params['splitter_mode'] = "random"
731
- if not sklearn_check_version('1.0'):
732
- onedal_params['min_impurity_split'] = self.min_impurity_split
733
- else:
734
- onedal_params['min_impurity_split'] = None
735
- self._cached_estimators_ = None
736
-
737
- # Compute
738
- self._onedal_estimator = self._onedal_classifier(**onedal_params)
739
- self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
740
-
741
- self._save_attributes()
742
- if sklearn_check_version("1.2"):
743
- self._estimator = ExtraTreeClassifier()
744
- self.estimators_ = self._estimators_
745
- # Decapsulate classes_ attributes
746
- self.n_classes_ = self.n_classes_[0]
747
- self.classes_ = self.classes_[0]
748
- return self
749
-
750
- def _onedal_predict(self, X, queue=None):
751
- X = check_array(X, dtype=[np.float32, np.float64])
752
- check_is_fitted(self)
753
- if sklearn_check_version("1.0"):
754
- self._check_feature_names(X, reset=False)
755
-
756
- res = self._onedal_estimator.predict(X, queue=queue)
757
- return np.take(self.classes_,
758
- res.ravel().astype(np.int64, casting='unsafe'))
759
-
760
- def _onedal_predict_proba(self, X, queue=None):
761
- X = check_array(X, dtype=[np.float64, np.float32])
762
- check_is_fitted(self)
763
- if sklearn_check_version('0.23'):
764
- self._check_n_features(X, reset=False)
765
- if sklearn_check_version("1.0"):
766
- self._check_feature_names(X, reset=False)
767
- return self._onedal_estimator.predict_proba(X, queue=queue)
768
-
769
-
770
- class ExtraTreesRegressor(sklearn_ExtraTreesRegressor, BaseTree):
771
- __doc__ = sklearn_ExtraTreesRegressor.__doc__
772
-
773
- if sklearn_check_version('1.2'):
774
- _parameter_constraints: dict = {
775
- **sklearn_ExtraTreesRegressor._parameter_constraints,
776
- "max_bins": [Interval(numbers.Integral, 2, None, closed="left")],
777
- "min_bin_size": [Interval(numbers.Integral, 1, None, closed="left")]
778
- }
779
-
780
- if sklearn_check_version('1.0'):
781
- def __init__(
782
- self,
783
- n_estimators=100,
784
- *,
785
- criterion="squared_error",
786
- max_depth=None,
787
- min_samples_split=2,
788
- min_samples_leaf=1,
789
- min_weight_fraction_leaf=0.,
790
- max_features=1.0 if sklearn_check_version('1.1') else 'auto',
791
- max_leaf_nodes=None,
792
- min_impurity_decrease=0.,
793
- bootstrap=False,
794
- oob_score=False,
795
- n_jobs=None,
796
- random_state=None,
797
- verbose=0,
798
- warm_start=False,
799
- ccp_alpha=0.0,
800
- max_samples=None,
801
- max_bins=256,
802
- min_bin_size=1):
803
- super(ExtraTreesRegressor, self).__init__(
804
- n_estimators=n_estimators,
805
- criterion=criterion,
806
- max_depth=max_depth,
807
- min_samples_split=min_samples_split,
808
- min_samples_leaf=min_samples_leaf,
809
- min_weight_fraction_leaf=min_weight_fraction_leaf,
810
- max_features=max_features,
811
- max_leaf_nodes=max_leaf_nodes,
812
- min_impurity_decrease=min_impurity_decrease,
813
- bootstrap=bootstrap,
814
- oob_score=oob_score,
815
- n_jobs=n_jobs,
816
- random_state=random_state,
817
- verbose=verbose,
818
- warm_start=warm_start
819
- )
820
- self.warm_start = warm_start
821
- self.ccp_alpha = ccp_alpha
822
- self.max_samples = max_samples
823
- self.max_bins = max_bins
824
- self.min_bin_size = min_bin_size
825
- else:
826
- def __init__(self,
827
- n_estimators=100, *,
828
- criterion="mse",
829
- max_depth=None,
830
- min_samples_split=2,
831
- min_samples_leaf=1,
832
- min_weight_fraction_leaf=0.,
833
- max_features="auto",
834
- max_leaf_nodes=None,
835
- min_impurity_decrease=0.,
836
- min_impurity_split=None,
837
- bootstrap=False,
838
- oob_score=False,
839
- n_jobs=None,
840
- random_state=None,
841
- verbose=0,
842
- warm_start=False,
843
- ccp_alpha=0.0,
844
- max_samples=None,
845
- max_bins=256,
846
- min_bin_size=1
847
- ):
848
- super(ExtraTreesRegressor, self).__init__(
849
- n_estimators=n_estimators,
850
- criterion=criterion,
851
- max_depth=max_depth,
852
- min_samples_split=min_samples_split,
853
- min_samples_leaf=min_samples_leaf,
854
- min_weight_fraction_leaf=min_weight_fraction_leaf,
855
- max_features=max_features,
856
- max_leaf_nodes=max_leaf_nodes,
857
- min_impurity_decrease=min_impurity_decrease,
858
- min_impurity_split=min_impurity_split,
859
- bootstrap=bootstrap,
860
- oob_score=oob_score,
861
- n_jobs=n_jobs,
862
- random_state=random_state,
863
- verbose=verbose,
864
- warm_start=warm_start,
865
- ccp_alpha=ccp_alpha,
866
- max_samples=max_samples
867
- )
868
- self.warm_start = warm_start
869
- self.ccp_alpha = ccp_alpha
870
- self.max_samples = max_samples
871
- self.max_bins = max_bins
872
- self.min_bin_size = min_bin_size
873
-
874
- @property
875
- def _estimators_(self):
876
- if hasattr(self, '_cached_estimators_'):
877
- if self._cached_estimators_:
878
- return self._cached_estimators_
879
- if sklearn_check_version('0.22'):
880
- check_is_fitted(self)
881
- else:
882
- check_is_fitted(self, '_onedal_model')
883
- # convert model to estimators
884
- params = {
885
- 'criterion': self.criterion,
886
- 'max_depth': self.max_depth,
887
- 'min_samples_split': self.min_samples_split,
888
- 'min_samples_leaf': self.min_samples_leaf,
889
- 'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
890
- 'max_features': self.max_features,
891
- 'max_leaf_nodes': self.max_leaf_nodes,
892
- 'min_impurity_decrease': self.min_impurity_decrease,
893
- 'random_state': None,
894
- }
895
- if not sklearn_check_version('1.0'):
896
- params['min_impurity_split'] = self.min_impurity_split
897
- est = ExtraTreeRegressor(**params)
898
- # we need to set est.tree_ field with Trees constructed from Intel(R)
899
- # oneAPI Data Analytics Library solution
900
- estimators_ = []
901
- random_state_checked = check_random_state(self.random_state)
902
- for i in range(self.n_estimators):
903
- est_i = clone(est)
904
- est_i.set_params(
905
- random_state=random_state_checked.randint(
906
- np.iinfo(
907
- np.int32).max))
908
- if sklearn_check_version('1.0'):
909
- est_i.n_features_in_ = self.n_features_in_
910
- else:
911
- est_i.n_features_ = self.n_features_in_
912
- est_i.n_classes_ = 1
913
- est_i.n_outputs_ = self.n_outputs_
914
- tree_i_state_class = get_tree_state_reg(
915
- self._onedal_model, i)
916
- tree_i_state_dict = {
917
- 'max_depth': tree_i_state_class.max_depth,
918
- 'node_count': tree_i_state_class.node_count,
919
- 'nodes': check_tree_nodes(tree_i_state_class.node_ar),
920
- 'values': tree_i_state_class.value_ar}
921
-
922
- est_i.tree_ = Tree(
923
- self.n_features_in_, np.array(
924
- [1], dtype=np.intp), self.n_outputs_)
925
- est_i.tree_.__setstate__(tree_i_state_dict)
926
- estimators_.append(est_i)
927
-
928
- return estimators_
929
-
930
- def _onedal_fit_ready(self, patching_status, X, y, sample_weight):
931
- if sp.issparse(y):
932
- raise ValueError(
933
- "sparse multilabel-indicator for y is not supported."
934
- )
935
-
936
- if sklearn_check_version("1.2"):
937
- self._validate_params()
938
- else:
939
- self._check_parameters()
940
-
941
- if not self.bootstrap and self.oob_score:
942
- raise ValueError("Out of bag estimation only available"
943
- " if bootstrap=True")
944
-
945
- if sklearn_check_version('1.0') and self.criterion == "mse":
946
- warnings.warn(
947
- "Criterion 'mse' was deprecated in v1.0 and will be "
948
- "removed in version 1.2. Use `criterion='squared_error'` "
949
- "which is equivalent.",
950
- FutureWarning
951
- )
952
-
953
- ready = patching_status.and_conditions([
954
- (self.oob_score and daal_check_version((2021, 'P', 500)) or not
955
- self.oob_score,
956
- "OOB score is only supported starting from 2021.5 version of oneDAL."),
957
- (self.warm_start is False, "Warm start is not supported."),
958
- (self.criterion in ["mse", "squared_error"],
959
- f"'{self.criterion}' criterion is not supported. "
960
- "Only 'mse' and 'squared_error' criteria are supported."),
961
- (self.ccp_alpha == 0.0,
962
- f"Non-zero 'ccp_alpha' ({self.ccp_alpha}) is not supported."),
963
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
964
- (self.n_estimators <= 6024, "More than 6024 estimators is not supported.")
965
- ])
966
-
967
- if ready:
968
- if sklearn_check_version("1.0"):
969
- self._check_feature_names(X, reset=True)
970
- X = check_array(X, dtype=[np.float64, np.float32])
971
- y = np.asarray(y)
972
- y = np.atleast_1d(y)
973
-
974
- if y.ndim == 2 and y.shape[1] == 1:
975
- warnings.warn("A column-vector y was passed when a 1d array was"
976
- " expected. Please change the shape of y to "
977
- "(n_samples,), for example using ravel().",
978
- DataConversionWarning, stacklevel=2)
979
-
980
- y = check_array(y, ensure_2d=False, dtype=X.dtype)
981
- check_consistent_length(X, y)
982
-
983
- if y.ndim == 1:
984
- # reshape is necessary to preserve the data contiguity against vs
985
- # [:, np.newaxis] that does not.
986
- y = np.reshape(y, (-1, 1))
987
-
988
- self.n_outputs_ = y.shape[1]
989
- ready = patching_status.and_conditions([
990
- (self.n_outputs_ == 1,
991
- f"Number of outputs ({self.n_outputs_}) is not 1.")
992
- ])
993
-
994
- n_samples = X.shape[0]
995
- if isinstance(self.max_samples, numbers.Integral):
996
- if not sklearn_check_version('1.2'):
997
- if not (1 <= self.max_samples <= n_samples):
998
- msg = "`max_samples` must be in range 1 to {} but got value {}"
999
- raise ValueError(msg.format(n_samples, self.max_samples))
1000
- else:
1001
- if self.max_samples > n_samples:
1002
- msg = "`max_samples` must be <= n_samples={} but got value {}"
1003
- raise ValueError(msg.format(n_samples, self.max_samples))
1004
- elif isinstance(self.max_samples, numbers.Real):
1005
- if sklearn_check_version('1.2'):
1006
- pass
1007
- elif sklearn_check_version('1.0'):
1008
- if not (0 < float(self.max_samples) <= 1):
1009
- msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
1010
- raise ValueError(msg.format(self.max_samples))
1011
- else:
1012
- if not (0 < float(self.max_samples) < 1):
1013
- msg = "`max_samples` must be in range (0, 1) but got value {}"
1014
- raise ValueError(msg.format(self.max_samples))
1015
- elif self.max_samples is not None:
1016
- msg = "`max_samples` should be int or float, but got type '{}'"
1017
- raise TypeError(msg.format(type(self.max_samples)))
1018
-
1019
- if not self.bootstrap and self.max_samples is not None:
1020
- raise ValueError(
1021
- "`max_sample` cannot be set if `bootstrap=False`. "
1022
- "Either switch to `bootstrap=True` or set "
1023
- "`max_sample=None`."
1024
- )
1025
-
1026
- return ready, X, y, sample_weight
1027
-
1028
- def _onedal_cpu_supported(self, method_name, *data):
1029
- class_name = self.__class__.__name__
1030
- _patching_status = PatchingConditionsChain(
1031
- f'sklearn.ensemble.{class_name}.{method_name}')
1032
-
1033
- if method_name == 'fit':
1034
- ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
1035
-
1036
- dal_ready = ready and _patching_status.and_conditions([
1037
- (daal_check_version((2023, 'P', 200)),
1038
- "ExtraTrees only supported starting from oneDAL version 2023.2"),
1039
- (not sp.issparse(sample_weight), "sample_weight is sparse. "
1040
- "Sparse input is not supported."),
1041
- ])
1042
-
1043
- dal_ready &= not hasattr(self, 'estimators_')
1044
-
1045
- if dal_ready and (self.random_state is not None):
1046
- warnings.warn("Setting 'random_state' value is not supported. "
1047
- "State set by oneDAL to default value (777).",
1048
- RuntimeWarning)
1049
-
1050
- elif method_name in ['predict',
1051
- 'predict_proba']:
1052
-
1053
- X = data[0]
1054
-
1055
- dal_ready = _patching_status.and_conditions([
1056
- (hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
1057
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1058
- (self.warm_start is False, "Warm start is not supported."),
1059
- (daal_check_version((2023, 'P', 200)),
1060
- "ExtraTrees only supported starting from oneDAL version 2023.2")
1061
- ])
1062
- if hasattr(self, 'n_outputs_'):
1063
- dal_ready &= _patching_status.and_conditions([
1064
- (self.n_outputs_ == 1,
1065
- f"Number of outputs ({self.n_outputs_}) is not 1."),
1066
- ])
1067
- else:
1068
- dal_ready = False
1069
-
1070
- else:
1071
- raise RuntimeError(
1072
- f'Unknown method {method_name} in {self.__class__.__name__}')
1073
-
1074
- _patching_status.write_log()
1075
- return dal_ready
1076
-
1077
- def _onedal_gpu_supported(self, method_name, *data):
1078
- class_name = self.__class__.__name__
1079
- _patching_status = PatchingConditionsChain(
1080
- f'sklearn.ensemble.{class_name}.{method_name}')
1081
-
1082
- if method_name == 'fit':
1083
- ready, X, y, sample_weight = self._onedal_fit_ready(_patching_status, *data)
1084
-
1085
- dal_ready = ready and _patching_status.and_conditions([
1086
- (daal_check_version((2023, 'P', 100)),
1087
- "ExtraTrees only supported starting from oneDAL version 2023.1"),
1088
- (sample_weight is not None, "sample_weight is not supported."),
1089
- ])
1090
-
1091
- dal_ready &= not hasattr(self, 'estimators_')
1092
-
1093
- if dal_ready and (self.random_state is not None):
1094
- warnings.warn("Setting 'random_state' value is not supported. "
1095
- "State set by oneDAL to default value (777).",
1096
- RuntimeWarning)
1097
-
1098
- elif method_name in ['predict',
1099
- 'predict_proba']:
1100
-
1101
- X = data[0]
1102
-
1103
- dal_ready = _patching_status.and_conditions([
1104
- (hasattr(self, '_onedal_model'), "oneDAL model was not trained."),
1105
- (not sp.issparse(X), "X is sparse. Sparse input is not supported."),
1106
- (self.warm_start is False, "Warm start is not supported."),
1107
-
1108
- (daal_check_version((2023, 'P', 100)),
1109
- "ExtraTrees only supported starting from oneDAL version 2023.1")
1110
- ])
1111
- if hasattr(self, 'n_outputs_'):
1112
- dal_ready &= _patching_status.and_conditions([
1113
- (self.n_outputs_ == 1,
1114
- f"Number of outputs ({self.n_outputs_}) is not 1."),
1115
- ])
1116
-
1117
- else:
1118
- raise RuntimeError(
1119
- f'Unknown method {method_name} in {self.__class__.__name__}')
1120
-
1121
- _patching_status.write_log()
1122
- return dal_ready
1123
-
1124
- def _onedal_fit(self, X, y, sample_weight=None, queue=None):
1125
- if sp.issparse(y):
1126
- raise ValueError(
1127
- "sparse multilabel-indicator for y is not supported."
1128
- )
1129
- if sklearn_check_version("1.2"):
1130
- self._validate_params()
1131
- else:
1132
- self._check_parameters()
1133
- if sample_weight is not None:
1134
- sample_weight = self.check_sample_weight(sample_weight, X)
1135
- if sklearn_check_version("1.0"):
1136
- self._check_feature_names(X, reset=True)
1137
- X = check_array(X, dtype=[np.float64, np.float32])
1138
- y = np.atleast_1d(np.asarray(y))
1139
- if y.ndim == 2 and y.shape[1] == 1:
1140
- warnings.warn(
1141
- "A column-vector y was passed when a 1d array was"
1142
- " expected. Please change the shape of y to "
1143
- "(n_samples,), for example using ravel().",
1144
- DataConversionWarning,
1145
- stacklevel=2)
1146
- y = check_array(y, ensure_2d=False, dtype=X.dtype)
1147
- check_consistent_length(X, y)
1148
- self.n_features_in_ = X.shape[1]
1149
- if not sklearn_check_version('1.0'):
1150
- self.n_features_ = self.n_features_in_
1151
- rs_ = check_random_state(self.random_state)
1152
-
1153
- if self.oob_score:
1154
- err = 'out_of_bag_error_r2|out_of_bag_error_prediction'
1155
- else:
1156
- err = 'none'
1157
-
1158
- onedal_params = {
1159
- 'n_estimators': self.n_estimators,
1160
- 'criterion': self.criterion,
1161
- 'max_depth': self.max_depth,
1162
- 'min_samples_split': self.min_samples_split,
1163
- 'min_samples_leaf': self.min_samples_leaf,
1164
- 'min_weight_fraction_leaf': self.min_weight_fraction_leaf,
1165
- 'max_features': self.max_features,
1166
- 'max_leaf_nodes': self.max_leaf_nodes,
1167
- 'min_impurity_decrease': self.min_impurity_decrease,
1168
- 'bootstrap': self.bootstrap,
1169
- 'oob_score': self.oob_score,
1170
- 'n_jobs': self.n_jobs,
1171
- 'random_state': rs_,
1172
- 'verbose': self.verbose,
1173
- 'warm_start': self.warm_start,
1174
- 'error_metric_mode': err,
1175
- 'variable_importance_mode': 'mdi',
1176
- 'max_samples': self.max_samples
1177
- }
1178
- if daal_check_version((2023, 'P', 101)):
1179
- onedal_params['splitter_mode'] = "random"
1180
- self._cached_estimators_ = None
1181
- self._onedal_estimator = self._onedal_regressor(**onedal_params)
1182
- self._onedal_estimator.fit(X, y, sample_weight, queue=queue)
1183
-
1184
- self._save_attributes()
1185
- if sklearn_check_version("1.2"):
1186
- self._estimator = ExtraTreeRegressor()
1187
- self.estimators_ = self._estimators_
1188
- return self
1189
-
1190
- def _onedal_predict(self, X, queue=None):
1191
- if sklearn_check_version("1.0"):
1192
- self._check_feature_names(X, reset=False)
1193
- X = self._validate_X_predict(X)
1194
- return self._onedal_estimator.predict(X, queue=queue)
1195
-
1196
- def fit(self, X, y, sample_weight=None):
1197
- """
1198
- Build a forest of trees from the training set (X, y).
1199
-
1200
- Parameters
1201
- ----------
1202
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
1203
- The training input samples. Internally, its dtype will be converted
1204
- to ``dtype=np.float32``. If a sparse matrix is provided, it will be
1205
- converted into a sparse ``csc_matrix``.
1206
-
1207
- y : array-like of shape (n_samples,) or (n_samples, n_outputs)
1208
- The target values (class labels in classification, real numbers in
1209
- regression).
1210
-
1211
- sample_weight : array-like of shape (n_samples,), default=None
1212
- Sample weights. If None, then samples are equally weighted. Splits
1213
- that would create child nodes with net zero or negative weight are
1214
- ignored while searching for a split in each node. In the case of
1215
- classification, splits are also ignored if they would result in any
1216
- single class carrying a negative weight in either child node.
1217
-
1218
- Returns
1219
- -------
1220
- self : object
1221
- """
1222
- dispatch(self, 'fit', {
1223
- 'onedal': self.__class__._onedal_fit,
1224
- 'sklearn': sklearn_ExtraTreesRegressor.fit,
1225
- }, X, y, sample_weight)
1226
- return self
1227
-
1228
- @wrap_output_data
1229
- def predict(self, X):
1230
- """
1231
- Predict class for X.
1232
-
1233
- The predicted class of an input sample is a vote by the trees in
1234
- the forest, weighted by their probability estimates. That is,
1235
- the predicted class is the one with highest mean probability
1236
- estimate across the trees.
1237
-
1238
- Parameters
1239
- ----------
1240
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
1241
- The input samples. Internally, its dtype will be converted to
1242
- ``dtype=np.float32``. If a sparse matrix is provided, it will be
1243
- converted into a sparse ``csr_matrix``.
1244
-
1245
- Returns
1246
- -------
1247
- y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
1248
- The predicted classes.
1249
- """
1250
- return dispatch(self, 'predict', {
1251
- 'onedal': self.__class__._onedal_predict,
1252
- 'sklearn': sklearn_ExtraTreesRegressor.predict,
1253
- }, X)
1254
-
1255
- if sklearn_check_version('1.0'):
1256
- @deprecated(
1257
- "Attribute `n_features_` was deprecated in version 1.0 and will be "
1258
- "removed in 1.2. Use `n_features_in_` instead.")
1259
- @property
1260
- def n_features_(self):
1261
- return self.n_features_in_