RuleTree 0.0.2.post1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ruletree-0.0.2.post1/LICENSE +23 -0
  2. ruletree-0.0.2.post1/MANIFEST.in +3 -0
  3. ruletree-0.0.2.post1/PKG-INFO +40 -0
  4. ruletree-0.0.2.post1/README.md +2 -0
  5. ruletree-0.0.2.post1/RuleTree/__init__.py +9 -0
  6. ruletree-0.0.2.post1/RuleTree/base/RuleTreeBase.py +7 -0
  7. ruletree-0.0.2.post1/RuleTree/base/RuleTreeBaseSplit.py +7 -0
  8. ruletree-0.0.2.post1/RuleTree/base/RuleTreeBaseStump.py +24 -0
  9. ruletree-0.0.2.post1/RuleTree/base/__init__.py +0 -0
  10. ruletree-0.0.2.post1/RuleTree/ensemble/RuleForestClassifier.py +233 -0
  11. ruletree-0.0.2.post1/RuleTree/ensemble/RuleForestRegressor.py +103 -0
  12. ruletree-0.0.2.post1/RuleTree/ensemble/RuleTreeAdaBoostClassifier.py +61 -0
  13. ruletree-0.0.2.post1/RuleTree/ensemble/RuleTreeAdaBoostRegressor.py +54 -0
  14. ruletree-0.0.2.post1/RuleTree/ensemble/__init__.py +0 -0
  15. ruletree-0.0.2.post1/RuleTree/stumps/__init__.py +0 -0
  16. ruletree-0.0.2.post1/RuleTree/stumps/classification/DecisionTreeStumpClassifier.py +236 -0
  17. ruletree-0.0.2.post1/RuleTree/stumps/classification/MultiplePivotTreeStumpClassifier.py +116 -0
  18. ruletree-0.0.2.post1/RuleTree/stumps/classification/ObliqueDecisionTreeStumpClassifier.py +135 -0
  19. ruletree-0.0.2.post1/RuleTree/stumps/classification/ObliquePivotTreeStumpClassifier.py +84 -0
  20. ruletree-0.0.2.post1/RuleTree/stumps/classification/PivotTreeStumpClassifier.py +163 -0
  21. ruletree-0.0.2.post1/RuleTree/stumps/classification/ShapeletTreeStumpClassifier.py +243 -0
  22. ruletree-0.0.2.post1/RuleTree/stumps/classification/__init__.py +0 -0
  23. ruletree-0.0.2.post1/RuleTree/stumps/regression/DecisionTreeStumpRegressor.py +144 -0
  24. ruletree-0.0.2.post1/RuleTree/stumps/regression/ObliqueDecisionTreeStumpRegressor.py +63 -0
  25. ruletree-0.0.2.post1/RuleTree/stumps/regression/__init__.py +0 -0
  26. ruletree-0.0.2.post1/RuleTree/stumps/splitters/MultiplePivotSplit.py +64 -0
  27. ruletree-0.0.2.post1/RuleTree/stumps/splitters/ObliqueBivariateSplit.py +117 -0
  28. ruletree-0.0.2.post1/RuleTree/stumps/splitters/ObliqueHouseHolderSplit.py +92 -0
  29. ruletree-0.0.2.post1/RuleTree/stumps/splitters/ObliquePivotSplit.py +27 -0
  30. ruletree-0.0.2.post1/RuleTree/stumps/splitters/PivotSplit.py +102 -0
  31. ruletree-0.0.2.post1/RuleTree/stumps/splitters/__init__.py +0 -0
  32. ruletree-0.0.2.post1/RuleTree/tree/RuleTree.py +654 -0
  33. ruletree-0.0.2.post1/RuleTree/tree/RuleTreeClassifier.py +317 -0
  34. ruletree-0.0.2.post1/RuleTree/tree/RuleTreeCluster.py +179 -0
  35. ruletree-0.0.2.post1/RuleTree/tree/RuleTreeNode.py +182 -0
  36. ruletree-0.0.2.post1/RuleTree/tree/RuleTreeRegressor.py +160 -0
  37. ruletree-0.0.2.post1/RuleTree/tree/__init__.py +0 -0
  38. ruletree-0.0.2.post1/RuleTree/utils/__init__.py +2 -0
  39. ruletree-0.0.2.post1/RuleTree/utils/bic_estimator.py +210 -0
  40. ruletree-0.0.2.post1/RuleTree/utils/data_utils.py +234 -0
  41. ruletree-0.0.2.post1/RuleTree/utils/define.py +17 -0
  42. ruletree-0.0.2.post1/RuleTree/utils/dict_ruletree_encoding.py +46 -0
  43. ruletree-0.0.2.post1/RuleTree/utils/light_famd/__init__.py +7 -0
  44. ruletree-0.0.2.post1/RuleTree/utils/light_famd/ca.py +109 -0
  45. ruletree-0.0.2.post1/RuleTree/utils/light_famd/famd.py +78 -0
  46. ruletree-0.0.2.post1/RuleTree/utils/light_famd/mca.py +46 -0
  47. ruletree-0.0.2.post1/RuleTree/utils/light_famd/mfa.py +199 -0
  48. ruletree-0.0.2.post1/RuleTree/utils/light_famd/one_hot.py +37 -0
  49. ruletree-0.0.2.post1/RuleTree/utils/light_famd/pca.py +164 -0
  50. ruletree-0.0.2.post1/RuleTree/utils/light_famd/svd.py +37 -0
  51. ruletree-0.0.2.post1/RuleTree/utils/light_famd/util.py +32 -0
  52. ruletree-0.0.2.post1/RuleTree/utils/shapelet_transform/Shapelets.py +213 -0
  53. ruletree-0.0.2.post1/RuleTree/utils/shapelet_transform/__init__.py +0 -0
  54. ruletree-0.0.2.post1/RuleTree/utils/shapelet_transform/matrix_to_vector_distances.py +42 -0
  55. ruletree-0.0.2.post1/RuleTree/utils/utils_decoding.py +81 -0
  56. ruletree-0.0.2.post1/requirements.txt +16 -0
  57. ruletree-0.0.2.post1/ruletree.egg-info/PKG-INFO +40 -0
  58. ruletree-0.0.2.post1/ruletree.egg-info/SOURCES.txt +67 -0
  59. ruletree-0.0.2.post1/ruletree.egg-info/dependency_links.txt +1 -0
  60. ruletree-0.0.2.post1/ruletree.egg-info/requires.txt +18 -0
  61. ruletree-0.0.2.post1/ruletree.egg-info/top_level.txt +1 -0
  62. ruletree-0.0.2.post1/setup.cfg +6 -0
  63. ruletree-0.0.2.post1/setup.py +62 -0
