obliquetree 1.0.1__cp313-cp313-win_amd64.whl → 1.0.3__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of obliquetree might be problematic. Click here for more details.

obliquetree/utils.py CHANGED
@@ -6,6 +6,7 @@ import json
6
6
  from typing import Optional, Dict, Any, List, Union
7
7
  from io import BytesIO
8
8
  import os
9
+ import numpy as np
9
10
 
10
11
 
11
12
  def load_tree(tree_data: Union[str, Dict]) -> Union[Classifier, Regressor]:
@@ -329,6 +330,361 @@ def visualize_tree(
329
330
  savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0)
330
331
 
331
332
  show()
333
+
334
+ def export_tree_to_onnx(tree: Union[Classifier, Regressor]) -> None:
335
+ """
336
+ Convert an oblique decision tree (Classifier or Regressor) into an ONNX model.
337
+
338
+ .. important::
339
+ - This implementation currently does **not** support batch processing.
340
+ Only a single row (1D NumPy array) and np.float64 dtype can be passed as input.
341
+ - The input variable name must be **"X"** and its shape should be (n_features,).
342
+ - In binary classification, the output is a single-dimensional value representing
343
+ the probability of belonging to the positive class.
344
+
345
+ Parameters
346
+ ----------
347
+ tree : Union[Classifier, Regressor]
348
+ The oblique decision tree (classifier or regressor) to be converted to ONNX.
349
+
350
+ Returns
351
+ -------
352
+ onnx.ModelProto
353
+ The constructed ONNX model.
354
+
355
+ Examples
356
+ --------
357
+ >>> # Suppose we have a 2D NumPy array X of shape (num_samples, num_features).
358
+ >>> # We only take a single row for prediction:
359
+ >>> X_sample = X[0, :]
360
+ >>>
361
+ >>> # Create an inference session using onnxruntime:
362
+ >>> import onnxruntime
363
+ >>> session = onnxruntime.InferenceSession("tree.onnx")
364
+ >>>
365
+ >>> # Retrieve the output name of the model
366
+ >>> out_name = session.get_outputs()[0].name
367
+ >>>
368
+ >>> # Perform inference on the sample
369
+ >>> y_pred = session.run([out_name], {"X": X_sample})[0]
370
+ >>> print(y_pred)
371
+ """
372
+ try:
373
+ from onnx import helper, TensorProto
374
+ except ImportError as e:
375
+ raise ImportError(
376
+ "Failed to import onnx dependencies. Please make sure the 'onnx' "
377
+ "package is installed."
378
+ ) from e
379
+
380
+ tree_dict = export_tree(tree)
381
+
382
+ # Closure for unique name generation
383
+ name_counter = [0]
384
+
385
+ def _unique_name(prefix="Node"):
386
+ name_counter[0] += 1
387
+ return f"{prefix}_{name_counter[0]}"
388
+
389
+ def _make_constant_int_node(name, value, shape=None):
390
+ """
391
+ Creates an ONNX Constant node containing int64 data.
392
+ Useful for indices in Gather or other integer-only parameters.
393
+ """
394
+ if shape is None:
395
+ shape = [len(value)] if isinstance(value, list) else []
396
+ arr = (
397
+ np.array(value, dtype=np.int64)
398
+ if isinstance(value, list)
399
+ else (
400
+ np.array([value], dtype=np.int64)
401
+ if shape == []
402
+ else np.array(value, dtype=np.int64)
403
+ )
404
+ )
405
+
406
+ const_tensor = helper.make_tensor(
407
+ name=_unique_name("const_data_int"),
408
+ data_type=TensorProto.INT64,
409
+ dims=arr.shape,
410
+ vals=arr.flatten().tolist(),
411
+ )
412
+
413
+ node = helper.make_node(
414
+ "Constant", inputs=[], outputs=[name], value=const_tensor
415
+ )
416
+ return node
417
+
418
+ def _make_constant_float_node(name, value, shape=None):
419
+ """
420
+ Creates an ONNX Constant node containing float64 data.
421
+ Useful for thresholds, weights, etc.
422
+ """
423
+ if shape is None:
424
+ shape = [len(value)] if isinstance(value, list) else []
425
+ arr = (
426
+ np.array(value, dtype=np.float64)
427
+ if isinstance(value, list)
428
+ else np.array([value], dtype=np.float64)
429
+ )
430
+
431
+ if shape and arr.shape != tuple(shape):
432
+ arr = arr.reshape(shape)
433
+
434
+ const_tensor = helper.make_tensor(
435
+ name=_unique_name("const_data_float"),
436
+ data_type=TensorProto.DOUBLE,
437
+ dims=arr.shape,
438
+ vals=arr.flatten().tolist(),
439
+ )
440
+ node = helper.make_node(
441
+ "Constant", inputs=[], outputs=[name], value=const_tensor
442
+ )
443
+ return node
444
+
445
+ def _build_subgraph_for_node(node_dict, n_classes):
446
+ """
447
+ Recursively builds a subgraph (for 'If' branches) from the given node definition.
448
+ The subgraph uses 'X' as an outer-scope input (not declared in inputs[]).
449
+ """
450
+ nodes = []
451
+ graph_name = _unique_name("SubGraph")
452
+
453
+ # Subgraph output
454
+ out_name = _unique_name("sub_out")
455
+ out_info = helper.make_tensor_value_info(out_name, TensorProto.DOUBLE, None)
456
+
457
+ # Reference to 'X' from the outer scope
458
+ x_info = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [None])
459
+
460
+ # If this is a leaf node
461
+ if node_dict["is_leaf"]:
462
+ if "values" in node_dict and isinstance(node_dict["values"], list):
463
+ # Multi-class leaf
464
+ val_array = node_dict["values"]
465
+ shape = [len(val_array)]
466
+ cnode = _make_constant_float_node(out_name, val_array, shape)
467
+ nodes.append(cnode)
468
+ else:
469
+ # Single-value leaf (binary or regression)
470
+ val = node_dict["value"]
471
+ cnode = _make_constant_float_node(out_name, val, [])
472
+ nodes.append(cnode)
473
+
474
+ subgraph = helper.make_graph(
475
+ nodes=nodes,
476
+ name=graph_name,
477
+ inputs=[],
478
+ outputs=[out_info],
479
+ value_info=[x_info],
480
+ )
481
+ return subgraph, out_name
482
+
483
+ # Otherwise, this node is a split
484
+ cond_name = _unique_name("cond_bool")
485
+ is_oblique = node_dict.get("is_oblique", False)
486
+ cat_list = node_dict.get("category_left", [])
487
+ n_category = len(cat_list)
488
+
489
+ # Oblique split
490
+ if is_oblique:
491
+ w_list = node_dict["weights"]
492
+ f_list = node_dict["features"]
493
+ thr_val = node_dict["threshold"]
494
+
495
+ partials = []
496
+ for w, f_idx in zip(w_list, f_list):
497
+ gather_idx = _make_constant_int_node(
498
+ _unique_name("gather_idx"), [f_idx], [1]
499
+ )
500
+ nodes.append(gather_idx)
501
+
502
+ gather_out = _unique_name("gather_out")
503
+ gnode = helper.make_node(
504
+ "Gather",
505
+ inputs=["X", gather_idx.output[0]],
506
+ outputs=[gather_out],
507
+ axis=0,
508
+ )
509
+ nodes.append(gnode)
510
+
511
+ w_node = _make_constant_float_node(_unique_name("weight"), w, [])
512
+ nodes.append(w_node)
513
+
514
+ mul_out = _unique_name("mul_out")
515
+ mul_node = helper.make_node(
516
+ "Mul", inputs=[gather_out, w_node.output[0]], outputs=[mul_out]
517
+ )
518
+ nodes.append(mul_node)
519
+ partials.append(mul_out)
520
+
521
+ # Summation of partial products
522
+ if len(partials) == 1:
523
+ final_dot = partials[0]
524
+ else:
525
+ tmp = partials[0]
526
+ for p in partials[1:]:
527
+ add_out = _unique_name("add_out")
528
+ add_node = helper.make_node(
529
+ "Add", inputs=[tmp, p], outputs=[add_out]
530
+ )
531
+ nodes.append(add_node)
532
+ tmp = add_out
533
+ final_dot = tmp
534
+
535
+ thr_node = _make_constant_float_node(_unique_name("thr"), thr_val, [])
536
+ nodes.append(thr_node)
537
+
538
+ less_node = helper.make_node(
539
+ "Less", inputs=[final_dot, thr_node.output[0]], outputs=[cond_name]
540
+ )
541
+ nodes.append(less_node)
542
+
543
+ # Categorical split
544
+ elif n_category > 0:
545
+ f_idx = node_dict["feature_idx"]
546
+ fnode = _make_constant_int_node(_unique_name("catf_idx"), [f_idx], [1])
547
+ nodes.append(fnode)
548
+
549
+ gout = _unique_name("cat_gather_out")
550
+ gnode = helper.make_node(
551
+ "Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
552
+ )
553
+ nodes.append(gnode)
554
+
555
+ eq_outputs = []
556
+ for c_val in cat_list:
557
+ cat_node = _make_constant_float_node(_unique_name("cat_val"), c_val, [])
558
+ nodes.append(cat_node)
559
+
560
+ eq_out = _unique_name("eq_out")
561
+ eq_node = helper.make_node(
562
+ "Equal", inputs=[gout, cat_node.output[0]], outputs=[eq_out]
563
+ )
564
+ nodes.append(eq_node)
565
+ eq_outputs.append(eq_out)
566
+
567
+ if len(eq_outputs) == 1:
568
+ final_eq = eq_outputs[0]
569
+ else:
570
+ tmp = eq_outputs[0]
571
+ for eqo in eq_outputs[1:]:
572
+ or_out = _unique_name("or_out")
573
+ or_node = helper.make_node(
574
+ "Or", inputs=[tmp, eqo], outputs=[or_out]
575
+ )
576
+ nodes.append(or_node)
577
+ tmp = or_out
578
+ final_eq = tmp
579
+
580
+ id_node = helper.make_node(
581
+ "Identity", inputs=[final_eq], outputs=[cond_name]
582
+ )
583
+ nodes.append(id_node)
584
+
585
+ # Axis-aligned numeric split
586
+ else:
587
+ f_idx = node_dict["feature_idx"]
588
+ thr_val = node_dict["threshold"]
589
+
590
+ fnode = _make_constant_int_node(_unique_name("f_idx"), [f_idx], [1])
591
+ nodes.append(fnode)
592
+
593
+ gout = _unique_name("gather_out")
594
+ gnode = helper.make_node(
595
+ "Gather", inputs=["X", fnode.output[0]], outputs=[gout], axis=0
596
+ )
597
+ nodes.append(gnode)
598
+
599
+ thr_node = _make_constant_float_node(_unique_name("thr_val"), thr_val, [])
600
+ nodes.append(thr_node)
601
+
602
+ less_node = helper.make_node(
603
+ "Less", inputs=[gout, thr_node.output[0]], outputs=[cond_name]
604
+ )
605
+ nodes.append(less_node)
606
+
607
+ # Recursively build subgraphs for left and right
608
+ left_sub, left_out = _build_subgraph_for_node(node_dict["left"], n_classes)
609
+ right_sub, right_out = _build_subgraph_for_node(node_dict["right"], n_classes)
610
+
611
+ if_out = _unique_name("if_out")
612
+ if_info = helper.make_tensor_value_info(if_out, TensorProto.DOUBLE, None)
613
+
614
+ if_node = helper.make_node(
615
+ "If",
616
+ inputs=[cond_name],
617
+ outputs=[if_out],
618
+ name=_unique_name("IfNode"),
619
+ then_branch=left_sub,
620
+ else_branch=right_sub,
621
+ )
622
+ nodes.append(if_node)
623
+
624
+ subgraph = helper.make_graph(
625
+ nodes=nodes,
626
+ name=graph_name,
627
+ inputs=[],
628
+ outputs=[if_info],
629
+ value_info=[x_info],
630
+ )
631
+ return subgraph, if_out
632
+
633
+ # Retrieve tree parameters
634
+ params = tree_dict["params"]
635
+ n_classes = params.get("n_classes", 2)
636
+ n_features = params.get("n_features", 4)
637
+
638
+ # Build the root subgraph from the tree
639
+ root_subgraph, root_out_name = _build_subgraph_for_node(
640
+ tree_dict["tree"], n_classes
641
+ )
642
+
643
+ # Main graph I/O
644
+ main_input = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [n_features])
645
+ main_output = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, None)
646
+
647
+ # Extract nodes and value_info from the root subgraph
648
+ nodes = list(root_subgraph.node)
649
+ val_info = list(root_subgraph.value_info)
650
+ if_out_name = root_subgraph.output[0].name
651
+
652
+ # Add a final Identity node to map subgraph output to "Y"
653
+ final_out_node_name = _unique_name("final_y")
654
+ identity_node = helper.make_node(
655
+ "Identity", inputs=[if_out_name], outputs=[final_out_node_name]
656
+ )
657
+ nodes.append(identity_node)
658
+ main_output.name = final_out_node_name
659
+
660
+ # Construct the main graph
661
+ main_graph = helper.make_graph(
662
+ nodes=nodes,
663
+ name="MainGraph",
664
+ inputs=[main_input],
665
+ outputs=[main_output],
666
+ value_info=val_info,
667
+ )
668
+
669
+ # Fix output shape to [1] or [n_classes]
670
+ if n_classes > 2:
671
+ dim = main_graph.output[0].type.tensor_type.shape.dim.add()
672
+ dim.dim_value = n_classes
673
+ else:
674
+ dim = main_graph.output[0].type.tensor_type.shape.dim.add()
675
+ dim.dim_value = 1
676
+
677
+ # Fix input shape to [n_features]
678
+ main_graph.input[0].type.tensor_type.shape.dim[0].dim_value = n_features
679
+
680
+ onnx_model = helper.make_model(
681
+ main_graph,
682
+ producer_name="custom_oblique_categorical_tree",
683
+ opset_imports=[helper.make_opsetid("", 13)],
684
+ )
685
+ onnx_model.ir_version = 7
686
+
687
+ return onnx_model
332
688
 
