obliquetree 1.0.1__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of obliquetree might be problematic. Click here for more details.
- obliquetree/__init__.py +3 -0
- obliquetree/_pywrap.py +730 -0
- obliquetree/src/__init__.py +0 -0
- obliquetree/src/base.cpp +36716 -0
- obliquetree/src/base.cpython-311-darwin.so +0 -0
- obliquetree/src/ccp.cpp +25779 -0
- obliquetree/src/ccp.cpython-311-darwin.so +0 -0
- obliquetree/src/metric.cpp +31844 -0
- obliquetree/src/metric.cpython-311-darwin.so +0 -0
- obliquetree/src/oblique.cpp +35180 -0
- obliquetree/src/oblique.cpython-311-darwin.so +0 -0
- obliquetree/src/tree.cpp +30405 -0
- obliquetree/src/tree.cpython-311-darwin.so +0 -0
- obliquetree/src/utils.cpp +32804 -0
- obliquetree/src/utils.cpython-311-darwin.so +0 -0
- obliquetree/utils.py +526 -0
- obliquetree-1.0.1.dist-info/LICENSE +21 -0
- obliquetree-1.0.1.dist-info/METADATA +113 -0
- obliquetree-1.0.1.dist-info/RECORD +21 -0
- obliquetree-1.0.1.dist-info/WHEEL +5 -0
- obliquetree-1.0.1.dist-info/top_level.txt +1 -0
|
Binary file
|
obliquetree/utils.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .src.utils import export_tree as _export_tree
|
|
4
|
+
from ._pywrap import BaseTree, Classifier, Regressor
|
|
5
|
+
import json
|
|
6
|
+
from typing import Optional, Dict, Any, List, Union
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_tree(tree_data: Union[str, Dict]) -> Union[Classifier, Regressor]:
|
|
12
|
+
"""
|
|
13
|
+
Load a decision tree model from a JSON file or dictionary representation.
|
|
14
|
+
|
|
15
|
+
This function reconstructs a trained decision tree from its serialized form,
|
|
16
|
+
either from a JSON file on disk or a dictionary containing the tree structure
|
|
17
|
+
and parameters.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tree_data : Union[str, Dict]
|
|
22
|
+
Either:
|
|
23
|
+
- A string containing the file path to a JSON file containing the tree data
|
|
24
|
+
- A dictionary containing the serialized tree structure and parameters
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
Union[Classifier, Regressor]
|
|
29
|
+
A reconstructed decision tree object. The specific type (Classifier or
|
|
30
|
+
Regressor) is determined by the 'task' parameter in the tree data.
|
|
31
|
+
"""
|
|
32
|
+
# Handle input types
|
|
33
|
+
if isinstance(tree_data, str):
|
|
34
|
+
if not os.path.exists(tree_data):
|
|
35
|
+
raise FileNotFoundError(f"The file {tree_data} does not exist")
|
|
36
|
+
|
|
37
|
+
with open(tree_data, "r") as f:
|
|
38
|
+
tree = json.load(f)
|
|
39
|
+
elif isinstance(tree_data, dict):
|
|
40
|
+
tree = tree_data
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError("Input must be a JSON string, file path, or dictionary")
|
|
43
|
+
|
|
44
|
+
# Validate tree structure
|
|
45
|
+
if (
|
|
46
|
+
not isinstance(tree, dict)
|
|
47
|
+
or "params" not in tree
|
|
48
|
+
or "task" not in tree["params"]
|
|
49
|
+
):
|
|
50
|
+
raise ValueError("Invalid tree data structure")
|
|
51
|
+
|
|
52
|
+
# Create appropriate object based on task
|
|
53
|
+
if not tree["params"]["task"]:
|
|
54
|
+
obj = Classifier.__new__(Classifier)
|
|
55
|
+
else:
|
|
56
|
+
obj = Regressor.__new__(Regressor)
|
|
57
|
+
|
|
58
|
+
tree["_fit"] = True
|
|
59
|
+
|
|
60
|
+
obj.__setstate__(tree)
|
|
61
|
+
|
|
62
|
+
return obj
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def export_tree(
|
|
66
|
+
tree: Union[Classifier, Regressor], out_file: str = None
|
|
67
|
+
) -> Union[None, dict]:
|
|
68
|
+
"""
|
|
69
|
+
Serialize a decision tree model to a dictionary or JSON file.
|
|
70
|
+
|
|
71
|
+
This function converts a trained decision tree into a portable format that can
|
|
72
|
+
be saved to disk or transmitted. The serialized format preserves all necessary
|
|
73
|
+
information to reconstruct the tree using load_tree().
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
tree : Union[Classifier, Regressor]
|
|
78
|
+
The trained decision tree model to export. Must be an instance of either
|
|
79
|
+
Classifier or Regressor and must have been fitted.
|
|
80
|
+
|
|
81
|
+
out_file : str, optional
|
|
82
|
+
If provided, the path where the serialized tree should be saved as a JSON
|
|
83
|
+
file. If None, the function returns the dictionary representation instead
|
|
84
|
+
of saving to disk.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
Union[None, dict]
|
|
89
|
+
If out_file is None:
|
|
90
|
+
Returns a dictionary containing the serialized tree structure and parameters
|
|
91
|
+
If out_file is provided:
|
|
92
|
+
Returns None after saving the tree to the specified JSON file
|
|
93
|
+
"""
|
|
94
|
+
if not isinstance(tree, BaseTree):
|
|
95
|
+
raise ValueError("`tree` must be an instance of `BaseTree`.")
|
|
96
|
+
|
|
97
|
+
if not tree._fit:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
|
|
103
|
+
|
|
104
|
+
if out_file is not None:
|
|
105
|
+
if isinstance(out_file, str):
|
|
106
|
+
with open(out_file, "w") as f:
|
|
107
|
+
json.dump(tree_dict, f, indent=4)
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError("`out_file` must be a string if provided.")
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
return tree_dict
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def visualize_tree(
|
|
116
|
+
tree: Union[Classifier, Regressor],
|
|
117
|
+
feature_names: Optional[List[str]] = None,
|
|
118
|
+
max_cat: Optional[int] = None,
|
|
119
|
+
max_oblique: Optional[int] = None,
|
|
120
|
+
save_path: Optional[str] = None,
|
|
121
|
+
dpi: int = 600,
|
|
122
|
+
figsize: tuple = (20, 10),
|
|
123
|
+
show_gini: bool = True,
|
|
124
|
+
show_n_samples: bool = True,
|
|
125
|
+
show_node_value: bool = True,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Generate a visual representation of a decision tree model.
|
|
129
|
+
|
|
130
|
+
Creates a graphical visualization of the tree structure showing decision nodes,
|
|
131
|
+
leaf nodes, and split criteria. The visualization can be customized to show
|
|
132
|
+
various node statistics and can be displayed or saved to a file.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
tree : Union[Classifier, Regressor]
|
|
137
|
+
The trained decision tree model to visualize. Must be fitted before
|
|
138
|
+
visualization.
|
|
139
|
+
|
|
140
|
+
feature_names : List[str], optional
|
|
141
|
+
Human-readable names for the features used in the tree. If provided,
|
|
142
|
+
these names will be used in split conditions instead of generic feature
|
|
143
|
+
indices (e.g., "age <= 30" instead of "f0 <= 30").
|
|
144
|
+
|
|
145
|
+
max_cat : int, optional
|
|
146
|
+
For categorical splits, limits the number of categories shown in the
|
|
147
|
+
visualization. If there are more categories than this limit, they will
|
|
148
|
+
be truncated with an ellipsis. Useful for splits with many categories.
|
|
149
|
+
|
|
150
|
+
max_oblique : int, optional
|
|
151
|
+
For oblique splits (those involving multiple features), limits the number
|
|
152
|
+
of features shown in the split condition. Helps manage complex oblique
|
|
153
|
+
splits in the visualization.
|
|
154
|
+
|
|
155
|
+
save_path : str, optional
|
|
156
|
+
If provided, saves the visualization to this file path. The file format
|
|
157
|
+
is determined by the file extension (e.g., '.png', '.pdf').
|
|
158
|
+
|
|
159
|
+
dpi : int, default=600
|
|
160
|
+
The resolution (dots per inch) of the saved image. Only relevant if
|
|
161
|
+
save_path is provided.
|
|
162
|
+
|
|
163
|
+
figsize : tuple, default=(20, 10)
|
|
164
|
+
The width and height of the figure in inches.
|
|
165
|
+
|
|
166
|
+
show_gini : bool, default=True
|
|
167
|
+
Whether to display Gini impurity values in the nodes.
|
|
168
|
+
|
|
169
|
+
show_n_samples : bool, default=True
|
|
170
|
+
Whether to display the number of samples that reach each node.
|
|
171
|
+
|
|
172
|
+
show_node_value : bool, default=True
|
|
173
|
+
Whether to display the predicted value/class distributions in each node.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
None
|
|
178
|
+
The function displays the visualization and optionally saves it to disk.
|
|
179
|
+
"""
|
|
180
|
+
_check_visualize_tree_inputs(
|
|
181
|
+
tree, feature_names, max_cat, max_oblique, save_path, dpi, figsize
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
from graphviz import Digraph
|
|
186
|
+
except ImportError:
|
|
187
|
+
raise ImportError(
|
|
188
|
+
"graphviz is not installed. Please install it to use this function."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
from matplotlib.pyplot import figure, imshow, imread, axis, savefig, show
|
|
193
|
+
except:
|
|
194
|
+
raise ImportError(
|
|
195
|
+
"matplotlib is not installed. Please install it to use this function."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
|
|
199
|
+
|
|
200
|
+
node, params = tree_dict["tree"], tree_dict["params"]
|
|
201
|
+
|
|
202
|
+
def _visualize_recursive(node, graph=None, parent=None, edge_label=""):
|
|
203
|
+
if graph is None:
|
|
204
|
+
graph = Digraph(format="png")
|
|
205
|
+
graph.graph_attr.update(
|
|
206
|
+
{
|
|
207
|
+
"rankdir": "TB",
|
|
208
|
+
"ranksep": "0.3",
|
|
209
|
+
"nodesep": "0.2",
|
|
210
|
+
"splines": "polyline",
|
|
211
|
+
"ordering": "out",
|
|
212
|
+
}
|
|
213
|
+
)
|
|
214
|
+
graph.attr(
|
|
215
|
+
"node",
|
|
216
|
+
shape="box",
|
|
217
|
+
style="filled",
|
|
218
|
+
color="lightgrey",
|
|
219
|
+
fontname="Helvetica",
|
|
220
|
+
margin="0.2",
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
node_id = str(id(node))
|
|
224
|
+
label_parts = []
|
|
225
|
+
|
|
226
|
+
is_leaf = "left" not in node and "right" not in node
|
|
227
|
+
|
|
228
|
+
if is_leaf:
|
|
229
|
+
# For leaf nodes
|
|
230
|
+
label_parts.append("leaf")
|
|
231
|
+
label_parts.append(_format_value_str(node, params))
|
|
232
|
+
|
|
233
|
+
# Add impurity for leaf nodes if requested and available
|
|
234
|
+
if show_gini and "impurity" in node:
|
|
235
|
+
label_parts.append(f"impurity: {_format_float(node['impurity'])}")
|
|
236
|
+
|
|
237
|
+
# Add n_samples for leaf nodes if requested and available
|
|
238
|
+
if show_n_samples and "n_samples" in node:
|
|
239
|
+
label_parts.append(f"n_samples: {node['n_samples']}")
|
|
240
|
+
|
|
241
|
+
graph.node(
|
|
242
|
+
node_id,
|
|
243
|
+
label="\n".join(label_parts),
|
|
244
|
+
shape="box",
|
|
245
|
+
style="filled",
|
|
246
|
+
color="lightblue",
|
|
247
|
+
fontname="Helvetica",
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
# For internal nodes
|
|
251
|
+
# First add the split information
|
|
252
|
+
if node.get("is_oblique", False):
|
|
253
|
+
split_info = _create_oblique_expression(
|
|
254
|
+
node["features"],
|
|
255
|
+
node["weights"],
|
|
256
|
+
node["threshold"],
|
|
257
|
+
feature_names,
|
|
258
|
+
max_oblique,
|
|
259
|
+
)
|
|
260
|
+
elif "category_left" in node:
|
|
261
|
+
categories = node["category_left"]
|
|
262
|
+
cat_str = _format_categories(categories, max_cat)
|
|
263
|
+
feature_label = (
|
|
264
|
+
feature_names[node["feature_idx"]]
|
|
265
|
+
if feature_names and node["feature_idx"] < len(feature_names)
|
|
266
|
+
else f"f{node['feature_idx']}"
|
|
267
|
+
)
|
|
268
|
+
split_info = f"{feature_label} in {cat_str}"
|
|
269
|
+
else:
|
|
270
|
+
threshold = (
|
|
271
|
+
_format_float(node["threshold"])
|
|
272
|
+
if isinstance(node["threshold"], float)
|
|
273
|
+
else node["threshold"]
|
|
274
|
+
)
|
|
275
|
+
feature_label = (
|
|
276
|
+
feature_names[node["feature_idx"]]
|
|
277
|
+
if feature_names and node["feature_idx"] < len(feature_names)
|
|
278
|
+
else f"f{node['feature_idx']}"
|
|
279
|
+
)
|
|
280
|
+
split_info = f"{feature_label} ≤ {threshold}"
|
|
281
|
+
|
|
282
|
+
label_parts.append(split_info)
|
|
283
|
+
|
|
284
|
+
if show_node_value:
|
|
285
|
+
label_parts.append(_format_value_str(node, params))
|
|
286
|
+
|
|
287
|
+
# Add Gini impurity if requested
|
|
288
|
+
if show_gini and "impurity" in node:
|
|
289
|
+
label_parts.append(f"impurity: {_format_float(node['impurity'])}")
|
|
290
|
+
|
|
291
|
+
# Add n_samples if requested
|
|
292
|
+
if show_n_samples and "n_samples" in node:
|
|
293
|
+
label_parts.append(f"n_samples: {node['n_samples']}")
|
|
294
|
+
|
|
295
|
+
graph.node(
|
|
296
|
+
node_id,
|
|
297
|
+
label="\n".join(label_parts),
|
|
298
|
+
shape="box",
|
|
299
|
+
style="filled",
|
|
300
|
+
color="lightgrey",
|
|
301
|
+
fontname="Helvetica",
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if parent is not None:
|
|
305
|
+
graph.edge(
|
|
306
|
+
parent,
|
|
307
|
+
node_id,
|
|
308
|
+
label=edge_label,
|
|
309
|
+
fontname="Helvetica",
|
|
310
|
+
penwidth="1.0",
|
|
311
|
+
minlen="1",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if "left" in node:
|
|
315
|
+
_visualize_recursive(node["left"], graph, node_id, "Left")
|
|
316
|
+
if "right" in node:
|
|
317
|
+
_visualize_recursive(node["right"], graph, node_id, "Right")
|
|
318
|
+
|
|
319
|
+
return graph
|
|
320
|
+
|
|
321
|
+
graph = _visualize_recursive(node)
|
|
322
|
+
png_data = graph.pipe(format="png")
|
|
323
|
+
|
|
324
|
+
figure(figsize=figsize)
|
|
325
|
+
imshow(imread(BytesIO(png_data)))
|
|
326
|
+
axis("off")
|
|
327
|
+
|
|
328
|
+
if save_path:
|
|
329
|
+
savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0)
|
|
330
|
+
|
|
331
|
+
show()
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _format_float(value: float) -> str:
|
|
335
|
+
"""Format float value with 3 decimal places, return '0' for 0.0"""
|
|
336
|
+
if value == 0.0:
|
|
337
|
+
return "0"
|
|
338
|
+
return f"{value:.2f}"
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def _format_value_str(node: Dict[str, Any], params: Dict[str, Any]) -> str:
|
|
342
|
+
"""
|
|
343
|
+
Format value string based on task type (regression vs classification) and number of classes
|
|
344
|
+
|
|
345
|
+
Parameters:
|
|
346
|
+
-----------
|
|
347
|
+
node : Dict[str, Any]
|
|
348
|
+
The tree node dictionary containing values or value
|
|
349
|
+
params : Dict[str, Any]
|
|
350
|
+
Tree parameters containing task and n_classes information
|
|
351
|
+
"""
|
|
352
|
+
# Check if it's a regression task
|
|
353
|
+
if params["task"]:
|
|
354
|
+
value = (
|
|
355
|
+
_format_float(node["value"])
|
|
356
|
+
if isinstance(node["value"], float)
|
|
357
|
+
else node["value"]
|
|
358
|
+
)
|
|
359
|
+
return f"Value: {value}"
|
|
360
|
+
|
|
361
|
+
# Classification task
|
|
362
|
+
else:
|
|
363
|
+
if params["n_classes"] == 2: # Binary classification
|
|
364
|
+
# For binary case, node["values"] contains probability for positive class
|
|
365
|
+
p = node["value"]
|
|
366
|
+
return f"values: [{_format_float(1-p)}, {_format_float(p)}]"
|
|
367
|
+
else: # Multiclass (3 or more classes)
|
|
368
|
+
values_str = ", ".join(
|
|
369
|
+
_format_float(v) if isinstance(v, float) else str(v)
|
|
370
|
+
for v in node.get("values", [0.0] * params["n_classes"])
|
|
371
|
+
)
|
|
372
|
+
return f"values: [{values_str}]"
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def _create_oblique_expression(
|
|
376
|
+
features: list,
|
|
377
|
+
weights: list,
|
|
378
|
+
threshold: float,
|
|
379
|
+
feature_names: Optional[List[str]],
|
|
380
|
+
max_oblique: Optional[int] = None,
|
|
381
|
+
) -> str:
|
|
382
|
+
"""Create mathematical expression for oblique split with line breaks after 5 terms"""
|
|
383
|
+
terms = []
|
|
384
|
+
|
|
385
|
+
# Sort features and weights by absolute weight value
|
|
386
|
+
feature_weight_pairs = sorted(
|
|
387
|
+
zip(features, weights), key=lambda x: abs(x[1]), reverse=True
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Apply max_oblique limit if specified
|
|
391
|
+
if max_oblique is not None:
|
|
392
|
+
feature_weight_pairs = feature_weight_pairs[:max_oblique]
|
|
393
|
+
if len(features) > max_oblique:
|
|
394
|
+
feature_weight_pairs.append(("...", 0))
|
|
395
|
+
|
|
396
|
+
# Create terms with proper formatting
|
|
397
|
+
lines = []
|
|
398
|
+
current_line = []
|
|
399
|
+
|
|
400
|
+
for i, (f, w) in enumerate(feature_weight_pairs):
|
|
401
|
+
if f == "...":
|
|
402
|
+
current_line.append("...")
|
|
403
|
+
continue
|
|
404
|
+
|
|
405
|
+
feature_label = (
|
|
406
|
+
feature_names[f] if feature_names and f < len(feature_names) else f"f{f}"
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if w == 1.0:
|
|
410
|
+
term = feature_label # Removed parentheses for coefficient 1
|
|
411
|
+
elif w == -1.0:
|
|
412
|
+
term = f"–{feature_label}" # Removed parentheses for coefficient -1
|
|
413
|
+
else:
|
|
414
|
+
formatted_weight = _format_float(abs(w))
|
|
415
|
+
term = f"{'– ' if w < 0 else ''}({formatted_weight} * {feature_label})"
|
|
416
|
+
|
|
417
|
+
if i > 0:
|
|
418
|
+
term = f"+ {term}" if w > 0 else f" {term}"
|
|
419
|
+
|
|
420
|
+
current_line.append(term)
|
|
421
|
+
|
|
422
|
+
# Start new line after every 5 terms
|
|
423
|
+
if len(current_line) == 5 and i < len(feature_weight_pairs) - 1:
|
|
424
|
+
lines.append(" ".join(current_line) + " +")
|
|
425
|
+
current_line = []
|
|
426
|
+
|
|
427
|
+
if current_line:
|
|
428
|
+
lines.append(" ".join(current_line))
|
|
429
|
+
|
|
430
|
+
formatted_threshold = _format_float(threshold)
|
|
431
|
+
expression = "\n".join(lines)
|
|
432
|
+
return f"{expression} ≤ {formatted_threshold}"
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _format_categories(categories: list, max_cat: Optional[int] = None) -> str:
|
|
436
|
+
"""Format category list with line breaks after every 5 items"""
|
|
437
|
+
if max_cat is not None and len(categories) > max_cat:
|
|
438
|
+
shown_cats = categories[:max_cat]
|
|
439
|
+
return f"[{', '.join(map(str, shown_cats))}, ...]"
|
|
440
|
+
|
|
441
|
+
formatted_cats = []
|
|
442
|
+
current_line = []
|
|
443
|
+
|
|
444
|
+
for i, cat in enumerate(categories):
|
|
445
|
+
current_line.append(str(cat))
|
|
446
|
+
|
|
447
|
+
# Add line break after every 5 items or at the end
|
|
448
|
+
if len(current_line) == 9 and i < len(categories) - 1:
|
|
449
|
+
formatted_cats.append(", ".join(current_line) + ",")
|
|
450
|
+
current_line = []
|
|
451
|
+
|
|
452
|
+
if current_line:
|
|
453
|
+
formatted_cats.append(", ".join(current_line))
|
|
454
|
+
|
|
455
|
+
if len(formatted_cats) > 1:
|
|
456
|
+
return "[" + "\n".join(formatted_cats) + "]"
|
|
457
|
+
return f"[{formatted_cats[0]}]"
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _check_visualize_tree_inputs(
|
|
461
|
+
tree: BaseTree,
|
|
462
|
+
feature_names: Optional[List[str]] = None,
|
|
463
|
+
max_cat: Optional[int] = None,
|
|
464
|
+
max_oblique: Optional[int] = None,
|
|
465
|
+
save_path: Optional[str] = None,
|
|
466
|
+
dpi: int = 600,
|
|
467
|
+
figsize: tuple = (20, 10),
|
|
468
|
+
) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Validate the inputs for the visualize_tree function.
|
|
471
|
+
|
|
472
|
+
Parameters:
|
|
473
|
+
-----------
|
|
474
|
+
tree : object
|
|
475
|
+
The tree object to be visualized, must have a certain expected structure.
|
|
476
|
+
feature_names : Optional[List[str]]
|
|
477
|
+
If provided, must be a list of strings matching the number of features in the tree.
|
|
478
|
+
max_cat : Optional[int]
|
|
479
|
+
If provided, must be a positive integer.
|
|
480
|
+
max_oblique : Optional[int]
|
|
481
|
+
If provided, must be a positive integer.
|
|
482
|
+
save_path : Optional[str]
|
|
483
|
+
If provided, must be a valid file path ending in a supported image format (e.g., '.png').
|
|
484
|
+
dpi : int
|
|
485
|
+
Must be a positive integer.
|
|
486
|
+
figsize : tuple
|
|
487
|
+
Must be a tuple of two positive numbers.
|
|
488
|
+
"""
|
|
489
|
+
if not isinstance(tree, BaseTree):
|
|
490
|
+
raise ValueError("`tree` must be an instance of `BaseTree`.")
|
|
491
|
+
|
|
492
|
+
if not tree._fit:
|
|
493
|
+
raise ValueError(
|
|
494
|
+
"The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
if feature_names is not None:
|
|
498
|
+
if not isinstance(feature_names, list) or not all(
|
|
499
|
+
isinstance(f, str) for f in feature_names
|
|
500
|
+
):
|
|
501
|
+
raise ValueError("feature_names must be a list of strings.")
|
|
502
|
+
if len(feature_names) != tree.n_features:
|
|
503
|
+
raise ValueError(
|
|
504
|
+
f"feature_names must match the number of features in the tree ({tree.n_features})."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
if max_cat is not None and (not isinstance(max_cat, int) or max_cat <= 0):
|
|
508
|
+
raise ValueError("max_cat must be a positive integer.")
|
|
509
|
+
|
|
510
|
+
if max_oblique is not None and (
|
|
511
|
+
not isinstance(max_oblique, int) or max_oblique <= 0
|
|
512
|
+
):
|
|
513
|
+
raise ValueError("max_oblique must be a positive integer.")
|
|
514
|
+
|
|
515
|
+
if save_path is not None and not isinstance(save_path, str):
|
|
516
|
+
raise ValueError("save_path must be a string representing a valid file path.")
|
|
517
|
+
|
|
518
|
+
if not isinstance(dpi, int) or dpi <= 0:
|
|
519
|
+
raise ValueError("dpi must be a positive integer.")
|
|
520
|
+
|
|
521
|
+
if (
|
|
522
|
+
not isinstance(figsize, tuple)
|
|
523
|
+
or len(figsize) != 2
|
|
524
|
+
or not all(isinstance(dim, (int, float)) and dim > 0 for dim in figsize)
|
|
525
|
+
):
|
|
526
|
+
raise ValueError("figsize must be a tuple of two positive numbers.")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Samet Çopur
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: obliquetree
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: Traditional and Oblique Decision Tree
|
|
5
|
+
Author-email: Samet Copur <sametcopur@yahoo.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
Project-URL: Documentation, https://obliquetree.readthedocs.io/en/latest/
|
|
8
|
+
Project-URL: Repository, https://github.com/sametcopur/obliquetree
|
|
9
|
+
Project-URL: Tracker, https://github.com/sametcopur/obliquetree/issues
|
|
10
|
+
Keywords: python,data-science,machine-learning,machine-learning-library,explainable-ai,decision-tree,oblique-tree
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Requires-Python: >=3.10
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: numpy>=2.2.1
|
|
18
|
+
Requires-Dist: scipy>=1.15.0
|
|
19
|
+
|
|
20
|
+
# obliquetree
|
|
21
|
+
|
|
22
|
+
`obliquetree` is an advanced decision tree implementation designed to provide high-performance and interpretable models. It supports both classification and regression tasks, enabling a wide range of applications. By offering traditional and oblique splits, it ensures flexibility and improved generalization with shallow trees. This makes it a powerful alternative to regular decision trees.
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+

|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
----
|
|
29
|
+
|
|
30
|
+
## Getting Started
|
|
31
|
+
|
|
32
|
+
`obliquetree` combines advanced capabilities with efficient performance. It supports **oblique splits**, leveraging **L-BFGS optimization** to determine the best linear weights for splits, ensuring both speed and accuracy.
|
|
33
|
+
|
|
34
|
+
In **traditional mode**, without oblique splits, `obliquetree` outperforms `scikit-learn` in terms of speed and adds support for **categorical variables**, providing a significant advantage over many traditional decision tree implementations.
|
|
35
|
+
|
|
36
|
+
When the **oblique feature** is enabled, `obliquetree` dynamically selects the optimal split type between oblique and traditional splits. If no weights can be found to reduce impurity, it defaults to an **axis-aligned split**, ensuring robustness and adaptability in various scenarios.
|
|
37
|
+
|
|
38
|
+
In very large trees (e.g., depth 10 or more), the performance of `obliquetree` may converge closely with **traditional trees**. The true strength of `obliquetree` lies in their ability to perform exceptionally well at **shallower depths**, offering improved generalization with fewer splits. Moreover, thanks to linear projections, `obliquetree` significantly outperform traditional trees when working with datasets that exhibit **linear relationships**.
|
|
39
|
+
|
|
40
|
+
-----
|
|
41
|
+
## Installation
|
|
42
|
+
To install `obliquetree`, use the following pip command:
|
|
43
|
+
```bash
|
|
44
|
+
pip install obliquetree
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Using the `obliquetree` library is simple and intuitive. Here's a more generic example that works for both classification and regression:
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from obliquetree import Classifier, Regressor
|
|
52
|
+
|
|
53
|
+
# Initialize the model (Classifier or Regressor)
|
|
54
|
+
model = Classifier( # Replace "Classifier" with "Regressor" if performing regression
|
|
55
|
+
use_oblique=True, # Enable oblique splits
|
|
56
|
+
max_depth=2, # Set the maximum depth of the tree
|
|
57
|
+
n_pair=2, # Number of feature pairs for optimization
|
|
58
|
+
random_state=42, # Set a random state for reproducibility
|
|
59
|
+
categories=[0, 10, 32], # Specify which features are categorical
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Train the model on the training dataset
|
|
63
|
+
model.fit(X_train, y_train)
|
|
64
|
+
|
|
65
|
+
# Predict on the test dataset
|
|
66
|
+
y_pred = model.predict(X_test)
|
|
67
|
+
```
|
|
68
|
+
-----
|
|
69
|
+
|
|
70
|
+
## Documentation
|
|
71
|
+
For example usage, API details, comparisons with axis-aligned trees, and in-depth insights into the algorithmic foundation, we **strongly recommend** referring to the full [documentation](https://obliquetree.readthedocs.io/en/latest/).
|
|
72
|
+
|
|
73
|
+
---
|
|
74
|
+
## Key Features
|
|
75
|
+
|
|
76
|
+
- **Oblique Splits**
|
|
77
|
+
Perform oblique splits using linear combinations of features to capture complex patterns in data. Supports both linear and soft decision tree objectives for flexible and accurate modeling.
|
|
78
|
+
|
|
79
|
+
- **Axis-Aligned Splits**
|
|
80
|
+
Offers conventional (axis-aligned) splits, enabling users to leverage standard decision tree behavior for simplicity and interpretability.
|
|
81
|
+
|
|
82
|
+
- **Feature Constraints**
|
|
83
|
+
Limit the number of features used in oblique splits with the `n_pair` parameter, promoting simpler, more interpretable tree structures while retaining predictive power.
|
|
84
|
+
|
|
85
|
+
- **Seamless Categorical Feature Handling**
|
|
86
|
+
Natively supports categorical columns with minimal preprocessing. Only label encoding is required, removing the need for extensive data transformation.
|
|
87
|
+
|
|
88
|
+
- **Robust Handling of Missing Values**
|
|
89
|
+
Automatically assigns `NaN` values to the optimal leaf for axis-aligned splits.
|
|
90
|
+
|
|
91
|
+
- **Customizable Tree Structures**
|
|
92
|
+
The flexible API empowers users to design their own tree architectures easily.
|
|
93
|
+
|
|
94
|
+
- **Exact Equivalence with `scikit-learn`**
|
|
95
|
+
Guarantees results identical to `scikit-learn`'s decision trees when oblique and categorical splitting are disabled.
|
|
96
|
+
|
|
97
|
+
- **Optimized Performance**
|
|
98
|
+
Outperforms `scikit-learn` in terms of speed and efficiency when oblique and categorical splitting are disabled:
|
|
99
|
+
- Up to **50% faster** for datasets with float columns.
|
|
100
|
+
- Up to **200% faster** for datasets with integer columns.
|
|
101
|
+
|
|
102
|
+

|
|
103
|
+
|
|
104
|
+

|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
----
|
|
108
|
+
### Contributing
|
|
109
|
+
Contributions are welcome! If you'd like to improve `obliquetree` or suggest new features, feel free to fork the repository and submit a pull request.
|
|
110
|
+
|
|
111
|
+
-----
|
|
112
|
+
### License
|
|
113
|
+
`obliquetree` is released under the MIT License. See the LICENSE file for more details.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
obliquetree-1.0.1.dist-info/RECORD,,
|
|
2
|
+
obliquetree-1.0.1.dist-info/LICENSE,sha256=lLpW3hLh8QbaRAWXRaZ76c80HoVfUr6739D0S7ecslU,1069
|
|
3
|
+
obliquetree-1.0.1.dist-info/WHEEL,sha256=NW1RskY9zow1Y68W-gXg0oZyBRAugI1JHywIzAIai5o,109
|
|
4
|
+
obliquetree-1.0.1.dist-info/top_level.txt,sha256=m-5N4-iAS5MsFOdk8y1r2ya_i5rVQBgPayHBA-K26qg,12
|
|
5
|
+
obliquetree-1.0.1.dist-info/METADATA,sha256=_6KQMjXzPh7zw-b0KVnGB9hWUziugEiy_ZUUgaZIfrU,5640
|
|
6
|
+
obliquetree/__init__.py,sha256=hER789IcI7j2GnsVNvUhdVUjPJemBfHd2ztSOZgRwW8,81
|
|
7
|
+
obliquetree/utils.py,sha256=coNk6PATedhTTBoapJ6rb0SI4TSfRYxI6bBnGAwAO_Y,17976
|
|
8
|
+
obliquetree/_pywrap.py,sha256=pzx9tUbzb_4Brw6UOgYFVMv9VIfupDkJBLh5m1tUTEo,25754
|
|
9
|
+
obliquetree/src/base.cpython-311-darwin.so,sha256=ktGj5MbnjyrABkNEwaZgO2-O9eKHG2gQanprBpuaW5k,252696
|
|
10
|
+
obliquetree/src/utils.cpp,sha256=7Q5r4Ko31co4Rf90RTnLzKUmw9fDmrkfjyXr0gGXNJ8,1227838
|
|
11
|
+
obliquetree/src/ccp.cpython-311-darwin.so,sha256=M7smNrFzBuXadpV7oZgkkeLM8G65f-05rwHKLIWsvOg,168776
|
|
12
|
+
obliquetree/src/oblique.cpp,sha256=rqdaDM2dqID-7X8MXdRMz3PMSP6QWddZteiFf1yqskQ,1360036
|
|
13
|
+
obliquetree/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
obliquetree/src/base.cpp,sha256=dXMTiOZ4TLupIdvV3-dXmqnpQ04JFTf5J4AIzmg_dKc,1427982
|
|
15
|
+
obliquetree/src/tree.cpp,sha256=u7Bni3IDrVuKOwAnrL1RJ2ZO5GBlCOKNV0VSZtBWNoc,1137319
|
|
16
|
+
obliquetree/src/oblique.cpython-311-darwin.so,sha256=L9cPcpxuUmK-jPJdPNJB-uNOxHpz-clj96h_Ept3jx4,274464
|
|
17
|
+
obliquetree/src/metric.cpp,sha256=a8EQXx-oG4zUaAGizOiETEbOiqmwMkdkfZsLu_Lq_Wo,1174065
|
|
18
|
+
obliquetree/src/utils.cpython-311-darwin.so,sha256=fR27zP7WI7M0w_8Liajyn185E5N-xZmlSU_Svwc4NsU,255832
|
|
19
|
+
obliquetree/src/ccp.cpp,sha256=EE1KaqH9Nt3jDojBcr6TRd4KjKv3LMQAAfkDkNGaYP4,944586
|
|
20
|
+
obliquetree/src/tree.cpython-311-darwin.so,sha256=YfiWiXDAiy8dw8A3ZCQqFlbHcc7RhiaW_0-ea7FYhL4,186344
|
|
21
|
+
obliquetree/src/metric.cpython-311-darwin.so,sha256=cylxsZDHXhDJ-tdXO2a_SLVcV2Pv1I-FMLgHSnzWSis,186528
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
obliquetree
|