@@ -0,0 +1,23 @@
1
+ Copyright (c) 2016, PackageOwner
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,3 @@
1
+ include LICENSE
2
+ include requirements.txt
3
+ include README.md
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.1
2
+ Name: ruletree
3
+ Version: 0.0.2.post1
4
+ Summary: Package description
5
+ Home-page: https://github.com/riccotti/RuleTree
6
+ Author: Cristiano Landi
7
+ Author-email: cristiano.landi@phd.unipi.it
8
+ License: BSD-Clause-2
9
+ Keywords: keyword1 keyword2 keyword3
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Topic :: Software Development :: Build Tools
13
+ Classifier: License :: OSI Approved :: BSD License
14
+ Classifier: Operating System :: POSIX :: Other
15
+ Classifier: Operating System :: MacOS
16
+ Classifier: Programming Language :: Python
17
+ Classifier: Programming Language :: Python :: 3
18
+ Requires-Python: >=3.12.0
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: numpy<2.0.0
22
+ Requires-Dist: scikit-learn>=1.5.0
23
+ Requires-Dist: scikit-learn-extra
24
+ Requires-Dist: scipy
25
+ Requires-Dist: pandas
26
+ Requires-Dist: category_encoders
27
+ Requires-Dist: threadpoolctl
28
+ Requires-Dist: tqdm
29
+ Requires-Dist: progress-table
30
+ Requires-Dist: pygraphviz
31
+ Requires-Dist: graphviz>=0.20.3
32
+ Requires-Dist: numba
33
+ Requires-Dist: psutil
34
+ Requires-Dist: setuptools
35
+ Requires-Dist: matplotlib
36
+ Requires-Dist: tempfile312
37
+ Provides-Extra: flag
38
+
39
+ # RuleTree
40
+ TODO
@@ -0,0 +1,2 @@
1
+ # RuleTree
2
+ TODO
@@ -0,0 +1,9 @@
1
+ from ruletree.tree.RuleTreeRegressor import RuleTreeRegressor
2
+ from ruletree.tree.RuleTreeClassifier import RuleTreeClassifier
3
+ from ruletree.tree.RuleTreeCluster import RuleTreeCluster, RuleTreeClusterRegressor, RuleTreeClusterClassifier
4
+
5
+ from ruletree.ensemble.RuleForestRegressor import RuleForestRegressor
6
+ from ruletree.ensemble.RuleForestClassifier import RuleForestClassifier
7
+ from ruletree.ensemble.RuleTreeAdaBoostRegressor import RuleTreeAdaBoostRegressor
8
+ from ruletree.ensemble.RuleTreeAdaBoostClassifier import RuleTreeAdaBoostClassifier
9
+
@@ -0,0 +1,7 @@
1
+ from abc import ABC
2
+
3
+ from sklearn.base import BaseEstimator
4
+
5
+
6
+ class RuleTreeBase(BaseEstimator, ABC):
7
+ pass
@@ -0,0 +1,7 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class RuleTreeBaseSplit(ABC):
5
+ @abstractmethod
6
+ def __init__(self, ml_task):
7
+ self.ml_task = ml_task
@@ -0,0 +1,24 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from sklearn.base import BaseEstimator
4
+
5
+ from ruletree.utils.define import DATA_TYPE_TABULAR
6
+
7
+
8
+ class RuleTreeBaseStump(BaseEstimator, ABC):
9
+ @abstractmethod
10
+ def get_rule(self, columns_names=None, scaler=None, float_precision:int|None=3):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def node_to_dict(self):
15
+ pass
16
+
17
+ @classmethod
18
+ @abstractmethod
19
+ def dict_to_node(self, node_dict, X):
20
+ pass
21
+
22
+ @staticmethod
23
+ def supports(data_type):
24
+ return data_type in [DATA_TYPE_TABULAR]
File without changes
@@ -0,0 +1,233 @@
1
+ from random import random
2
+ import numpy as np
3
+ import sklearn.base
4
+ from sklearn.ensemble import BaggingClassifier
5
+ from ruletree.utils.data_utils import _iterative_mean
6
+
7
+ from ruletree import RuleTreeClassifier
8
+ from ruletree.base.RuleTreeBase import RuleTreeBase
9
+ from sklearn.base import ClassifierMixin
10
+
11
+ from ruletree.stumps.classification.MultiplePivotTreeStumpClassifier import MultiplePivotTreeStumpClassifier
12
+ from ruletree.stumps.classification.ObliquePivotTreeStumpClassifier import ObliquePivotTreeStumpClassifier
13
+ from ruletree.stumps.classification.PivotTreeStumpClassifier import PivotTreeStumpClassifier
14
+
15
+
16
+ class RuleForestClassifier(BaggingClassifier, RuleTreeBase):
17
+ def __init__(self,
18
+ n_estimators=100,
19
+ criterion='gini',
20
+ max_depth=None,
21
+ min_samples_split=2,
22
+ min_samples_leaf=1,
23
+ min_weight_fraction_leaf=0.0,
24
+ min_impurity_decrease=0.0,
25
+ max_leaf_nodes=float("inf"),
26
+ class_weight=None,
27
+ ccp_alpha=0.0,
28
+ prune_useless_leaves=False,
29
+ splitter='best',
30
+ *,
31
+ max_samples=None,
32
+ max_features=1.0,
33
+ bootstrap=True,
34
+ oob_score=False,
35
+ warm_start=False,
36
+ custom_estimator:sklearn.base.ClassifierMixin=None,
37
+ n_jobs=None,
38
+ random_state=None,
39
+ base_stump = None,
40
+ distance_matrix = None,
41
+ distance_measure = None,
42
+ stump_selection = 'best',
43
+ verbose=0):
44
+
45
+ self.n_estimators = n_estimators
46
+ self.criterion = criterion
47
+ self.max_depth = max_depth
48
+ self.min_samples_split = min_samples_split
49
+ self.min_samples_leaf = min_samples_leaf
50
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
51
+ self.min_impurity_decrease = min_impurity_decrease
52
+ self.max_leaf_nodes = max_leaf_nodes
53
+ self.class_weight = class_weight
54
+ self.ccp_alpha = ccp_alpha
55
+ self.prune_useless_leaves = prune_useless_leaves
56
+ self.splitter = splitter
57
+ self.max_samples = max_samples
58
+ self.max_features = max_features
59
+ self.bootstrap = bootstrap
60
+ self.oob_score = oob_score
61
+ self.warm_start = warm_start
62
+ self.custom_estimator = custom_estimator
63
+ self.n_jobs = n_jobs
64
+ self.random_state = random_state
65
+ self.verbose = verbose
66
+ self.base_stump = base_stump
67
+ self.distance_matrix = distance_matrix
68
+ self.distance_measure = distance_measure
69
+ self.stump_selection= stump_selection
70
+
71
+
72
+ def fit(self, X:np.ndarray, y:np.ndarray, sample_weight=None, **kwargs):
73
+ if self.max_features is None:
74
+ self.max_features = X.shape[1]
75
+
76
+ if type(self.max_features) is str:
77
+ if self.max_features == "sqrt":
78
+ self.max_features = int(np.sqrt(X.shape[1]))
79
+ elif self.max_features == "log2":
80
+ self.max_features = int(np.log2(X.shape[1]))
81
+
82
+ base_estimator = RuleTreeClassifier if self.custom_estimator is None else self.custom_estimator
83
+ splitter = .5 if self.splitter == 'hybrid_forest' else self.splitter
84
+ if type(splitter) is float:
85
+ base_estimator = RuleTreeClassifier_choosing_splitter_randomly
86
+
87
+ if self.base_stump is not None:
88
+ if any(isinstance(clf, (PivotTreeStumpClassifier, ObliquePivotTreeStumpClassifier, MultiplePivotTreeStumpClassifier)) for clf in self.base_stump):
89
+ base_estimator = ForestEstimatorPivotClassifier
90
+
91
+ super().__init__(estimator=base_estimator(criterion=self.criterion,
92
+ max_depth=self.max_depth,
93
+ min_samples_split=self.min_samples_split,
94
+ min_samples_leaf=self.min_samples_leaf,
95
+ min_weight_fraction_leaf=self.min_weight_fraction_leaf,
96
+ min_impurity_decrease=self.min_impurity_decrease,
97
+ random_state=self.random_state,
98
+ max_leaf_nodes=self.max_leaf_nodes,
99
+ class_weight=self.class_weight,
100
+ ccp_alpha=self.ccp_alpha,
101
+ prune_useless_leaves=self.prune_useless_leaves,
102
+ splitter=self.splitter,
103
+ base_stump = self.base_stump,
104
+ distance_measure = self.distance_measure,
105
+ distance_matrix = self.distance_matrix,
106
+ stump_selection= self.stump_selection
107
+
108
+ ),
109
+ n_estimators=self.n_estimators,
110
+ max_samples=X.shape[0] if self.max_samples is None else self.max_samples,
111
+ max_features=self.max_features,
112
+ bootstrap=self.bootstrap,
113
+ bootstrap_features=True,
114
+ oob_score=self.oob_score,
115
+ warm_start=self.warm_start,
116
+ n_jobs=self.n_jobs,
117
+ random_state=self.random_state,
118
+ verbose=self.verbose)
119
+
120
+ return super().fit(X, y, sample_weight=sample_weight, **kwargs)
121
+
122
+ def local_interpretation(self, X, joint_contribution = False):
123
+
124
+ if joint_contribution:
125
+ biases = []
126
+ contributions = []
127
+ predictions = []
128
+
129
+ for tree in self.estimators_:
130
+ pred, bias, contribution = tree.local_interpretation(X, joint_contribution=joint_contribution)
131
+ biases.append(bias)
132
+ contributions.append(contribution)
133
+ predictions.append(pred)
134
+
135
+ total_contributions = []
136
+
137
+ for i in range(len(X)):
138
+ contr = {}
139
+ for j, dct in enumerate(contributions):
140
+ for k in set(dct[i]).union(set(contr.keys())):
141
+ contr[k] = (contr.get(k, 0)*j + dct[i].get(k,0) ) / (j+1)
142
+
143
+ total_contributions.append(contr)
144
+
145
+ for i, item in enumerate(contribution):
146
+ total_contributions[i]
147
+ sm = sum([v for v in contribution[i].values()])
148
+
149
+
150
+
151
+ return (np.mean(predictions, axis=0), np.mean(biases, axis=0),
152
+ total_contributions)
153
+ else:
154
+ mean_pred = None
155
+ mean_bias = None
156
+ mean_contribution = None
157
+
158
+ for i, tree in enumerate(self.estimators_):
159
+ pred, bias, contribution = tree.local_interpretation(X)
160
+
161
+ if i < 1: # first iteration
162
+ mean_bias = bias
163
+ mean_contribution = contribution
164
+ mean_pred = pred
165
+ else:
166
+ mean_bias = _iterative_mean(i, mean_bias, bias)
167
+ mean_contribution = _iterative_mean(i, mean_contribution, contribution)
168
+ mean_pred = _iterative_mean(i, mean_pred, pred)
169
+
170
+ return mean_pred, mean_bias, mean_contribution
171
+
172
+
173
+
174
+ class RuleTreeClassifier_choosing_splitter_randomly(RuleTreeClassifier):
175
+ def __init__(self, splitter, **kwargs):
176
+ if random() < splitter:
177
+ if random() < splitter:
178
+ splitter = 'random'
179
+ else:
180
+ splitter = 'best'
181
+ kwargs["splitter"] = splitter
182
+ super().__init__(**kwargs)
183
+
184
+ class ForestEstimatorPivotClassifier(RuleTreeClassifier):
185
+ def __init__(self,
186
+ max_leaf_nodes=float('inf'),
187
+ min_samples_split=2,
188
+ max_depth=float('inf'),
189
+ prune_useless_leaves=False,
190
+ base_stump: ClassifierMixin | list = None,
191
+ stump_selection: str = 'random',
192
+ random_state=None,
193
+
194
+ criterion='gini',
195
+ splitter='best',
196
+ min_samples_leaf=1,
197
+ min_weight_fraction_leaf=0.0,
198
+ max_features=None,
199
+ min_impurity_decrease=0.0,
200
+ class_weight=None,
201
+ ccp_alpha=0.0,
202
+ monotonic_cst=None,
203
+ distance_matrix = None,
204
+ distance_measure = None
205
+
206
+ ):
207
+
208
+ super().__init__(max_leaf_nodes=max_leaf_nodes,
209
+ min_samples_split=min_samples_split,
210
+ max_depth=max_depth,
211
+ prune_useless_leaves=prune_useless_leaves,
212
+ base_stump=base_stump,
213
+ stump_selection=stump_selection,
214
+ random_state=random_state)
215
+
216
+ self.max_depth = max_depth
217
+ self.criterion = criterion
218
+ self.splitter = splitter
219
+ self.min_samples_split = min_samples_split
220
+ self.min_samples_leaf = min_samples_leaf
221
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
222
+ self.max_features = max_features
223
+ self.random_state = random_state
224
+ self.min_impurity_decrease = min_impurity_decrease
225
+ self.class_weight = class_weight
226
+ self.ccp_alpha = ccp_alpha
227
+ self.monotonic_cst = monotonic_cst
228
+ self.distance_matrix = distance_matrix
229
+ self.distance_measure = distance_measure
230
+
231
+
232
+ def fit(self, X: np.array, y: np.array=None, **kwargs):
233
+ super().fit(X, y, **kwargs)
@@ -0,0 +1,103 @@
1
+ from random import random
2
+
3
+ import numpy as np
4
+ import sklearn
5
+ from sklearn.ensemble import BaggingRegressor
6
+
7
+ from ruletree import RuleTreeRegressor
8
+ from ruletree.base.RuleTreeBase import RuleTreeBase
9
+
10
+
11
+ class RuleForestRegressor(BaggingRegressor, RuleTreeBase):
12
+ def __init__(self,
13
+ n_estimators=100,
14
+ criterion='squared_error',
15
+ max_depth=None,
16
+ min_samples_split=2,
17
+ min_samples_leaf=1,
18
+ min_weight_fraction_leaf=0.0,
19
+ min_impurity_decrease=0.0,
20
+ max_leaf_nodes=float("inf"),
21
+ ccp_alpha=0.0,
22
+ prune_useless_leaves=False,
23
+ splitter='best',
24
+ *,
25
+ max_samples=None,
26
+ max_features=1.0,
27
+ bootstrap=True,
28
+ oob_score=False,
29
+ warm_start=False,
30
+ custom_estimator: sklearn.base.RegressorMixin = None,
31
+ n_jobs=None,
32
+ random_state=None,
33
+ verbose=0):
34
+ self.n_estimators = n_estimators
35
+ self.criterion = criterion
36
+ self.max_depth = max_depth
37
+ self.min_samples_split = min_samples_split
38
+ self.min_samples_leaf = min_samples_leaf
39
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
40
+ self.min_impurity_decrease = min_impurity_decrease
41
+ self.max_leaf_nodes = max_leaf_nodes
42
+ self.ccp_alpha = ccp_alpha
43
+ self.prune_useless_leaves = prune_useless_leaves
44
+ self.splitter = splitter
45
+
46
+ self.max_samples = max_samples
47
+ self.max_features = max_features
48
+ self.bootstrap = bootstrap
49
+ self.oob_score = oob_score
50
+ self.warm_start = warm_start
51
+ self.custom_estimator = custom_estimator
52
+ self.n_jobs = n_jobs
53
+ self.random_state = random_state
54
+ self.verbose = verbose
55
+
56
+ def fit(self, X:np.ndarray, y:np.ndarray, sample_weight=None):
57
+ if self.max_features is None:
58
+ self.max_features = X.shape[1]
59
+
60
+ if type(self.max_features) is str:
61
+ if self.max_features == "sqrt":
62
+ self.max_features = int(np.sqrt(X.shape[1]))
63
+ elif self.max_features == "log2":
64
+ self.max_features = int(np.log2(X.shape[1]))
65
+
66
+ base_estimator = RuleTreeRegressor if self.custom_estimator is None else self.custom_estimator
67
+ splitter = .5 if self.splitter == 'hybrid_forest' else self.splitter
68
+ if type(splitter) is float:
69
+ base_estimator = RuleTreeRegressor_choosing_splitter_randomly
70
+
71
+ super().__init__(estimator=base_estimator(criterion=self.criterion,
72
+ max_depth=self.max_depth,
73
+ min_samples_split=self.min_samples_split,
74
+ min_samples_leaf=self.min_samples_leaf,
75
+ min_weight_fraction_leaf=self.min_weight_fraction_leaf,
76
+ min_impurity_decrease=self.min_impurity_decrease,
77
+ max_leaf_nodes=self.max_leaf_nodes,
78
+ ccp_alpha=self.ccp_alpha,
79
+ prune_useless_leaves=self.prune_useless_leaves,
80
+ splitter=self.splitter
81
+ ),
82
+ n_estimators=self.n_estimators,
83
+ max_samples=X.shape[0] if self.max_samples is None else self.max_samples,
84
+ max_features=self.max_features,
85
+ bootstrap=self.bootstrap,
86
+ bootstrap_features=True,
87
+ oob_score=self.oob_score,
88
+ warm_start=self.warm_start,
89
+ n_jobs=self.n_jobs,
90
+ random_state=self.random_state,
91
+ verbose=self.verbose)
92
+
93
+ return super().fit(X, y, sample_weight=sample_weight)
94
+
95
+ class RuleTreeRegressor_choosing_splitter_randomly(RuleTreeRegressor):
96
+ def __init__(self, splitter, **kwargs):
97
+ if random() < splitter:
98
+ if random() < splitter:
99
+ splitter = 'random'
100
+ else:
101
+ splitter = 'best'
102
+ kwargs["splitter"] = splitter
103
+ super().__init__(**kwargs)
@@ -0,0 +1,61 @@
1
+ from sklearn.ensemble import AdaBoostClassifier
2
+
3
+ from ruletree import RuleTreeClassifier
4
+ from ruletree.base.RuleTreeBase import RuleTreeBase
5
+
6
+
7
+ class RuleTreeAdaBoostClassifier(AdaBoostClassifier, RuleTreeBase):
8
+ def __init__(self,
9
+ n_estimators=50,
10
+ min_samples_split=2,
11
+ prune_useless_leaves=False,
12
+ random_state=None,
13
+ criterion='gini',
14
+ splitter='best',
15
+ min_samples_leaf=1,
16
+ min_weight_fraction_leaf=0.0,
17
+ max_features=None,
18
+ min_impurity_decrease=0.0,
19
+ class_weight=None,
20
+ ccp_alpha=0.0,
21
+ monotonic_cst=None,
22
+ *,
23
+ learning_rate=1.0,
24
+ algorithm='SAMME'
25
+ ):
26
+ self.min_samples_split = min_samples_split
27
+ self.prune_useless_leaves = prune_useless_leaves
28
+ self.random_state = random_state
29
+ self.criterion = criterion
30
+ self.splitter = splitter
31
+ self.min_samples_leaf = min_samples_leaf
32
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
33
+ self.max_features = max_features
34
+ self.min_impurity_decrease = min_impurity_decrease
35
+ self.class_weight = class_weight
36
+ self.ccp_alpha = ccp_alpha
37
+ self.monotonic_cst = monotonic_cst
38
+ self.n_estimators = n_estimators
39
+ self.learning_rate = learning_rate
40
+ self.algorithm = algorithm
41
+
42
+ estimator = RuleTreeClassifier(min_samples_split=min_samples_split,
43
+ max_depth=3, #stump
44
+ prune_useless_leaves=prune_useless_leaves,
45
+ random_state=random_state,
46
+
47
+ criterion=criterion,
48
+ splitter=splitter,
49
+ min_samples_leaf=min_samples_leaf,
50
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
51
+ max_features=max_features,
52
+ min_impurity_decrease=min_impurity_decrease,
53
+ class_weight=class_weight,
54
+ ccp_alpha=ccp_alpha,
55
+ monotonic_cst=monotonic_cst
56
+ )
57
+
58
+ super().__init__(
59
+ estimator=estimator,
60
+ n_estimators=n_estimators, learning_rate=learning_rate, algorithm=algorithm, random_state=random_state
61
+ )
@@ -0,0 +1,54 @@
1
+ from sklearn.ensemble import AdaBoostRegressor
2
+
3
+ from ruletree import RuleTreeRegressor
4
+ from ruletree.base.RuleTreeBase import RuleTreeBase
5
+
6
+
7
+ class RuleTreeAdaBoostRegressor(AdaBoostRegressor, RuleTreeBase):
8
+ def __init__(self,
9
+ n_estimators=50,
10
+ min_samples_split=2,
11
+ prune_useless_leaves=False,
12
+ random_state=None,
13
+ criterion='squared_error',
14
+ splitter='best',
15
+ min_samples_leaf=1,
16
+ min_weight_fraction_leaf=0.0,
17
+ max_features=None,
18
+ min_impurity_decrease=0.0,
19
+ ccp_alpha=0.0,
20
+ monotonic_cst=None,
21
+ *,
22
+ learning_rate=1.0,
23
+ loss='linear'
24
+ ):
25
+ self.min_samples_split = min_samples_split
26
+ self.prune_useless_leaves = prune_useless_leaves
27
+ self.random_state = random_state
28
+ self.criterion = criterion
29
+ self.splitter = splitter
30
+ self.min_samples_leaf = min_samples_leaf
31
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
32
+ self.max_features = max_features
33
+ self.min_impurity_decrease = min_impurity_decrease
34
+ self.ccp_alpha = ccp_alpha
35
+ self.monotonic_cst = monotonic_cst
36
+ self.n_estimators = n_estimators
37
+ self.learning_rate = learning_rate
38
+ self.loss = loss
39
+
40
+ super().__init__(
41
+ estimator=RuleTreeRegressor(max_depth=1, #stump
42
+ prune_useless_leaves=prune_useless_leaves,
43
+ random_state=random_state,
44
+ criterion=criterion,
45
+ splitter=splitter,
46
+ min_samples_leaf=min_samples_leaf,
47
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
48
+ max_features=max_features,
49
+ min_impurity_decrease=min_impurity_decrease,
50
+ ccp_alpha=ccp_alpha,
51
+ monotonic_cst=monotonic_cst
52
+ ),
53
+ n_estimators=n_estimators, learning_rate=learning_rate, loss=loss, random_state=random_state
54
+ )
File without changes
File without changes