333
689
 
334
690
  def _format_float(value: float) -> str:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: obliquetree
3
- Version: 1.0.1
3
+ Version: 1.0.3
4
4
  Summary: Traditional and Oblique Decision Tree
5
5
  Author-email: Samet Copur <sametcopur@yahoo.com>
6
6
  License: MIT License
@@ -16,6 +16,7 @@ Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
17
  Requires-Dist: numpy>=2.2.1
18
18
  Requires-Dist: scipy>=1.15.0
19
+ Dynamic: license-file
19
20
 
20
21
  # obliquetree
21
22
 
@@ -0,0 +1,21 @@
1
+ obliquetree/__init__.py,sha256=Dx8l4yy_b_7HPRYpkz39tiKcLBtoS4HsWXOwz1VpiGc,106
2
+ obliquetree/_pywrap.py,sha256=iz7FI7u8MP_iLKiGRtBaEfsw_dnpK8epxLJCI1T8XUw,27079
3
+ obliquetree/utils.py,sha256=ewYp_3HZRvgBMysMlnelZdDuMCv49wsPAuYD_6du4GY,31149
4
+ obliquetree/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ obliquetree/src/base.cp313-win_amd64.pyd,sha256=Gl6jx5itauxWVSPCTNP_Tb4uGAd-O7oaOibLq62EHog,178176
6
+ obliquetree/src/base.cpp,sha256=IwnIkavQxfU0cC4-AVjPN-5aSHgnCjtOm4K88R156UE,1455964
7
+ obliquetree/src/ccp.cp313-win_amd64.pyd,sha256=LQqRsSBycCUF2ixiwKJGbUji3BH3wuB3twSOU2zcuvg,107520
8
+ obliquetree/src/ccp.cpp,sha256=t46G1BB4xXnUPYirTU1g3-vy75f8Y77iwOoXu-svsk8,974220
9
+ obliquetree/src/metric.cp313-win_amd64.pyd,sha256=zRuRAglR5fuyogpS1w4sYiwzgsUZ29r6QfVZtO5S2FM,128512
10
+ obliquetree/src/metric.cpp,sha256=L-nfMk16gP5BQqiqC-TRguNjOgqSKnhDHRZbZRnbMkA,1214827
11
+ obliquetree/src/oblique.cp313-win_amd64.pyd,sha256=fPcIzM2Bbkf19LgWXCza1fR9qzSJIFq67_-2jRoTlU8,201216
12
+ obliquetree/src/oblique.cpp,sha256=M-vcBquzSJ0lGx0cY63g6WnV13GnpVGBjTXDeRpw-co,1402472
13
+ obliquetree/src/tree.cp313-win_amd64.pyd,sha256=U_ZYjKsGI6-Qz5ogCmmXstLQTDi0neAi85BxjpKgruI,121344
14
+ obliquetree/src/tree.cpp,sha256=VVPtQxz8Qq0kJb1uAne_dB9Q_qUXPEmc2BwumlYqFIw,1164673
15
+ obliquetree/src/utils.cp313-win_amd64.pyd,sha256=I7ZOJZ-akf6uvJEpZZXc38BarkD5SDBB8_4sBR09NOw,159744
16
+ obliquetree/src/utils.cpp,sha256=FhJTiyLfEKZeYB78f_n08L13mHyN4IJskoQ9gwv64ww,1268066
17
+ obliquetree-1.0.3.dist-info/licenses/LICENSE,sha256=WfEPDyxevjVRSepuK7Lvl2VDl5pqtwjvMoNu2iWhcD8,1090
18
+ obliquetree-1.0.3.dist-info/METADATA,sha256=gRRugN8nEvyw9a-_SDWNQ-cwZXziJp5Xo8kvH9ogBmc,5776
19
+ obliquetree-1.0.3.dist-info/WHEEL,sha256=qV0EIPljj1XC_vuSatRWjn02nZIz3N1t8jsZz7HBr2U,101
20
+ obliquetree-1.0.3.dist-info/top_level.txt,sha256=m-5N4-iAS5MsFOdk8y1r2ya_i5rVQBgPayHBA-K26qg,12
21
+ obliquetree-1.0.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp313-cp313-win_amd64
5
5
 
