obliquetree 1.0.3__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of obliquetree might be problematic. Click here for more details.
- obliquetree/__init__.py +4 -0
- obliquetree/_pywrap.py +747 -0
- obliquetree/src/__init__.py +0 -0
- obliquetree/src/base.cpp +36355 -0
- obliquetree/src/base.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/src/ccp.cpp +25678 -0
- obliquetree/src/ccp.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/src/metric.cpp +31853 -0
- obliquetree/src/metric.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/src/oblique.cpp +35284 -0
- obliquetree/src/oblique.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/src/tree.cpp +30338 -0
- obliquetree/src/tree.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/src/utils.cpp +32642 -0
- obliquetree/src/utils.cpython-311-x86_64-linux-gnu.so +0 -0
- obliquetree/utils.py +882 -0
- obliquetree-1.0.3.dist-info/METADATA +114 -0
- obliquetree-1.0.3.dist-info/RECORD +21 -0
- obliquetree-1.0.3.dist-info/WHEEL +6 -0
- obliquetree-1.0.3.dist-info/licenses/LICENSE +21 -0
- obliquetree-1.0.3.dist-info/top_level.txt +1 -0
obliquetree/utils.py
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .src.utils import export_tree as _export_tree
|
|
4
|
+
from ._pywrap import BaseTree, Classifier, Regressor
|
|
5
|
+
import json
|
|
6
|
+
from typing import Optional, Dict, Any, List, Union
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_tree(tree_data: Union[str, Dict]) -> Union[Classifier, Regressor]:
|
|
13
|
+
"""
|
|
14
|
+
Load a decision tree model from a JSON file or dictionary representation.
|
|
15
|
+
|
|
16
|
+
This function reconstructs a trained decision tree from its serialized form,
|
|
17
|
+
either from a JSON file on disk or a dictionary containing the tree structure
|
|
18
|
+
and parameters.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
tree_data : Union[str, Dict]
|
|
23
|
+
Either:
|
|
24
|
+
- A string containing the file path to a JSON file containing the tree data
|
|
25
|
+
- A dictionary containing the serialized tree structure and parameters
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
Union[Classifier, Regressor]
|
|
30
|
+
A reconstructed decision tree object. The specific type (Classifier or
|
|
31
|
+
Regressor) is determined by the 'task' parameter in the tree data.
|
|
32
|
+
"""
|
|
33
|
+
# Handle input types
|
|
34
|
+
if isinstance(tree_data, str):
|
|
35
|
+
if not os.path.exists(tree_data):
|
|
36
|
+
raise FileNotFoundError(f"The file {tree_data} does not exist")
|
|
37
|
+
|
|
38
|
+
with open(tree_data, "r") as f:
|
|
39
|
+
tree = json.load(f)
|
|
40
|
+
elif isinstance(tree_data, dict):
|
|
41
|
+
tree = tree_data
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError("Input must be a JSON string, file path, or dictionary")
|
|
44
|
+
|
|
45
|
+
# Validate tree structure
|
|
46
|
+
if (
|
|
47
|
+
not isinstance(tree, dict)
|
|
48
|
+
or "params" not in tree
|
|
49
|
+
or "task" not in tree["params"]
|
|
50
|
+
):
|
|
51
|
+
raise ValueError("Invalid tree data structure")
|
|
52
|
+
|
|
53
|
+
# Create appropriate object based on task
|
|
54
|
+
if not tree["params"]["task"]:
|
|
55
|
+
obj = Classifier.__new__(Classifier)
|
|
56
|
+
else:
|
|
57
|
+
obj = Regressor.__new__(Regressor)
|
|
58
|
+
|
|
59
|
+
tree["_fit"] = True
|
|
60
|
+
|
|
61
|
+
obj.__setstate__(tree)
|
|
62
|
+
|
|
63
|
+
return obj
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def export_tree(
|
|
67
|
+
tree: Union[Classifier, Regressor], out_file: str = None
|
|
68
|
+
) -> Union[None, dict]:
|
|
69
|
+
"""
|
|
70
|
+
Serialize a decision tree model to a dictionary or JSON file.
|
|
71
|
+
|
|
72
|
+
This function converts a trained decision tree into a portable format that can
|
|
73
|
+
be saved to disk or transmitted. The serialized format preserves all necessary
|
|
74
|
+
information to reconstruct the tree using load_tree().
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
tree : Union[Classifier, Regressor]
|
|
79
|
+
The trained decision tree model to export. Must be an instance of either
|
|
80
|
+
Classifier or Regressor and must have been fitted.
|
|
81
|
+
|
|
82
|
+
out_file : str, optional
|
|
83
|
+
If provided, the path where the serialized tree should be saved as a JSON
|
|
84
|
+
file. If None, the function returns the dictionary representation instead
|
|
85
|
+
of saving to disk.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
Union[None, dict]
|
|
90
|
+
If out_file is None:
|
|
91
|
+
Returns a dictionary containing the serialized tree structure and parameters
|
|
92
|
+
If out_file is provided:
|
|
93
|
+
Returns None after saving the tree to the specified JSON file
|
|
94
|
+
"""
|
|
95
|
+
if not isinstance(tree, BaseTree):
|
|
96
|
+
raise ValueError("`tree` must be an instance of `BaseTree`.")
|
|
97
|
+
|
|
98
|
+
if not tree._fit:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
|
|
104
|
+
|
|
105
|
+
if out_file is not None:
|
|
106
|
+
if isinstance(out_file, str):
|
|
107
|
+
with open(out_file, "w") as f:
|
|
108
|
+
json.dump(tree_dict, f, indent=4)
|
|
109
|
+
else:
|
|
110
|
+
raise ValueError("`out_file` must be a string if provided.")
|
|
111
|
+
|
|
112
|
+
else:
|
|
113
|
+
return tree_dict
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def visualize_tree(
|
|
117
|
+
tree: Union[Classifier, Regressor],
|
|
118
|
+
feature_names: Optional[List[str]] = None,
|
|
119
|
+
max_cat: Optional[int] = None,
|
|
120
|
+
max_oblique: Optional[int] = None,
|
|
121
|
+
save_path: Optional[str] = None,
|
|
122
|
+
dpi: int = 600,
|
|
123
|
+
figsize: tuple = (20, 10),
|
|
124
|
+
show_gini: bool = True,
|
|
125
|
+
show_n_samples: bool = True,
|
|
126
|
+
show_node_value: bool = True,
|
|
127
|
+
) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Generate a visual representation of a decision tree model.
|
|
130
|
+
|
|
131
|
+
Creates a graphical visualization of the tree structure showing decision nodes,
|
|
132
|
+
leaf nodes, and split criteria. The visualization can be customized to show
|
|
133
|
+
various node statistics and can be displayed or saved to a file.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
tree : Union[Classifier, Regressor]
|
|
138
|
+
The trained decision tree model to visualize. Must be fitted before
|
|
139
|
+
visualization.
|
|
140
|
+
|
|
141
|
+
feature_names : List[str], optional
|
|
142
|
+
Human-readable names for the features used in the tree. If provided,
|
|
143
|
+
these names will be used in split conditions instead of generic feature
|
|
144
|
+
indices (e.g., "age <= 30" instead of "f0 <= 30").
|
|
145
|
+
|
|
146
|
+
max_cat : int, optional
|
|
147
|
+
For categorical splits, limits the number of categories shown in the
|
|
148
|
+
visualization. If there are more categories than this limit, they will
|
|
149
|
+
be truncated with an ellipsis. Useful for splits with many categories.
|
|
150
|
+
|
|
151
|
+
max_oblique : int, optional
|
|
152
|
+
For oblique splits (those involving multiple features), limits the number
|
|
153
|
+
of features shown in the split condition. Helps manage complex oblique
|
|
154
|
+
splits in the visualization.
|
|
155
|
+
|
|
156
|
+
save_path : str, optional
|
|
157
|
+
If provided, saves the visualization to this file path. The file format
|
|
158
|
+
is determined by the file extension (e.g., '.png', '.pdf').
|
|
159
|
+
|
|
160
|
+
dpi : int, default=600
|
|
161
|
+
The resolution (dots per inch) of the saved image. Only relevant if
|
|
162
|
+
save_path is provided.
|
|
163
|
+
|
|
164
|
+
figsize : tuple, default=(20, 10)
|
|
165
|
+
The width and height of the figure in inches.
|
|
166
|
+
|
|
167
|
+
show_gini : bool, default=True
|
|
168
|
+
Whether to display Gini impurity values in the nodes.
|
|
169
|
+
|
|
170
|
+
show_n_samples : bool, default=True
|
|
171
|
+
Whether to display the number of samples that reach each node.
|
|
172
|
+
|
|
173
|
+
show_node_value : bool, default=True
|
|
174
|
+
Whether to display the predicted value/class distributions in each node.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
None
|
|
179
|
+
The function displays the visualization and optionally saves it to disk.
|
|
180
|
+
"""
|
|
181
|
+
_check_visualize_tree_inputs(
|
|
182
|
+
tree, feature_names, max_cat, max_oblique, save_path, dpi, figsize
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
from graphviz import Digraph
|
|
187
|
+
except ImportError:
|
|
188
|
+
raise ImportError(
|
|
189
|
+
"graphviz is not installed. Please install it to use this function."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
from matplotlib.pyplot import figure, imshow, imread, axis, savefig, show
|
|
194
|
+
except:
|
|
195
|
+
raise ImportError(
|
|
196
|
+
"matplotlib is not installed. Please install it to use this function."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
|
|
200
|
+
|
|
201
|
+
node, params = tree_dict["tree"], tree_dict["params"]
|
|
202
|
+
|
|
203
|
+
def _visualize_recursive(node, graph=None, parent=None, edge_label=""):
|
|
204
|
+
if graph is None:
|
|
205
|
+
graph = Digraph(format="png")
|
|
206
|
+
graph.graph_attr.update(
|
|
207
|
+
{
|
|
208
|
+
"rankdir": "TB",
|
|
209
|
+
"ranksep": "0.3",
|
|
210
|
+
"nodesep": "0.2",
|
|
211
|
+
"splines": "polyline",
|
|
212
|
+
"ordering": "out",
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
graph.attr(
|
|
216
|
+
"node",
|
|
217
|
+
shape="box",
|
|
218
|
+
style="filled",
|
|
219
|
+
color="lightgrey",
|
|
220
|
+
fontname="Helvetica",
|
|
221
|
+
margin="0.2",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
node_id = str(id(node))
|
|
225
|
+
label_parts = []
|
|
226
|
+
|
|
227
|
+
is_leaf = "left" not in node and "right" not in node
|
|
228
|
+
|
|
229
|
+
if is_leaf:
|
|
230
|
+
# For leaf nodes
|
|
231
|
+
label_parts.append("leaf")
|
|
232
|
+
label_parts.append(_format_value_str(node, params))
|
|
233
|
+
|
|
234
|
+
# Add impurity for leaf nodes if requested and available
|
|
235
|
+
if show_gini and "impurity" in node:
|
|
236
|
+
label_parts.append(f"impurity: {_format_float(node['impurity'])}")
|
|
237
|
+
|
|
238
|
+
# Add n_samples for leaf nodes if requested and available
|
|
239
|
+
if show_n_samples and "n_samples" in node:
|
|
240
|
+
label_parts.append(f"n_samples: {node['n_samples']}")
|
|
241
|
+
|
|
242
|
+
graph.node(
|
|
243
|
+
node_id,
|
|
244
|
+
label="\n".join(label_parts),
|
|
245
|
+
shape="box",
|
|
246
|
+
style="filled",
|
|
247
|
+
color="lightblue",
|
|
248
|
+
fontname="Helvetica",
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
# For internal nodes
|
|
252
|
+
# First add the split information
|
|
253
|
+
if node.get("is_oblique", False):
|
|
254
|
+
split_info = _create_oblique_expression(
|
|
255
|
+
node["features"],
|
|
256
|
+
node["weights"],
|
|
257
|
+
node["threshold"],
|
|
258
|
+
feature_names,
|
|
259
|
+
max_oblique,
|
|
260
|
+
)
|
|
261
|
+
elif "category_left" in node:
|
|
262
|
+
categories = node["category_left"]
|
|
263
|
+
cat_str = _format_categories(categories, max_cat)
|
|
264
|
+
feature_label = (
|
|
265
|
+
feature_names[node["feature_idx"]]
|
|
266
|
+
if feature_names and node["feature_idx"] < len(feature_names)
|
|
267
|
+
else f"f{node['feature_idx']}"
|
|
268
|
+
)
|
|
269
|
+
split_info = f"{feature_label} in {cat_str}"
|
|
270
|
+
else:
|
|
271
|
+
threshold = (
|
|
272
|
+
_format_float(node["threshold"])
|
|
273
|
+
if isinstance(node["threshold"], float)
|
|
274
|
+
else node["threshold"]
|
|
275
|
+
)
|
|
276
|
+
feature_label = (
|
|
277
|
+
feature_names[node["feature_idx"]]
|
|
278
|
+
if feature_names and node["feature_idx"] < len(feature_names)
|
|
279
|
+
else f"f{node['feature_idx']}"
|
|
280
|
+
)
|
|
281
|
+
split_info = f"{feature_label} ≤ {threshold}"
|
|
282
|
+
|
|
283
|
+
label_parts.append(split_info)
|
|
284
|
+
|
|
285
|
+
if show_node_value:
|
|
286
|
+
label_parts.append(_format_value_str(node, params))
|
|
287
|
+
|
|
288
|
+
# Add Gini impurity if requested
|
|
289
|
+
if show_gini and "impurity" in node:
|
|
290
|
+
label_parts.append(f"impurity: {_format_float(node['impurity'])}")
|
|
291
|
+
|
|
292
|
+
# Add n_samples if requested
|
|
293
|
+
if show_n_samples and "n_samples" in node:
|
|
294
|
+
label_parts.append(f"n_samples: {node['n_samples']}")
|
|
295
|
+
|
|
296
|
+
graph.node(
|
|
297
|
+
node_id,
|
|
298
|
+
label="\n".join(label_parts),
|
|
299
|
+
shape="box",
|
|
300
|
+
style="filled",
|
|
301
|
+
color="lightgrey",
|
|
302
|
+
fontname="Helvetica",
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if parent is not None:
|
|
306
|
+
graph.edge(
|
|
307
|
+
parent,
|
|
308
|
+
node_id,
|
|
309
|
+
label=edge_label,
|
|
310
|
+
fontname="Helvetica",
|
|
311
|
+
penwidth="1.0",
|
|
312
|
+
minlen="1",
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if "left" in node:
|
|
316
|
+
_visualize_recursive(node["left"], graph, node_id, "Left")
|
|
317
|
+
if "right" in node:
|
|
318
|
+
_visualize_recursive(node["right"], graph, node_id, "Right")
|
|
319
|
+
|
|
320
|
+
return graph
|
|
321
|
+
|
|
322
|
+
graph = _visualize_recursive(node)
|
|
323
|
+
png_data = graph.pipe(format="png")
|
|
324
|
+
|
|
325
|
+
figure(figsize=figsize)
|
|
326
|
+
imshow(imread(BytesIO(png_data)))
|
|
327
|
+
axis("off")
|
|
328
|
+
|
|
329
|
+
if save_path:
|
|
330
|
+
savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0)
|
|
331
|
+
|
|
332
|
+
show()
|
|
333
|
+
|
|
334
|
+
def export_tree_to_onnx(tree: Union[Classifier, Regressor]) -> None:
|
|
335
|
+
"""
|
|
336
|
+
Convert an oblique decision tree (Classifier or Regressor) into an ONNX model.
|
|
337
|
+
|
|
338
|
+
.. important::
|
|
339
|
+
- This implementation currently does **not** support batch processing.
|
|
340
|
+
Only a single row (1D NumPy array) and np.float64 dtype can be passed as input.
|
|
341
|
+
- The input variable name must be **"X"** and its shape should be (n_features,).
|
|
342
|
+
- In binary classification, the output is a single-dimensional value representing
|
|
343
|
+
the probability of belonging to the positive class.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
tree : Union[Classifier, Regressor]
|
|
348
|
+
The oblique decision tree (classifier or regressor) to be converted to ONNX.
|
|
349
|
+
|
|
350
|
+
Returns
|
|
351
|
+
-------
|
|
352
|
+
onnx.ModelProto
|
|
353
|
+
The constructed ONNX model.
|
|
354
|
+
|
|
355
|
+
Examples
|
|
356
|
+
--------
|
|
357
|
+
>>> # Suppose we have a 2D NumPy array X of shape (num_samples, num_features).
|
|
358
|
+
>>> # We only take a single row for prediction:
|
|
359
|
+
>>> X_sample = X[0, :]
|
|
360
|
+
>>>
|
|
361
|
+
>>> # Create an inference session using onnxruntime:
|
|
362
|
+
>>> import onnxruntime
|
|
363
|
+
>>> session = onnxruntime.InferenceSession("tree.onnx")
|
|
364
|
+
>>>
|
|
365
|
+
>>> # Retrieve the output name of the model
|
|
366
|
+
>>> out_name = session.get_outputs()[0].name
|
|
367
|
+
>>>
|
|
368
|
+
>>> # Perform inference on the sample
|
|
369
|
+
>>> y_pred = session.run([out_name], {"X": X_sample})[0]
|
|
370
|
+
>>> print(y_pred)
|
|
371
|
+
"""
|
|
372
|
+
try:
|
|
373
|
+
from onnx import helper, TensorProto
|
|
374
|
+
except ImportError as e:
|
|
375
|
+
raise ImportError(
|
|
376
|
+
"Failed to import onnx dependencies. Please make sure the 'onnx' "
|
|
377
|
+
"package is installed."
|
|
378
|
+
) from e
|
|
379
|
+
|
|
380
|
+
tree_dict = export_tree(tree)
|
|
381
|
+
|
|
382
|
+
# Closure for unique name generation
|
|
383
|
+
name_counter = [0]
|
|
384
|
+
|
|
385
|
+
def _unique_name(prefix="Node"):
|
|
386
|
+
name_counter[0] += 1
|
|
387
|
+
return f"{prefix}_{name_counter[0]}"
|
|
388
|
+
|
|
389
|
+
def _make_constant_int_node(name, value, shape=None):
|
|
390
|
+
"""
|
|
391
|
+
Creates an ONNX Constant node containing int64 data.
|
|
392
|
+
Useful for indices in Gather or other integer-only parameters.
|
|
393
|
+
"""
|
|
394
|
+
if shape is None:
|
|
395
|
+
shape = [len(value)] if isinstance(value, list) else []
|
|
396
|
+
arr = (
|
|
397
|
+
np.array(value, dtype=np.int64)
|
|
398
|
+
if isinstance(value, list)
|
|
399
|
+
else (
|
|
400
|
+
np.array([value], dtype=np.int64)
|
|
401
|
+
if shape == []
|
|
402
|
+
else np.array(value, dtype=np.int64)
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
const_tensor = helper.make_tensor(
|
|
407
|
+
name=_unique_name("const_data_int"),
|
|
408
|
+
data_type=TensorProto.INT64,
|
|
409
|
+
dims=arr.shape,
|
|
410
|
+
vals=arr.flatten().tolist(),
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
node = helper.make_node(
|
|
414
|
+
"Constant", inputs=[], outputs=[name], value=const_tensor
|
|
415
|
+
)
|
|
416
|
+
return node
|
|
417
|
+
|
|
418
|
+
def _make_constant_float_node(name, value, shape=None):
|
|
419
|
+
"""
|
|
420
|
+
Creates an ONNX Constant node containing float64 data.
|
|
421
|
+
Useful for thresholds, weights, etc.
|
|
422
|
+
"""
|
|
423
|
+
if shape is None:
|
|
424
|
+
shape = [len(value)] if isinstance(value, list) else []
|
|
425
|
+
arr = (
|
|
426
|
+
np.array(value, dtype=np.float64)
|
|
427
|
+
if isinstance(value, list)
|
|
428
|
+
else np.array([value], dtype=np.float64)
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if shape and arr.shape != tuple(shape):
|
|
432
|
+
arr = arr.reshape(shape)
|
|
433
|
+
|
|
434
|
+
const_tensor = helper.make_tensor(
|
|
435
|
+
name=_unique_name("const_data_float"),
|
|
436
|
+
data_type=TensorProto.DOUBLE,
|
|
437
|
+
dims=arr.shape,
|
|
438
|
+
vals=arr.flatten().tolist(),
|
|
439
|
+
)
|
|
440
|
+
node = helper.make_node(
|
|
441
|
+
"Constant", inputs=[], outputs=[name], value=const_tensor
|
|
442
|
+
)
|
|
443
|
+
return node
|
|
444
|
+
|
|
445
|
+
def _build_subgraph_for_node(node_dict, n_classes):
|
|
446
|
+
"""
|
|
447
|
+
Recursively builds a subgraph (for 'If' branches) from the given node definition.
|
|
448
|
+
The subgraph uses 'X' as an outer-scope input (not declared in inputs[]).
|
|
449
|
+
"""
|
|
450
|
+
nodes = []
|
|
451
|
+
graph_name = _unique_name("SubGraph")
|
|
452
|
+
|
|
453
|
+
# Subgraph output
|
|
454
|
+
out_name = _unique_name("sub_out")
|
|
455
|
+
out_info = helper.make_tensor_value_info(out_name, TensorProto.DOUBLE, None)
|
|
456
|
+
|
|
457
|
+
# Reference to 'X' from the outer scope
|
|
458
|
+
x_info = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [None])
|
|
459
|
+
|
|
460
|
+
# If this is a leaf node
|
|
461
|
+
if node_dict["is_leaf"]:
|
|
462
|
+
if "values" in node_dict and isinstance(node_dict["values"], list):
|
|
463
|
+
# Multi-class leaf
|
|
464
|
+
val_array = node_dict["values"]
|
|
465
|
+
shape = [len(val_array)]
|
|
466
|
+
cnode = _make_constant_float_node(out_name, val_array, shape)
|
|
467
|
+
nodes.append(cnode)
|
|
468
|
+
else:
|
|
469
|
+
# Single-value leaf (binary or regression)
|
|
470
|
+
val = node_dict["value"]
|
|
471
|
+
cnode = _make_constant_float_node(out_name, val, [])
|
|
472
|
+
nodes.append(cnode)
|
|
473
|
+
|
|
474
|
+
subgraph = helper.make_graph(
|
|
475
|
+
nodes=nodes,
|
|
476
|
+
name=graph_name,
|
|
477
|
+
inputs=[],
|
|
478
|
+
outputs=[out_info],
|
|
479
|
+
value_info=[x_info],
|
|
480
|
+
)
|
|
481
|
+
return subgraph, out_name
|
|
482
|
+
|
|
483
|
+
# Otherwise, this node is a split
|
|
484
|
+
cond_name = _unique_name("cond_bool")
|
|
485
|
+
is_oblique = node_dict.get("is_oblique", False)
|
|
486
|
+
cat_list = node_dict.get("category_left", [])
|
|
487
|
+
n_category = len(cat_list)
|
|
488
|
+
|
|
489
|
+
# Oblique split
|
|
490
|
+
if is_oblique:
|
|
491
|
+
w_list = node_dict["weights"]
|
|
492
|
+
f_list = node_dict["features"]
|
|
493
|
+
thr_val = node_dict["threshold"]
|
|
494
|
+
|
|
495
|
+
partials = []
|
|
496
|
+
for w, f_idx in zip(w_list, f_list):
|
|
497
|
+
gather_idx = _make_constant_int_node(
|
|
498
|
+
_unique_name("gather_idx"), [f_idx], [1]
|
|
499
|
+
)
|
|
500
|
+
nodes.append(gather_idx)
|
|
501
|
+
|
|
502
|
+
gather_out = _unique_name("gather_out")
|
|
503
|
+
gnode = helper.make_node(
|
|
504
|
+
"Gather",
|
|
505
|
+
inputs=["X", gather_idx.output[0]],
|
|
506
|
+
outputs=[gather_out],
|
|
507
|
+
axis=0,
|
|
508
|
+
)
|
|
509
|
+
nodes.append(gnode)
|
|
510
|
+
|
|
511
|
+
w_node = _make_constant_float_node(_unique_name("weight"), w, [])
|
|
512
|
+
nodes.append(w_node)
|
|
513
|
+
|
|
514
|
+
mul_out = _unique_name("mul_out")
|
|
515
|
+
mul_node = helper.make_node(
|
|
516
|
+
"Mul", inputs=[gather_out, w_node.output[0]], outputs=[mul_out]
|
|
517
|
+
)
|
|
518
|
+
nodes.append(mul_node)
|
|
519
|
+
partials.append(mul_out)
|
|
520
|
+
|
|
521
|
+
# Summation of partial products
|
|
522
|
+
if len(partials) == 1:
|
|
523
|
+
final_dot = partials[0]
|
|
524
|
+
else:
|
|
525
|
+
tmp = partials[0]
|
|
526
|
+
for p in partials[1:]:
|
|
527
|
+
add_out = _unique_name("add_out")
|
|
528
|
+
add_node = helper.make_node(
|
|
529
|
+
"Add", inputs=[tmp, p], outputs=[add_out]
|
|
530
|
+
)
|
|
531
|
+
nodes.append(add_node)
|
|
532
|
+
tmp = add_out
|
|
533
|
+
final_dot = tmp
|
|
534
|
+
|
|
535
|
+
thr_node = _make_constant_float_node(_unique_name("thr"), thr_val, [])
|
|
536
|
+
nodes.append(thr_node)
|
|
537
|
+
|
|
538
|
+
less_node = helper.make_node(
|
|
539
|
+
"Less", inputs=[final_dot, thr_node.output[0]], outputs=[cond_name]
|
|
540
|
+
)
|
|
541
|
+
nodes.append(less_node)
|
|
542
|
+
|
|
543
|
+
# Categorical split
|
|
544
|
+
elif n_category > 0:
|
|
545
|
+
f_idx = node_dict["feature_idx"]
|
|
546
|
+
fnode = _make_constant_int_node(_unique_name("catf_idx"), [f_idx], [1])
|
|
547
|
+
nodes.append(fnode)
|
|
548
|
+
|
|
549
|
+
gout = _unique_name("cat_gather_out")
|
|
550
|
+
gnode = helper.make_node(
|
|
551
|
+
"Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
|
|
552
|
+
)
|
|
553
|
+
nodes.append(gnode)
|
|
554
|
+
|
|
555
|
+
eq_outputs = []
|
|
556
|
+
for c_val in cat_list:
|
|
557
|
+
cat_node = _make_constant_float_node(_unique_name("cat_val"), c_val, [])
|
|
558
|
+
nodes.append(cat_node)
|
|
559
|
+
|
|
560
|
+
eq_out = _unique_name("eq_out")
|
|
561
|
+
eq_node = helper.make_node(
|
|
562
|
+
"Equal", inputs=[gout, cat_node.output[0]], outputs=[eq_out]
|
|
563
|
+
)
|
|
564
|
+
nodes.append(eq_node)
|
|
565
|
+
eq_outputs.append(eq_out)
|
|
566
|
+
|
|
567
|
+
if len(eq_outputs) == 1:
|
|
568
|
+
final_eq = eq_outputs[0]
|
|
569
|
+
else:
|
|
570
|
+
tmp = eq_outputs[0]
|
|
571
|
+
for eqo in eq_outputs[1:]:
|
|
572
|
+
or_out = _unique_name("or_out")
|
|
573
|
+
or_node = helper.make_node(
|
|
574
|
+
"Or", inputs=[tmp, eqo], outputs=[or_out]
|
|
575
|
+
)
|
|
576
|
+
nodes.append(or_node)
|
|
577
|
+
tmp = or_out
|
|
578
|
+
final_eq = tmp
|
|
579
|
+
|
|
580
|
+
id_node = helper.make_node(
|
|
581
|
+
"Identity", inputs=[final_eq], outputs=[cond_name]
|
|
582
|
+
)
|
|
583
|
+
nodes.append(id_node)
|
|
584
|
+
|
|
585
|
+
# Axis-aligned numeric split
|
|
586
|
+
else:
|
|
587
|
+
f_idx = node_dict["feature_idx"]
|
|
588
|
+
thr_val = node_dict["threshold"]
|
|
589
|
+
|
|
590
|
+
fnode = _make_constant_int_node(_unique_name("f_idx"), [f_idx], [1])
|
|
591
|
+
nodes.append(fnode)
|
|
592
|
+
|
|
593
|
+
gout = _unique_name("gather_out")
|
|
594
|
+
gnode = helper.make_node(
|
|
595
|
+
"Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
|
|
596
|
+
)
|
|
597
|
+
nodes.append(gnode)
|
|
598
|
+
|
|
599
|
+
thr_node = _make_constant_float_node(_unique_name("thr_val"), thr_val, [])
|
|
600
|
+
nodes.append(thr_node)
|
|
601
|
+
|
|
602
|
+
less_node = helper.make_node(
|
|
603
|
+
"Less", inputs=[gout, thr_node.output[0]], outputs=[cond_name]
|
|
604
|
+
)
|
|
605
|
+
nodes.append(less_node)
|
|
606
|
+
|
|
607
|
+
# Recursively build subgraphs for left and right
|
|
608
|
+
left_sub, left_out = _build_subgraph_for_node(node_dict["left"], n_classes)
|
|
609
|
+
right_sub, right_out = _build_subgraph_for_node(node_dict["right"], n_classes)
|
|
610
|
+
|
|
611
|
+
if_out = _unique_name("if_out")
|
|
612
|
+
if_info = helper.make_tensor_value_info(if_out, TensorProto.DOUBLE, None)
|
|
613
|
+
|
|
614
|
+
if_node = helper.make_node(
|
|
615
|
+
"If",
|
|
616
|
+
inputs=[cond_name],
|
|
617
|
+
outputs=[if_out],
|
|
618
|
+
name=_unique_name("IfNode"),
|
|
619
|
+
then_branch=left_sub,
|
|
620
|
+
else_branch=right_sub,
|
|
621
|
+
)
|
|
622
|
+
nodes.append(if_node)
|
|
623
|
+
|
|
624
|
+
subgraph = helper.make_graph(
|
|
625
|
+
nodes=nodes,
|
|
626
|
+
name=graph_name,
|
|
627
|
+
inputs=[],
|
|
628
|
+
outputs=[if_info],
|
|
629
|
+
value_info=[x_info],
|
|
630
|
+
)
|
|
631
|
+
return subgraph, if_out
|
|
632
|
+
|
|
633
|
+
# Retrieve tree parameters
|
|
634
|
+
params = tree_dict["params"]
|
|
635
|
+
n_classes = params.get("n_classes", 2)
|
|
636
|
+
n_features = params.get("n_features", 4)
|
|
637
|
+
|
|
638
|
+
# Build the root subgraph from the tree
|
|
639
|
+
root_subgraph, root_out_name = _build_subgraph_for_node(
|
|
640
|
+
tree_dict["tree"], n_classes
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Main graph I/O
|
|
644
|
+
main_input = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [n_features])
|
|
645
|
+
main_output = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, None)
|
|
646
|
+
|
|
647
|
+
# Extract nodes and value_info from the root subgraph
|
|
648
|
+
nodes = list(root_subgraph.node)
|
|
649
|
+
val_info = list(root_subgraph.value_info)
|
|
650
|
+
if_out_name = root_subgraph.output[0].name
|
|
651
|
+
|
|
652
|
+
# Add a final Identity node to map subgraph output to "Y"
|
|
653
|
+
final_out_node_name = _unique_name("final_y")
|
|
654
|
+
identity_node = helper.make_node(
|
|
655
|
+
"Identity", inputs=[if_out_name], outputs=[final_out_node_name]
|
|
656
|
+
)
|
|
657
|
+
nodes.append(identity_node)
|
|
658
|
+
main_output.name = final_out_node_name
|
|
659
|
+
|
|
660
|
+
# Construct the main graph
|
|
661
|
+
main_graph = helper.make_graph(
|
|
662
|
+
nodes=nodes,
|
|
663
|
+
name="MainGraph",
|
|
664
|
+
inputs=[main_input],
|
|
665
|
+
outputs=[main_output],
|
|
666
|
+
value_info=val_info,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Fix output shape to [1] or [n_classes]
|
|
670
|
+
if n_classes > 2:
|
|
671
|
+
dim = main_graph.output[0].type.tensor_type.shape.dim.add()
|
|
672
|
+
dim.dim_value = n_classes
|
|
673
|
+
else:
|
|
674
|
+
dim = main_graph.output[0].type.tensor_type.shape.dim.add()
|
|
675
|
+
dim.dim_value = 1
|
|
676
|
+
|
|
677
|
+
# Fix input shape to [n_features]
|
|
678
|
+
main_graph.input[0].type.tensor_type.shape.dim[0].dim_value = n_features
|
|
679
|
+
|
|
680
|
+
onnx_model = helper.make_model(
|
|
681
|
+
main_graph,
|
|
682
|
+
producer_name="custom_oblique_categorical_tree",
|
|
683
|
+
opset_imports=[helper.make_opsetid("", 13)],
|
|
684
|
+
)
|
|
685
|
+
onnx_model.ir_version = 7
|
|
686
|
+
|
|
687
|
+
return onnx_model
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _format_float(value: float) -> str:
|
|
691
|
+
"""Format float value with 3 decimal places, return '0' for 0.0"""
|
|
692
|
+
if value == 0.0:
|
|
693
|
+
return "0"
|
|
694
|
+
return f"{value:.2f}"
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _format_value_str(node: Dict[str, Any], params: Dict[str, Any]) -> str:
|
|
698
|
+
"""
|
|
699
|
+
Format value string based on task type (regression vs classification) and number of classes
|
|
700
|
+
|
|
701
|
+
Parameters:
|
|
702
|
+
-----------
|
|
703
|
+
node : Dict[str, Any]
|
|
704
|
+
The tree node dictionary containing values or value
|
|
705
|
+
params : Dict[str, Any]
|
|
706
|
+
Tree parameters containing task and n_classes information
|
|
707
|
+
"""
|
|
708
|
+
# Check if it's a regression task
|
|
709
|
+
if params["task"]:
|
|
710
|
+
value = (
|
|
711
|
+
_format_float(node["value"])
|
|
712
|
+
if isinstance(node["value"], float)
|
|
713
|
+
else node["value"]
|
|
714
|
+
)
|
|
715
|
+
return f"Value: {value}"
|
|
716
|
+
|
|
717
|
+
# Classification task
|
|
718
|
+
else:
|
|
719
|
+
if params["n_classes"] == 2: # Binary classification
|
|
720
|
+
# For binary case, node["values"] contains probability for positive class
|
|
721
|
+
p = node["value"]
|
|
722
|
+
return f"values: [{_format_float(1-p)}, {_format_float(p)}]"
|
|
723
|
+
else: # Multiclass (3 or more classes)
|
|
724
|
+
values_str = ", ".join(
|
|
725
|
+
_format_float(v) if isinstance(v, float) else str(v)
|
|
726
|
+
for v in node.get("values", [0.0] * params["n_classes"])
|
|
727
|
+
)
|
|
728
|
+
return f"values: [{values_str}]"
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def _create_oblique_expression(
|
|
732
|
+
features: list,
|
|
733
|
+
weights: list,
|
|
734
|
+
threshold: float,
|
|
735
|
+
feature_names: Optional[List[str]],
|
|
736
|
+
max_oblique: Optional[int] = None,
|
|
737
|
+
) -> str:
|
|
738
|
+
"""Create mathematical expression for oblique split with line breaks after 5 terms"""
|
|
739
|
+
terms = []
|
|
740
|
+
|
|
741
|
+
# Sort features and weights by absolute weight value
|
|
742
|
+
feature_weight_pairs = sorted(
|
|
743
|
+
zip(features, weights), key=lambda x: abs(x[1]), reverse=True
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
# Apply max_oblique limit if specified
|
|
747
|
+
if max_oblique is not None:
|
|
748
|
+
feature_weight_pairs = feature_weight_pairs[:max_oblique]
|
|
749
|
+
if len(features) > max_oblique:
|
|
750
|
+
feature_weight_pairs.append(("...", 0))
|
|
751
|
+
|
|
752
|
+
# Create terms with proper formatting
|
|
753
|
+
lines = []
|
|
754
|
+
current_line = []
|
|
755
|
+
|
|
756
|
+
for i, (f, w) in enumerate(feature_weight_pairs):
|
|
757
|
+
if f == "...":
|
|
758
|
+
current_line.append("...")
|
|
759
|
+
continue
|
|
760
|
+
|
|
761
|
+
feature_label = (
|
|
762
|
+
feature_names[f] if feature_names and f < len(feature_names) else f"f{f}"
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
if w == 1.0:
|
|
766
|
+
term = feature_label # Removed parentheses for coefficient 1
|
|
767
|
+
elif w == -1.0:
|
|
768
|
+
term = f"–{feature_label}" # Removed parentheses for coefficient -1
|
|
769
|
+
else:
|
|
770
|
+
formatted_weight = _format_float(abs(w))
|
|
771
|
+
term = f"{'– ' if w < 0 else ''}({formatted_weight} * {feature_label})"
|
|
772
|
+
|
|
773
|
+
if i > 0:
|
|
774
|
+
term = f"+ {term}" if w > 0 else f" {term}"
|
|
775
|
+
|
|
776
|
+
current_line.append(term)
|
|
777
|
+
|
|
778
|
+
# Start new line after every 5 terms
|
|
779
|
+
if len(current_line) == 5 and i < len(feature_weight_pairs) - 1:
|
|
780
|
+
lines.append(" ".join(current_line) + " +")
|
|
781
|
+
current_line = []
|
|
782
|
+
|
|
783
|
+
if current_line:
|
|
784
|
+
lines.append(" ".join(current_line))
|
|
785
|
+
|
|
786
|
+
formatted_threshold = _format_float(threshold)
|
|
787
|
+
expression = "\n".join(lines)
|
|
788
|
+
return f"{expression} ≤ {formatted_threshold}"
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
def _format_categories(categories: list, max_cat: Optional[int] = None) -> str:
|
|
792
|
+
"""Format category list with line breaks after every 5 items"""
|
|
793
|
+
if max_cat is not None and len(categories) > max_cat:
|
|
794
|
+
shown_cats = categories[:max_cat]
|
|
795
|
+
return f"[{', '.join(map(str, shown_cats))}, ...]"
|
|
796
|
+
|
|
797
|
+
formatted_cats = []
|
|
798
|
+
current_line = []
|
|
799
|
+
|
|
800
|
+
for i, cat in enumerate(categories):
|
|
801
|
+
current_line.append(str(cat))
|
|
802
|
+
|
|
803
|
+
# Add line break after every 5 items or at the end
|
|
804
|
+
if len(current_line) == 9 and i < len(categories) - 1:
|
|
805
|
+
formatted_cats.append(", ".join(current_line) + ",")
|
|
806
|
+
current_line = []
|
|
807
|
+
|
|
808
|
+
if current_line:
|
|
809
|
+
formatted_cats.append(", ".join(current_line))
|
|
810
|
+
|
|
811
|
+
if len(formatted_cats) > 1:
|
|
812
|
+
return "[" + "\n".join(formatted_cats) + "]"
|
|
813
|
+
return f"[{formatted_cats[0]}]"
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
def _check_visualize_tree_inputs(
|
|
817
|
+
tree: BaseTree,
|
|
818
|
+
feature_names: Optional[List[str]] = None,
|
|
819
|
+
max_cat: Optional[int] = None,
|
|
820
|
+
max_oblique: Optional[int] = None,
|
|
821
|
+
save_path: Optional[str] = None,
|
|
822
|
+
dpi: int = 600,
|
|
823
|
+
figsize: tuple = (20, 10),
|
|
824
|
+
) -> None:
|
|
825
|
+
"""
|
|
826
|
+
Validate the inputs for the visualize_tree function.
|
|
827
|
+
|
|
828
|
+
Parameters:
|
|
829
|
+
-----------
|
|
830
|
+
tree : object
|
|
831
|
+
The tree object to be visualized, must have a certain expected structure.
|
|
832
|
+
feature_names : Optional[List[str]]
|
|
833
|
+
If provided, must be a list of strings matching the number of features in the tree.
|
|
834
|
+
max_cat : Optional[int]
|
|
835
|
+
If provided, must be a positive integer.
|
|
836
|
+
max_oblique : Optional[int]
|
|
837
|
+
If provided, must be a positive integer.
|
|
838
|
+
save_path : Optional[str]
|
|
839
|
+
If provided, must be a valid file path ending in a supported image format (e.g., '.png').
|
|
840
|
+
dpi : int
|
|
841
|
+
Must be a positive integer.
|
|
842
|
+
figsize : tuple
|
|
843
|
+
Must be a tuple of two positive numbers.
|
|
844
|
+
"""
|
|
845
|
+
if not isinstance(tree, BaseTree):
|
|
846
|
+
raise ValueError("`tree` must be an instance of `BaseTree`.")
|
|
847
|
+
|
|
848
|
+
if not tree._fit:
|
|
849
|
+
raise ValueError(
|
|
850
|
+
"The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
if feature_names is not None:
|
|
854
|
+
if not isinstance(feature_names, list) or not all(
|
|
855
|
+
isinstance(f, str) for f in feature_names
|
|
856
|
+
):
|
|
857
|
+
raise ValueError("feature_names must be a list of strings.")
|
|
858
|
+
if len(feature_names) != tree.n_features:
|
|
859
|
+
raise ValueError(
|
|
860
|
+
f"feature_names must match the number of features in the tree ({tree.n_features})."
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
if max_cat is not None and (not isinstance(max_cat, int) or max_cat <= 0):
|
|
864
|
+
raise ValueError("max_cat must be a positive integer.")
|
|
865
|
+
|
|
866
|
+
if max_oblique is not None and (
|
|
867
|
+
not isinstance(max_oblique, int) or max_oblique <= 0
|
|
868
|
+
):
|
|
869
|
+
raise ValueError("max_oblique must be a positive integer.")
|
|
870
|
+
|
|
871
|
+
if save_path is not None and not isinstance(save_path, str):
|
|
872
|
+
raise ValueError("save_path must be a string representing a valid file path.")
|
|
873
|
+
|
|
874
|
+
if not isinstance(dpi, int) or dpi <= 0:
|
|
875
|
+
raise ValueError("dpi must be a positive integer.")
|
|
876
|
+
|
|
877
|
+
if (
|
|
878
|
+
not isinstance(figsize, tuple)
|
|
879
|
+
or len(figsize) != 2
|
|
880
|
+
or not all(isinstance(dim, (int, float)) and dim > 0 for dim in figsize)
|
|
881
|
+
):
|
|
882
|
+
raise ValueError("figsize must be a tuple of two positive numbers.")
|