obliquetree 1.0.3__cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of obliquetree might be problematic. Click here for more details.

obliquetree/utils.py ADDED
@@ -0,0 +1,882 @@
1
+ from __future__ import annotations
2
+
3
+ from .src.utils import export_tree as _export_tree
4
+ from ._pywrap import BaseTree, Classifier, Regressor
5
+ import json
6
+ from typing import Optional, Dict, Any, List, Union
7
+ from io import BytesIO
8
+ import os
9
+ import numpy as np
10
+
11
+
12
+ def load_tree(tree_data: Union[str, Dict]) -> Union[Classifier, Regressor]:
13
+ """
14
+ Load a decision tree model from a JSON file or dictionary representation.
15
+
16
+ This function reconstructs a trained decision tree from its serialized form,
17
+ either from a JSON file on disk or a dictionary containing the tree structure
18
+ and parameters.
19
+
20
+ Parameters
21
+ ----------
22
+ tree_data : Union[str, Dict]
23
+ Either:
24
+ - A string containing the file path to a JSON file containing the tree data
25
+ - A dictionary containing the serialized tree structure and parameters
26
+
27
+ Returns
28
+ -------
29
+ Union[Classifier, Regressor]
30
+ A reconstructed decision tree object. The specific type (Classifier or
31
+ Regressor) is determined by the 'task' parameter in the tree data.
32
+ """
33
+ # Handle input types
34
+ if isinstance(tree_data, str):
35
+ if not os.path.exists(tree_data):
36
+ raise FileNotFoundError(f"The file {tree_data} does not exist")
37
+
38
+ with open(tree_data, "r") as f:
39
+ tree = json.load(f)
40
+ elif isinstance(tree_data, dict):
41
+ tree = tree_data
42
+ else:
43
+ raise ValueError("Input must be a JSON string, file path, or dictionary")
44
+
45
+ # Validate tree structure
46
+ if (
47
+ not isinstance(tree, dict)
48
+ or "params" not in tree
49
+ or "task" not in tree["params"]
50
+ ):
51
+ raise ValueError("Invalid tree data structure")
52
+
53
+ # Create appropriate object based on task
54
+ if not tree["params"]["task"]:
55
+ obj = Classifier.__new__(Classifier)
56
+ else:
57
+ obj = Regressor.__new__(Regressor)
58
+
59
+ tree["_fit"] = True
60
+
61
+ obj.__setstate__(tree)
62
+
63
+ return obj
64
+
65
+
66
+ def export_tree(
67
+ tree: Union[Classifier, Regressor], out_file: str = None
68
+ ) -> Union[None, dict]:
69
+ """
70
+ Serialize a decision tree model to a dictionary or JSON file.
71
+
72
+ This function converts a trained decision tree into a portable format that can
73
+ be saved to disk or transmitted. The serialized format preserves all necessary
74
+ information to reconstruct the tree using load_tree().
75
+
76
+ Parameters
77
+ ----------
78
+ tree : Union[Classifier, Regressor]
79
+ The trained decision tree model to export. Must be an instance of either
80
+ Classifier or Regressor and must have been fitted.
81
+
82
+ out_file : str, optional
83
+ If provided, the path where the serialized tree should be saved as a JSON
84
+ file. If None, the function returns the dictionary representation instead
85
+ of saving to disk.
86
+
87
+ Returns
88
+ -------
89
+ Union[None, dict]
90
+ If out_file is None:
91
+ Returns a dictionary containing the serialized tree structure and parameters
92
+ If out_file is provided:
93
+ Returns None after saving the tree to the specified JSON file
94
+ """
95
+ if not isinstance(tree, BaseTree):
96
+ raise ValueError("`tree` must be an instance of `BaseTree`.")
97
+
98
+ if not tree._fit:
99
+ raise ValueError(
100
+ "The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
101
+ )
102
+
103
+ tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
104
+
105
+ if out_file is not None:
106
+ if isinstance(out_file, str):
107
+ with open(out_file, "w") as f:
108
+ json.dump(tree_dict, f, indent=4)
109
+ else:
110
+ raise ValueError("`out_file` must be a string if provided.")
111
+
112
+ else:
113
+ return tree_dict
114
+
115
+
116
+ def visualize_tree(
117
+ tree: Union[Classifier, Regressor],
118
+ feature_names: Optional[List[str]] = None,
119
+ max_cat: Optional[int] = None,
120
+ max_oblique: Optional[int] = None,
121
+ save_path: Optional[str] = None,
122
+ dpi: int = 600,
123
+ figsize: tuple = (20, 10),
124
+ show_gini: bool = True,
125
+ show_n_samples: bool = True,
126
+ show_node_value: bool = True,
127
+ ) -> None:
128
+ """
129
+ Generate a visual representation of a decision tree model.
130
+
131
+ Creates a graphical visualization of the tree structure showing decision nodes,
132
+ leaf nodes, and split criteria. The visualization can be customized to show
133
+ various node statistics and can be displayed or saved to a file.
134
+
135
+ Parameters
136
+ ----------
137
+ tree : Union[Classifier, Regressor]
138
+ The trained decision tree model to visualize. Must be fitted before
139
+ visualization.
140
+
141
+ feature_names : List[str], optional
142
+ Human-readable names for the features used in the tree. If provided,
143
+ these names will be used in split conditions instead of generic feature
144
+ indices (e.g., "age <= 30" instead of "f0 <= 30").
145
+
146
+ max_cat : int, optional
147
+ For categorical splits, limits the number of categories shown in the
148
+ visualization. If there are more categories than this limit, they will
149
+ be truncated with an ellipsis. Useful for splits with many categories.
150
+
151
+ max_oblique : int, optional
152
+ For oblique splits (those involving multiple features), limits the number
153
+ of features shown in the split condition. Helps manage complex oblique
154
+ splits in the visualization.
155
+
156
+ save_path : str, optional
157
+ If provided, saves the visualization to this file path. The file format
158
+ is determined by the file extension (e.g., '.png', '.pdf').
159
+
160
+ dpi : int, default=600
161
+ The resolution (dots per inch) of the saved image. Only relevant if
162
+ save_path is provided.
163
+
164
+ figsize : tuple, default=(20, 10)
165
+ The width and height of the figure in inches.
166
+
167
+ show_gini : bool, default=True
168
+ Whether to display Gini impurity values in the nodes.
169
+
170
+ show_n_samples : bool, default=True
171
+ Whether to display the number of samples that reach each node.
172
+
173
+ show_node_value : bool, default=True
174
+ Whether to display the predicted value/class distributions in each node.
175
+
176
+ Returns
177
+ -------
178
+ None
179
+ The function displays the visualization and optionally saves it to disk.
180
+ """
181
+ _check_visualize_tree_inputs(
182
+ tree, feature_names, max_cat, max_oblique, save_path, dpi, figsize
183
+ )
184
+
185
+ try:
186
+ from graphviz import Digraph
187
+ except ImportError:
188
+ raise ImportError(
189
+ "graphviz is not installed. Please install it to use this function."
190
+ )
191
+
192
+ try:
193
+ from matplotlib.pyplot import figure, imshow, imread, axis, savefig, show
194
+ except:
195
+ raise ImportError(
196
+ "matplotlib is not installed. Please install it to use this function."
197
+ )
198
+
199
+ tree_dict = _export_tree(tree) # Assuming this function is implemented elsewhere.
200
+
201
+ node, params = tree_dict["tree"], tree_dict["params"]
202
+
203
+ def _visualize_recursive(node, graph=None, parent=None, edge_label=""):
204
+ if graph is None:
205
+ graph = Digraph(format="png")
206
+ graph.graph_attr.update(
207
+ {
208
+ "rankdir": "TB",
209
+ "ranksep": "0.3",
210
+ "nodesep": "0.2",
211
+ "splines": "polyline",
212
+ "ordering": "out",
213
+ }
214
+ )
215
+ graph.attr(
216
+ "node",
217
+ shape="box",
218
+ style="filled",
219
+ color="lightgrey",
220
+ fontname="Helvetica",
221
+ margin="0.2",
222
+ )
223
+
224
+ node_id = str(id(node))
225
+ label_parts = []
226
+
227
+ is_leaf = "left" not in node and "right" not in node
228
+
229
+ if is_leaf:
230
+ # For leaf nodes
231
+ label_parts.append("leaf")
232
+ label_parts.append(_format_value_str(node, params))
233
+
234
+ # Add impurity for leaf nodes if requested and available
235
+ if show_gini and "impurity" in node:
236
+ label_parts.append(f"impurity: {_format_float(node['impurity'])}")
237
+
238
+ # Add n_samples for leaf nodes if requested and available
239
+ if show_n_samples and "n_samples" in node:
240
+ label_parts.append(f"n_samples: {node['n_samples']}")
241
+
242
+ graph.node(
243
+ node_id,
244
+ label="\n".join(label_parts),
245
+ shape="box",
246
+ style="filled",
247
+ color="lightblue",
248
+ fontname="Helvetica",
249
+ )
250
+ else:
251
+ # For internal nodes
252
+ # First add the split information
253
+ if node.get("is_oblique", False):
254
+ split_info = _create_oblique_expression(
255
+ node["features"],
256
+ node["weights"],
257
+ node["threshold"],
258
+ feature_names,
259
+ max_oblique,
260
+ )
261
+ elif "category_left" in node:
262
+ categories = node["category_left"]
263
+ cat_str = _format_categories(categories, max_cat)
264
+ feature_label = (
265
+ feature_names[node["feature_idx"]]
266
+ if feature_names and node["feature_idx"] < len(feature_names)
267
+ else f"f{node['feature_idx']}"
268
+ )
269
+ split_info = f"{feature_label} in {cat_str}"
270
+ else:
271
+ threshold = (
272
+ _format_float(node["threshold"])
273
+ if isinstance(node["threshold"], float)
274
+ else node["threshold"]
275
+ )
276
+ feature_label = (
277
+ feature_names[node["feature_idx"]]
278
+ if feature_names and node["feature_idx"] < len(feature_names)
279
+ else f"f{node['feature_idx']}"
280
+ )
281
+ split_info = f"{feature_label} ≤ {threshold}"
282
+
283
+ label_parts.append(split_info)
284
+
285
+ if show_node_value:
286
+ label_parts.append(_format_value_str(node, params))
287
+
288
+ # Add Gini impurity if requested
289
+ if show_gini and "impurity" in node:
290
+ label_parts.append(f"impurity: {_format_float(node['impurity'])}")
291
+
292
+ # Add n_samples if requested
293
+ if show_n_samples and "n_samples" in node:
294
+ label_parts.append(f"n_samples: {node['n_samples']}")
295
+
296
+ graph.node(
297
+ node_id,
298
+ label="\n".join(label_parts),
299
+ shape="box",
300
+ style="filled",
301
+ color="lightgrey",
302
+ fontname="Helvetica",
303
+ )
304
+
305
+ if parent is not None:
306
+ graph.edge(
307
+ parent,
308
+ node_id,
309
+ label=edge_label,
310
+ fontname="Helvetica",
311
+ penwidth="1.0",
312
+ minlen="1",
313
+ )
314
+
315
+ if "left" in node:
316
+ _visualize_recursive(node["left"], graph, node_id, "Left")
317
+ if "right" in node:
318
+ _visualize_recursive(node["right"], graph, node_id, "Right")
319
+
320
+ return graph
321
+
322
+ graph = _visualize_recursive(node)
323
+ png_data = graph.pipe(format="png")
324
+
325
+ figure(figsize=figsize)
326
+ imshow(imread(BytesIO(png_data)))
327
+ axis("off")
328
+
329
+ if save_path:
330
+ savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0)
331
+
332
+ show()
333
+
334
+ def export_tree_to_onnx(tree: Union[Classifier, Regressor]) -> None:
335
+ """
336
+ Convert an oblique decision tree (Classifier or Regressor) into an ONNX model.
337
+
338
+ .. important::
339
+ - This implementation currently does **not** support batch processing.
340
+ Only a single row (1D NumPy array) and np.float64 dtype can be passed as input.
341
+ - The input variable name must be **"X"** and its shape should be (n_features,).
342
+ - In binary classification, the output is a single-dimensional value representing
343
+ the probability of belonging to the positive class.
344
+
345
+ Parameters
346
+ ----------
347
+ tree : Union[Classifier, Regressor]
348
+ The oblique decision tree (classifier or regressor) to be converted to ONNX.
349
+
350
+ Returns
351
+ -------
352
+ onnx.ModelProto
353
+ The constructed ONNX model.
354
+
355
+ Examples
356
+ --------
357
+ >>> # Suppose we have a 2D NumPy array X of shape (num_samples, num_features).
358
+ >>> # We only take a single row for prediction:
359
+ >>> X_sample = X[0, :]
360
+ >>>
361
+ >>> # Create an inference session using onnxruntime:
362
+ >>> import onnxruntime
363
+ >>> session = onnxruntime.InferenceSession("tree.onnx")
364
+ >>>
365
+ >>> # Retrieve the output name of the model
366
+ >>> out_name = session.get_outputs()[0].name
367
+ >>>
368
+ >>> # Perform inference on the sample
369
+ >>> y_pred = session.run([out_name], {"X": X_sample})[0]
370
+ >>> print(y_pred)
371
+ """
372
+ try:
373
+ from onnx import helper, TensorProto
374
+ except ImportError as e:
375
+ raise ImportError(
376
+ "Failed to import onnx dependencies. Please make sure the 'onnx' "
377
+ "package is installed."
378
+ ) from e
379
+
380
+ tree_dict = export_tree(tree)
381
+
382
+ # Closure for unique name generation
383
+ name_counter = [0]
384
+
385
+ def _unique_name(prefix="Node"):
386
+ name_counter[0] += 1
387
+ return f"{prefix}_{name_counter[0]}"
388
+
389
+ def _make_constant_int_node(name, value, shape=None):
390
+ """
391
+ Creates an ONNX Constant node containing int64 data.
392
+ Useful for indices in Gather or other integer-only parameters.
393
+ """
394
+ if shape is None:
395
+ shape = [len(value)] if isinstance(value, list) else []
396
+ arr = (
397
+ np.array(value, dtype=np.int64)
398
+ if isinstance(value, list)
399
+ else (
400
+ np.array([value], dtype=np.int64)
401
+ if shape == []
402
+ else np.array(value, dtype=np.int64)
403
+ )
404
+ )
405
+
406
+ const_tensor = helper.make_tensor(
407
+ name=_unique_name("const_data_int"),
408
+ data_type=TensorProto.INT64,
409
+ dims=arr.shape,
410
+ vals=arr.flatten().tolist(),
411
+ )
412
+
413
+ node = helper.make_node(
414
+ "Constant", inputs=[], outputs=[name], value=const_tensor
415
+ )
416
+ return node
417
+
418
+ def _make_constant_float_node(name, value, shape=None):
419
+ """
420
+ Creates an ONNX Constant node containing float64 data.
421
+ Useful for thresholds, weights, etc.
422
+ """
423
+ if shape is None:
424
+ shape = [len(value)] if isinstance(value, list) else []
425
+ arr = (
426
+ np.array(value, dtype=np.float64)
427
+ if isinstance(value, list)
428
+ else np.array([value], dtype=np.float64)
429
+ )
430
+
431
+ if shape and arr.shape != tuple(shape):
432
+ arr = arr.reshape(shape)
433
+
434
+ const_tensor = helper.make_tensor(
435
+ name=_unique_name("const_data_float"),
436
+ data_type=TensorProto.DOUBLE,
437
+ dims=arr.shape,
438
+ vals=arr.flatten().tolist(),
439
+ )
440
+ node = helper.make_node(
441
+ "Constant", inputs=[], outputs=[name], value=const_tensor
442
+ )
443
+ return node
444
+
445
+ def _build_subgraph_for_node(node_dict, n_classes):
446
+ """
447
+ Recursively builds a subgraph (for 'If' branches) from the given node definition.
448
+ The subgraph uses 'X' as an outer-scope input (not declared in inputs[]).
449
+ """
450
+ nodes = []
451
+ graph_name = _unique_name("SubGraph")
452
+
453
+ # Subgraph output
454
+ out_name = _unique_name("sub_out")
455
+ out_info = helper.make_tensor_value_info(out_name, TensorProto.DOUBLE, None)
456
+
457
+ # Reference to 'X' from the outer scope
458
+ x_info = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [None])
459
+
460
+ # If this is a leaf node
461
+ if node_dict["is_leaf"]:
462
+ if "values" in node_dict and isinstance(node_dict["values"], list):
463
+ # Multi-class leaf
464
+ val_array = node_dict["values"]
465
+ shape = [len(val_array)]
466
+ cnode = _make_constant_float_node(out_name, val_array, shape)
467
+ nodes.append(cnode)
468
+ else:
469
+ # Single-value leaf (binary or regression)
470
+ val = node_dict["value"]
471
+ cnode = _make_constant_float_node(out_name, val, [])
472
+ nodes.append(cnode)
473
+
474
+ subgraph = helper.make_graph(
475
+ nodes=nodes,
476
+ name=graph_name,
477
+ inputs=[],
478
+ outputs=[out_info],
479
+ value_info=[x_info],
480
+ )
481
+ return subgraph, out_name
482
+
483
+ # Otherwise, this node is a split
484
+ cond_name = _unique_name("cond_bool")
485
+ is_oblique = node_dict.get("is_oblique", False)
486
+ cat_list = node_dict.get("category_left", [])
487
+ n_category = len(cat_list)
488
+
489
+ # Oblique split
490
+ if is_oblique:
491
+ w_list = node_dict["weights"]
492
+ f_list = node_dict["features"]
493
+ thr_val = node_dict["threshold"]
494
+
495
+ partials = []
496
+ for w, f_idx in zip(w_list, f_list):
497
+ gather_idx = _make_constant_int_node(
498
+ _unique_name("gather_idx"), [f_idx], [1]
499
+ )
500
+ nodes.append(gather_idx)
501
+
502
+ gather_out = _unique_name("gather_out")
503
+ gnode = helper.make_node(
504
+ "Gather",
505
+ inputs=["X", gather_idx.output[0]],
506
+ outputs=[gather_out],
507
+ axis=0,
508
+ )
509
+ nodes.append(gnode)
510
+
511
+ w_node = _make_constant_float_node(_unique_name("weight"), w, [])
512
+ nodes.append(w_node)
513
+
514
+ mul_out = _unique_name("mul_out")
515
+ mul_node = helper.make_node(
516
+ "Mul", inputs=[gather_out, w_node.output[0]], outputs=[mul_out]
517
+ )
518
+ nodes.append(mul_node)
519
+ partials.append(mul_out)
520
+
521
+ # Summation of partial products
522
+ if len(partials) == 1:
523
+ final_dot = partials[0]
524
+ else:
525
+ tmp = partials[0]
526
+ for p in partials[1:]:
527
+ add_out = _unique_name("add_out")
528
+ add_node = helper.make_node(
529
+ "Add", inputs=[tmp, p], outputs=[add_out]
530
+ )
531
+ nodes.append(add_node)
532
+ tmp = add_out
533
+ final_dot = tmp
534
+
535
+ thr_node = _make_constant_float_node(_unique_name("thr"), thr_val, [])
536
+ nodes.append(thr_node)
537
+
538
+ less_node = helper.make_node(
539
+ "Less", inputs=[final_dot, thr_node.output[0]], outputs=[cond_name]
540
+ )
541
+ nodes.append(less_node)
542
+
543
+ # Categorical split
544
+ elif n_category > 0:
545
+ f_idx = node_dict["feature_idx"]
546
+ fnode = _make_constant_int_node(_unique_name("catf_idx"), [f_idx], [1])
547
+ nodes.append(fnode)
548
+
549
+ gout = _unique_name("cat_gather_out")
550
+ gnode = helper.make_node(
551
+ "Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
552
+ )
553
+ nodes.append(gnode)
554
+
555
+ eq_outputs = []
556
+ for c_val in cat_list:
557
+ cat_node = _make_constant_float_node(_unique_name("cat_val"), c_val, [])
558
+ nodes.append(cat_node)
559
+
560
+ eq_out = _unique_name("eq_out")
561
+ eq_node = helper.make_node(
562
+ "Equal", inputs=[gout, cat_node.output[0]], outputs=[eq_out]
563
+ )
564
+ nodes.append(eq_node)
565
+ eq_outputs.append(eq_out)
566
+
567
+ if len(eq_outputs) == 1:
568
+ final_eq = eq_outputs[0]
569
+ else:
570
+ tmp = eq_outputs[0]
571
+ for eqo in eq_outputs[1:]:
572
+ or_out = _unique_name("or_out")
573
+ or_node = helper.make_node(
574
+ "Or", inputs=[tmp, eqo], outputs=[or_out]
575
+ )
576
+ nodes.append(or_node)
577
+ tmp = or_out
578
+ final_eq = tmp
579
+
580
+ id_node = helper.make_node(
581
+ "Identity", inputs=[final_eq], outputs=[cond_name]
582
+ )
583
+ nodes.append(id_node)
584
+
585
+ # Axis-aligned numeric split
586
+ else:
587
+ f_idx = node_dict["feature_idx"]
588
+ thr_val = node_dict["threshold"]
589
+
590
+ fnode = _make_constant_int_node(_unique_name("f_idx"), [f_idx], [1])
591
+ nodes.append(fnode)
592
+
593
+ gout = _unique_name("gather_out")
594
+ gnode = helper.make_node(
595
+ "Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
596
+ )
597
+ nodes.append(gnode)
598
+
599
+ thr_node = _make_constant_float_node(_unique_name("thr_val"), thr_val, [])
600
+ nodes.append(thr_node)
601
+
602
+ less_node = helper.make_node(
603
+ "Less", inputs=[gout, thr_node.output[0]], outputs=[cond_name]
604
+ )
605
+ nodes.append(less_node)
606
+
607
+ # Recursively build subgraphs for left and right
608
+ left_sub, left_out = _build_subgraph_for_node(node_dict["left"], n_classes)
609
+ right_sub, right_out = _build_subgraph_for_node(node_dict["right"], n_classes)
610
+
611
+ if_out = _unique_name("if_out")
612
+ if_info = helper.make_tensor_value_info(if_out, TensorProto.DOUBLE, None)
613
+
614
+ if_node = helper.make_node(
615
+ "If",
616
+ inputs=[cond_name],
617
+ outputs=[if_out],
618
+ name=_unique_name("IfNode"),
619
+ then_branch=left_sub,
620
+ else_branch=right_sub,
621
+ )
622
+ nodes.append(if_node)
623
+
624
+ subgraph = helper.make_graph(
625
+ nodes=nodes,
626
+ name=graph_name,
627
+ inputs=[],
628
+ outputs=[if_info],
629
+ value_info=[x_info],
630
+ )
631
+ return subgraph, if_out
632
+
633
+ # Retrieve tree parameters
634
+ params = tree_dict["params"]
635
+ n_classes = params.get("n_classes", 2)
636
+ n_features = params.get("n_features", 4)
637
+
638
+ # Build the root subgraph from the tree
639
+ root_subgraph, root_out_name = _build_subgraph_for_node(
640
+ tree_dict["tree"], n_classes
641
+ )
642
+
643
+ # Main graph I/O
644
+ main_input = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [n_features])
645
+ main_output = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, None)
646
+
647
+ # Extract nodes and value_info from the root subgraph
648
+ nodes = list(root_subgraph.node)
649
+ val_info = list(root_subgraph.value_info)
650
+ if_out_name = root_subgraph.output[0].name
651
+
652
+ # Add a final Identity node to map subgraph output to "Y"
653
+ final_out_node_name = _unique_name("final_y")
654
+ identity_node = helper.make_node(
655
+ "Identity", inputs=[if_out_name], outputs=[final_out_node_name]
656
+ )
657
+ nodes.append(identity_node)
658
+ main_output.name = final_out_node_name
659
+
660
+ # Construct the main graph
661
+ main_graph = helper.make_graph(
662
+ nodes=nodes,
663
+ name="MainGraph",
664
+ inputs=[main_input],
665
+ outputs=[main_output],
666
+ value_info=val_info,
667
+ )
668
+
669
+ # Fix output shape to [1] or [n_classes]
670
+ if n_classes > 2:
671
+ dim = main_graph.output[0].type.tensor_type.shape.dim.add()
672
+ dim.dim_value = n_classes
673
+ else:
674
+ dim = main_graph.output[0].type.tensor_type.shape.dim.add()
675
+ dim.dim_value = 1
676
+
677
+ # Fix input shape to [n_features]
678
+ main_graph.input[0].type.tensor_type.shape.dim[0].dim_value = n_features
679
+
680
+ onnx_model = helper.make_model(
681
+ main_graph,
682
+ producer_name="custom_oblique_categorical_tree",
683
+ opset_imports=[helper.make_opsetid("", 13)],
684
+ )
685
+ onnx_model.ir_version = 7
686
+
687
+ return onnx_model
688
+
689
+
690
+ def _format_float(value: float) -> str:
691
+ """Format float value with 3 decimal places, return '0' for 0.0"""
692
+ if value == 0.0:
693
+ return "0"
694
+ return f"{value:.2f}"
695
+
696
+
697
+ def _format_value_str(node: Dict[str, Any], params: Dict[str, Any]) -> str:
698
+ """
699
+ Format value string based on task type (regression vs classification) and number of classes
700
+
701
+ Parameters:
702
+ -----------
703
+ node : Dict[str, Any]
704
+ The tree node dictionary containing values or value
705
+ params : Dict[str, Any]
706
+ Tree parameters containing task and n_classes information
707
+ """
708
+ # Check if it's a regression task
709
+ if params["task"]:
710
+ value = (
711
+ _format_float(node["value"])
712
+ if isinstance(node["value"], float)
713
+ else node["value"]
714
+ )
715
+ return f"Value: {value}"
716
+
717
+ # Classification task
718
+ else:
719
+ if params["n_classes"] == 2: # Binary classification
720
+ # For binary case, node["values"] contains probability for positive class
721
+ p = node["value"]
722
+ return f"values: [{_format_float(1-p)}, {_format_float(p)}]"
723
+ else: # Multiclass (3 or more classes)
724
+ values_str = ", ".join(
725
+ _format_float(v) if isinstance(v, float) else str(v)
726
+ for v in node.get("values", [0.0] * params["n_classes"])
727
+ )
728
+ return f"values: [{values_str}]"
729
+
730
+
731
+ def _create_oblique_expression(
732
+ features: list,
733
+ weights: list,
734
+ threshold: float,
735
+ feature_names: Optional[List[str]],
736
+ max_oblique: Optional[int] = None,
737
+ ) -> str:
738
+ """Create mathematical expression for oblique split with line breaks after 5 terms"""
739
+ terms = []
740
+
741
+ # Sort features and weights by absolute weight value
742
+ feature_weight_pairs = sorted(
743
+ zip(features, weights), key=lambda x: abs(x[1]), reverse=True
744
+ )
745
+
746
+ # Apply max_oblique limit if specified
747
+ if max_oblique is not None:
748
+ feature_weight_pairs = feature_weight_pairs[:max_oblique]
749
+ if len(features) > max_oblique:
750
+ feature_weight_pairs.append(("...", 0))
751
+
752
+ # Create terms with proper formatting
753
+ lines = []
754
+ current_line = []
755
+
756
+ for i, (f, w) in enumerate(feature_weight_pairs):
757
+ if f == "...":
758
+ current_line.append("...")
759
+ continue
760
+
761
+ feature_label = (
762
+ feature_names[f] if feature_names and f < len(feature_names) else f"f{f}"
763
+ )
764
+
765
+ if w == 1.0:
766
+ term = feature_label # Removed parentheses for coefficient 1
767
+ elif w == -1.0:
768
+ term = f"–{feature_label}" # Removed parentheses for coefficient -1
769
+ else:
770
+ formatted_weight = _format_float(abs(w))
771
+ term = f"{'– ' if w < 0 else ''}({formatted_weight} * {feature_label})"
772
+
773
+ if i > 0:
774
+ term = f"+ {term}" if w > 0 else f" {term}"
775
+
776
+ current_line.append(term)
777
+
778
+ # Start new line after every 5 terms
779
+ if len(current_line) == 5 and i < len(feature_weight_pairs) - 1:
780
+ lines.append(" ".join(current_line) + " +")
781
+ current_line = []
782
+
783
+ if current_line:
784
+ lines.append(" ".join(current_line))
785
+
786
+ formatted_threshold = _format_float(threshold)
787
+ expression = "\n".join(lines)
788
+ return f"{expression} ≤ {formatted_threshold}"
789
+
790
+
791
+ def _format_categories(categories: list, max_cat: Optional[int] = None) -> str:
792
+ """Format category list with line breaks after every 5 items"""
793
+ if max_cat is not None and len(categories) > max_cat:
794
+ shown_cats = categories[:max_cat]
795
+ return f"[{', '.join(map(str, shown_cats))}, ...]"
796
+
797
+ formatted_cats = []
798
+ current_line = []
799
+
800
+ for i, cat in enumerate(categories):
801
+ current_line.append(str(cat))
802
+
803
+ # Add line break after every 5 items or at the end
804
+ if len(current_line) == 9 and i < len(categories) - 1:
805
+ formatted_cats.append(", ".join(current_line) + ",")
806
+ current_line = []
807
+
808
+ if current_line:
809
+ formatted_cats.append(", ".join(current_line))
810
+
811
+ if len(formatted_cats) > 1:
812
+ return "[" + "\n".join(formatted_cats) + "]"
813
+ return f"[{formatted_cats[0]}]"
814
+
815
+
816
+ def _check_visualize_tree_inputs(
817
+ tree: BaseTree,
818
+ feature_names: Optional[List[str]] = None,
819
+ max_cat: Optional[int] = None,
820
+ max_oblique: Optional[int] = None,
821
+ save_path: Optional[str] = None,
822
+ dpi: int = 600,
823
+ figsize: tuple = (20, 10),
824
+ ) -> None:
825
+ """
826
+ Validate the inputs for the visualize_tree function.
827
+
828
+ Parameters:
829
+ -----------
830
+ tree : object
831
+ The tree object to be visualized, must have a certain expected structure.
832
+ feature_names : Optional[List[str]]
833
+ If provided, must be a list of strings matching the number of features in the tree.
834
+ max_cat : Optional[int]
835
+ If provided, must be a positive integer.
836
+ max_oblique : Optional[int]
837
+ If provided, must be a positive integer.
838
+ save_path : Optional[str]
839
+ If provided, must be a valid file path ending in a supported image format (e.g., '.png').
840
+ dpi : int
841
+ Must be a positive integer.
842
+ figsize : tuple
843
+ Must be a tuple of two positive numbers.
844
+ """
845
+ if not isinstance(tree, BaseTree):
846
+ raise ValueError("`tree` must be an instance of `BaseTree`.")
847
+
848
+ if not tree._fit:
849
+ raise ValueError(
850
+ "The tree has not been fitted yet. Please call the 'fit' method to train the tree before using this function."
851
+ )
852
+
853
+ if feature_names is not None:
854
+ if not isinstance(feature_names, list) or not all(
855
+ isinstance(f, str) for f in feature_names
856
+ ):
857
+ raise ValueError("feature_names must be a list of strings.")
858
+ if len(feature_names) != tree.n_features:
859
+ raise ValueError(
860
+ f"feature_names must match the number of features in the tree ({tree.n_features})."
861
+ )
862
+
863
+ if max_cat is not None and (not isinstance(max_cat, int) or max_cat <= 0):
864
+ raise ValueError("max_cat must be a positive integer.")
865
+
866
+ if max_oblique is not None and (
867
+ not isinstance(max_oblique, int) or max_oblique <= 0
868
+ ):
869
+ raise ValueError("max_oblique must be a positive integer.")
870
+
871
+ if save_path is not None and not isinstance(save_path, str):
872
+ raise ValueError("save_path must be a string representing a valid file path.")
873
+
874
+ if not isinstance(dpi, int) or dpi <= 0:
875
+ raise ValueError("dpi must be a positive integer.")
876
+
877
+ if (
878
+ not isinstance(figsize, tuple)
879
+ or len(figsize) != 2
880
+ or not all(isinstance(dim, (int, float)) and dim > 0 for dim in figsize)
881
+ ):
882
+ raise ValueError("figsize must be a tuple of two positive numbers.")