@@ -1,21 +0,0 @@
1
- obliquetree/__init__.py,sha256=facekeN72UJTEDPLHy4qMR3gREAZ7GPa3vG4TyHKa0w,83
2
- obliquetree/_pywrap.py,sha256=xFbIFoMNAAYyDSO6b5A1NytxA5eV8-KxXgF2wxbysWo,26484
3
- obliquetree/utils.py,sha256=U0aRNA_FxiIiEagABAAQZ1Ssx1nIMdo-XjN3qv7V98w,18502
4
- obliquetree/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- obliquetree/src/base.cp313-win_amd64.pyd,sha256=6463FQS9jjxbrfOseDnOgdLwyYgQe8XPRxW15S8wJts,193536
6
- obliquetree/src/base.cpp,sha256=0ZILP5yX1ixNIfIn6FDiSmQdTakW1VI1qk-OEmOmec4,1421617
7
- obliquetree/src/ccp.cp313-win_amd64.pyd,sha256=5rmHkpVJl6PA_vPa7_Hr6Rl-v5khtedgoz7bar_GZ9k,118272
8
- obliquetree/src/ccp.cpp,sha256=sAmjaZwmzsSBbj35VMJ3H5FBB18vJRjNjeHDvQXQ3Og,944547
9
- obliquetree/src/metric.cp313-win_amd64.pyd,sha256=hWLdoiDmNZ_scv3F4Su1I_8plqFy1V8rOA2T219z1zw,136704
10
- obliquetree/src/metric.cpp,sha256=XmHbXjazt8upGm4pejkF4cqHjDQR1xY-0X00H3jrXxk,1174026
11
- obliquetree/src/oblique.cp313-win_amd64.pyd,sha256=i09mWgrHL4BTWhyQPkt-A9yDPISGurXdDW9NySedqa4,210432
12
- obliquetree/src/oblique.cpp,sha256=eXd3pS6_mE2w9xLzgx3QXo4S64dNUH2CGPoanUtPCp8,1353669
13
- obliquetree/src/tree.cp313-win_amd64.pyd,sha256=4u6mBgaqdW3a1XK9RsSZVyr_8F8BrkKcBG78yo9jVR4,132608
14
- obliquetree/src/tree.cpp,sha256=54KJ8my79R9h4JS6cb1rZ09HXt9IbMy0JtJkKj5vLE8,1130952
15
- obliquetree/src/utils.cp313-win_amd64.pyd,sha256=KxACVs9GSijGykwzTtJ0SFxB845CAXZQ3KPP6GHT7g8,171520
16
- obliquetree/src/utils.cpp,sha256=eEPf0KdnQoD1-CqVeOM57ElLOQIXvgt5KRhqFP-aygY,1227801
17
- obliquetree-1.0.1.dist-info/LICENSE,sha256=WfEPDyxevjVRSepuK7Lvl2VDl5pqtwjvMoNu2iWhcD8,1090
18
- obliquetree-1.0.1.dist-info/METADATA,sha256=Shil4GEgde1VDY0cBdvglKurdxy9SBoEWZJoZkaeG4Q,5753
19
- obliquetree-1.0.1.dist-info/WHEEL,sha256=6bXTkCllrWLYPW3gCPkeRA91N4604g9hqNhQqZWsUzQ,101
20
- obliquetree-1.0.1.dist-info/top_level.txt,sha256=m-5N4-iAS5MsFOdk8y1r2ya_i5rVQBgPayHBA-K26qg,12
21
- obliquetree-1.0.1.dist-info/RECORD,,