PyDecisionGraph 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PyDecisionGraph might be problematic. Click here for more details.

decision_tree/abc.py ADDED
@@ -0,0 +1,1147 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import json
5
+ import operator
6
+ import sys
7
+ import traceback
8
+ from collections.abc import Callable, Iterable
9
+ from typing import Any, Self, final
10
+
11
+ from . import LOGGER
12
+ from .exc import TooFewChildren, TooManyChildren, EdgeValueError, NodeValueError, NodeNotFountError
13
+
14
+ __all__ = ['LGM', 'LogicGroup', 'SkipContextsBlock', 'LogicExpression', 'ExpressionCollection', 'LogicNode', 'ActionNode', 'ELSE_CONDITION']
15
+
16
+
17
+ class Singleton(type):
18
+ _instances = {}
19
+
20
+ def __call__(cls, *args, **kwargs):
21
+ if cls not in cls._instances:
22
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
23
+ return cls._instances[cls]
24
+
25
+
26
+ class ConditionElse(object):
27
+ """Represents an else condition in decision trees."""
28
+
29
+ def __str__(self):
30
+ return ""
31
+
32
+
33
+ ELSE_CONDITION = NO_CONDITION = ConditionElse()
34
+
35
+
36
+ class LogicGroupManager(metaclass=Singleton):
37
+ """
38
+ A singleton class to manage caching and reuse of LogicGroup instances.
39
+ Keeps track of active LogicGroup instances using a cursor.
40
+ """
41
+
42
+ def __init__(self):
43
+ # Dictionary to store cached LogicGroup instances
44
+ self._cache = {}
45
+
46
+ # Cursor to track the currently active LogicGroups
47
+ self._active_groups: list[LogicGroup] = []
48
+ self._active_nodes: list[LogicNode] = []
49
+ self._exit_nodes: list[ActionNode] = [] # action nodes, usually NoAction() nodes, marked as an early-exit of a logic group
50
+ self._pending_connection_nodes: list[ActionNode] = [] # for those exit-nodes, they will be activated when the corresponding logic group is finalized.
51
+ self.inspection_mode = False
52
+
53
+ def __call__(self, name: str, cls: type[LogicGroup], **kwargs) -> LogicGroup:
54
+ """
55
+ Retrieve a cached LogicGroup instance or create a new one if not cached.
56
+
57
+ :param name: The name of the LogicGroup.
58
+ :param cls: The class of the LogicGroup to create if not cached.
59
+ :param kwargs: Additional arguments for LogicGroup initialization.
60
+ :return: A LogicGroup instance.
61
+ """
62
+ if name in self._cache:
63
+ return self._cache[name]
64
+
65
+ # Create a new instance and add it to the cache
66
+ logic_group = cls(name=name, **kwargs)
67
+ self._cache[name] = logic_group
68
+ return logic_group
69
+
70
+ def __contains__(self, name: str) -> bool:
71
+ return name in self._cache
72
+
73
+ def __getitem__(self, name: str) -> LogicGroup:
74
+ return self._cache[name]
75
+
76
+ def __setitem__(self, name: str, value: LogicGroup):
77
+ self._cache[name] = value
78
+
79
+ def enter_logic_group(self, logic_group: LogicGroup):
80
+ """
81
+ Append a LogicGroup to the active list when it enters.
82
+
83
+ :param logic_group: The LogicGroup entering the context.
84
+ """
85
+ self._active_groups.append(logic_group)
86
+
87
+ def exit_logic_group(self, logic_group: LogicGroup):
88
+ """
89
+ Handle the exit of a LogicGroup and ensure subsequent groups also exit.
90
+
91
+ :param logic_group: The LogicGroup exiting the context.
92
+ """
93
+ if not self._active_groups or self._active_groups[-1] is not logic_group:
94
+ raise ValueError("The LogicGroup is not currently active.")
95
+
96
+ self._active_groups.pop(-1)
97
+
98
+ for node in self._exit_nodes:
99
+ if getattr(node, 'break_from') is logic_group:
100
+ self._pending_connection_nodes.append(node)
101
+
102
+ def enter_expression(self, node: LogicNode):
103
+ if isinstance(self, ActionNode):
104
+ LOGGER.error('Enter the with code block of an ActionNode rejected. Check is this intentional?')
105
+
106
+ if self._pending_connection_nodes:
107
+ from .node import NoAction
108
+
109
+ for _exit_node in self._pending_connection_nodes:
110
+ if isinstance(_exit_node, NoAction):
111
+ if (parent := _exit_node.parent) is None:
112
+ raise NodeNotFountError('ActionNode must have a parent node!')
113
+ parent.replace(original_node=_exit_node, new_node=node)
114
+ else:
115
+ _exit_node.edges.append(NO_CONDITION)
116
+ _exit_node.nodes[NO_CONDITION] = node
117
+
118
+ self._pending_connection_nodes.clear()
119
+
120
+ if (active_node := self.active_expression) is not None:
121
+ active_node: LogicNode = active_node
122
+ active_node.subordinates.append(node)
123
+
124
+ self._active_nodes.append(node)
125
+
126
+ def exit_expression(self, node: LogicNode):
127
+ if not self._active_nodes or self._active_nodes[-1] is not node:
128
+ raise ValueError(f"The {node} is not currently active.")
129
+
130
+ self._active_nodes.pop(-1)
131
+
132
+ def clear(self):
133
+ """
134
+ Clear the cache of LogicGroup instances and reset active groups.
135
+ """
136
+ self._cache.clear()
137
+ self._active_groups.clear()
138
+ self._active_nodes.clear()
139
+
140
+ @property
141
+ def active_logic_group(self) -> LogicGroup | None:
142
+ if self._active_groups:
143
+ return self._active_groups[-1]
144
+
145
+ return None
146
+
147
+ @property
148
+ def active_expression(self) -> LogicNode | None:
149
+ if self._active_nodes:
150
+ return self._active_nodes[-1]
151
+
152
+ return None
153
+
154
+
155
+ LGM = LogicGroupManager()
156
+
157
+
158
+ class LogicGroupMeta(type):
159
+ """
160
+ A metaclass for LogicGroup that manages caching of instances.
161
+ """
162
+ _registry_ = {}
163
+
164
+ def __new__(cls, name, bases, dct):
165
+ new_class = super().__new__(cls, name, bases, dct)
166
+ cls._registry_[name] = new_class
167
+ return new_class
168
+
169
+ def __call__(cls, name, *args, **kwargs):
170
+ if name is None:
171
+ raise ValueError("LogicGroup instances must have a 'name'.")
172
+
173
+ # Check the cache for an existing instance
174
+ if name in LGM:
175
+ return LGM[name]
176
+
177
+ # Create a new instance and cache it
178
+ instance = super().__call__(name=name, *args, **kwargs)
179
+ LGM[name] = instance
180
+ return instance
181
+
182
+ @property
183
+ def registry(self):
184
+ return self._registry_
185
+
186
+
187
+ class LogicGroup(object, metaclass=LogicGroupMeta):
188
+ """
189
+ A minimal context manager to save/restore state from the `.contexts` dict.
190
+
191
+ A logic group maintains no status itself; the status should be restored
192
+ from the outer `.contexts` dict.
193
+ """
194
+
195
+ def __init__(self, name: str, parent: Self = None, contexts: dict[str, Any] = None):
196
+ self.name = name
197
+ self.parent = parent
198
+ self.Break = type(f"{self.__class__.__name__}Break", (Exception,), {}) # Assign Break at instance level
199
+
200
+ # a root logic group
201
+ if parent is None:
202
+ info_dict = {}
203
+ if contexts is None:
204
+ contexts = {}
205
+ # try to recover from parent
206
+ else:
207
+ info_dict = parent._sub_logics.setdefault(name, {})
208
+ logic_type = self.__class__.__name__
209
+ assert info_dict.setdefault('logic_type', logic_type) == logic_type, f"Logic {info_dict['logic_type']} already registered in {parent.name}!"
210
+ contexts = info_dict.setdefault('contexts', {} if contexts is None else contexts)
211
+
212
+ self.contexts: dict[str, Any] = contexts
213
+ self._sub_logics = info_dict.setdefault('sub_logics', {})
214
+
215
+ def __repr__(self):
216
+ return f'<{self.__class__.__name__}>({self.name!r})'
217
+
218
+ def __enter__(self) -> Self:
219
+ LGM.enter_logic_group(self)
220
+ return self
221
+
222
+ def __exit__(self, exc_type, exc_value, exc_traceback):
223
+ LGM.exit_logic_group(self)
224
+
225
+ if exc_type is None:
226
+ return
227
+
228
+ if exc_type is self.Break:
229
+ return True
230
+
231
+ # Explicitly re-raise other exceptions
232
+ return False
233
+
234
+ def break_(self, scope: LogicGroup = None):
235
+ if scope is None:
236
+ scope = self
237
+
238
+ # will not break from scope in inspection mode
239
+ if LGM.inspection_mode:
240
+ active_node = LGM.active_expression
241
+
242
+ if active_node is not None:
243
+ active_node: LogicNode
244
+ if not active_node.nodes:
245
+ raise TooFewChildren()
246
+ else:
247
+ last_node = active_node.last_leaf
248
+ assert isinstance(last_node, ActionNode), NodeValueError('An ActionNode is required before breaking a LogicGroup.')
249
+ last_node.break_from = scope
250
+ LGM._exit_nodes.append(last_node)
251
+ return
252
+
253
+ raise scope.Break()
254
+
255
+ @property
256
+ def sub_logics(self) -> dict[str, Self]:
257
+ sub_logic_instances = {}
258
+ for logic_name, info in self._sub_logics.items():
259
+ logic_type = info["logic_type"]
260
+
261
+ # Dynamically retrieve the class using meta registry
262
+ logic_class = self.__class__.registry.get(logic_type)
263
+
264
+ if logic_class is None:
265
+ raise ValueError(f"Class {logic_type} not found in registry.")
266
+
267
+ # Get the __init__ method's signature
268
+ init_signature = inspect.signature(logic_class.__init__)
269
+ init_params = init_signature.parameters
270
+
271
+ # Prepare arguments for the sub-logic initialization
272
+ init_args = {}
273
+ for param_name, param in init_params.items():
274
+ if param_name == "self":
275
+ continue # Skip 'self'
276
+
277
+ if param_name in info:
278
+ init_args[param_name] = info[param_name]
279
+ elif param_name == "name":
280
+ init_args["name"] = logic_name
281
+ elif param_name == "parent":
282
+ init_args["parent"] = self
283
+ elif param_name == "contexts":
284
+ LOGGER.warning(f"Contexts dict not found for {logic_name}!")
285
+ init_args["contexts"] = {}
286
+ elif param.default == inspect.Parameter.empty:
287
+ # Missing a required argument that cannot be inferred
288
+ raise TypeError(f"Missing required argument '{param_name}' for {logic_type}.")
289
+
290
+ # Instantiate the sub-logic
291
+ sub_logic_instance = logic_class(**init_args)
292
+ sub_logic_instances[logic_name] = sub_logic_instance
293
+
294
+ return sub_logic_instances
295
+
296
+
297
+ class SkipContextsBlock(object):
298
+ class _Skip(Exception):
299
+ pass
300
+
301
+ def _entry_check(self) -> Any:
302
+ """
303
+ A True value indicating NOT skip.
304
+ a False value indicating skip the code block.
305
+ """
306
+ pass
307
+
308
+ @final
309
+ def __enter__(self):
310
+ if self._entry_check(): # Check if the expression evaluates to True
311
+ self._on_enter()
312
+ return self
313
+
314
+ self._original_trace = self.get_trace()
315
+ frame = inspect.currentframe().f_back
316
+ sys.settrace(self.empty_trace)
317
+ frame.f_trace = self.err_trace
318
+
319
+ @final
320
+ def __exit__(self, exc_type, exc_value, exc_traceback):
321
+ if exc_type is None:
322
+ self._on_exit()
323
+ return
324
+
325
+ if issubclass(exc_type, self._Skip):
326
+ if hasattr(self, '_original_trace'):
327
+ sys.settrace(self._original_trace) # Restore the original trace
328
+ else:
329
+ raise Exception('original_trace not found! Debugger broken! This should never happened.')
330
+ return True
331
+
332
+ self._on_exit()
333
+ # Propagate any other exception raised in the block
334
+ return False
335
+
336
+ def _on_enter(self):
337
+ pass
338
+
339
+ def _on_exit(self):
340
+ pass
341
+
342
+ @staticmethod
343
+ def get_trace():
344
+ """
345
+ Safely retrieve the current trace function, prioritizing the PyDev debugger's trace function.
346
+ """
347
+ try:
348
+ # Check if PyDev debugger is active
349
+ # noinspection PyUnresolvedReferences
350
+ import pydevd
351
+ debugger = pydevd.GetGlobalDebugger()
352
+ if debugger is not None:
353
+ return debugger.trace_dispatch # Use PyDev's trace function
354
+ except ImportError:
355
+ pass # PyDev debugger is not installed or active
356
+
357
+ # Fall back to the standard trace function
358
+ return sys.gettrace()
359
+
360
+ @classmethod
361
+ def empty_trace(cls, *args, **kwargs) -> None:
362
+ pass
363
+
364
+ @classmethod
365
+ def err_trace(cls, frame, event, arg):
366
+ raise cls._Skip("Expression evaluated to be False, cannot enter the block.")
367
+
368
+
369
+ class LogicExpression(SkipContextsBlock):
370
+ """
371
+ Represents a logical or mathematical expression that supports deferred evaluation.
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ expression: float | int | bool | Exception | Callable[[], Any],
377
+ dtype: type = None,
378
+ repr: str = None,
379
+ ):
380
+ """
381
+ Initialize the LogicExpression.
382
+
383
+ Args:
384
+ expression (Union[Any, Callable[[], Any]]): A callable or static value.
385
+ dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
386
+ repr (str, optional): A string representation of the expression.
387
+ """
388
+ self.expression = expression
389
+ self.dtype = dtype
390
+ self.repr = repr if repr is not None else str(expression)
391
+
392
+ super().__init__()
393
+
394
+ def _entry_check(self) -> Any:
395
+ return self.eval()
396
+
397
+ def eval(self, enforce_dtype: bool = False) -> Any:
398
+ """Evaluate the expression."""
399
+ if isinstance(self.expression, (float, int, bool, str)):
400
+ value = self.expression
401
+ elif callable(self.expression):
402
+ value = self.expression()
403
+ elif isinstance(self.expression, Exception):
404
+ raise self.expression
405
+ else:
406
+ raise TypeError(f"Unsupported expression type: {type(self.expression)}.")
407
+
408
+ if self.dtype is Any or self.dtype is None:
409
+ pass # No type enforcement
410
+ elif enforce_dtype:
411
+ value = self.dtype(value)
412
+ elif not isinstance(value, self.dtype):
413
+ LOGGER.warning(f"Evaluated value {value} does not match dtype {self.dtype.__name__}.")
414
+
415
+ return value
416
+
417
+ # Logical operators
418
+ @classmethod
419
+ def cast(cls, value: int | float | bool | Exception | Self, dtype: type = None) -> Self:
420
+ """
421
+ Convert a static value, callable, or error into a LogicExpression.
422
+
423
+ Args:
424
+ value (Union[int, float, bool, LogicExpression, Callable, Exception]):
425
+ The value to convert. Can be:
426
+ - A static value (int, float, or bool).
427
+ - A callable returning a value.
428
+ - A pre-existing LogicExpression.
429
+ - An Exception to raise during evaluation.
430
+ dtype (type, optional): The expected type of the resulting LogicExpression.
431
+ If None, it will be inferred from the value.
432
+
433
+ Returns:
434
+ LogicExpression: The resulting LogicExpression.
435
+
436
+ Raises:
437
+ TypeError: If the value type is unsupported or dtype is incompatible.
438
+ """
439
+ if isinstance(value, LogicExpression):
440
+ return value
441
+ if isinstance(value, (int, float, bool)):
442
+ return LogicExpression(
443
+ expression=value,
444
+ dtype=dtype or type(value),
445
+ repr=str(value)
446
+ )
447
+ if callable(value):
448
+ return LogicExpression(
449
+ expression=value,
450
+ dtype=dtype or Any,
451
+ repr=f"Eval({value})"
452
+ )
453
+ if isinstance(value, Exception):
454
+ return LogicExpression(
455
+ expression=value,
456
+ dtype=dtype or Any,
457
+ repr=f"Raises({type(value).__name__}: {value})"
458
+ )
459
+ raise TypeError(f"Unsupported type for LogicExpression conversion: {type(value)}.")
460
+
461
+ def __bool__(self) -> bool:
462
+ return bool(self.eval())
463
+
464
+ def __and__(self, other: Self | bool) -> Self:
465
+ other_expr = self.cast(value=other, dtype=bool)
466
+ new_expr = LogicExpression(
467
+ expression=lambda: self.eval() and other_expr.eval(),
468
+ dtype=bool,
469
+ repr=f"({self.repr} and {other_expr.repr})"
470
+ )
471
+ return new_expr
472
+
473
+ def __eq__(self, other: int | float | bool | str | Self) -> Self:
474
+ if isinstance(other, LogicExpression):
475
+ other_value = other.eval()
476
+ else:
477
+ other_value = other
478
+
479
+ return LogicExpression(
480
+ expression=lambda: self.eval() == other_value,
481
+ dtype=bool,
482
+ repr=f"({self.repr} == {repr(other_value)})"
483
+ )
484
+
485
+ def __or__(self, other: Self | bool) -> Self:
486
+ other_expr = self.cast(value=other, dtype=bool)
487
+ new_expr = LogicExpression(
488
+ expression=lambda: self.eval() or other_expr.eval(),
489
+ dtype=bool,
490
+ repr=f"({self.repr} or {other_expr.repr})"
491
+ )
492
+ return new_expr
493
+
494
+ # Math operators
495
+ @classmethod
496
+ def _math_op(cls, self: Self, other: int | float | Self, op: Callable, operator_str: str, dtype: type = None) -> Self:
497
+ other_expr = LogicExpression.cast(other)
498
+
499
+ if dtype is None:
500
+ dtype = self.dtype
501
+
502
+ new_expr = LogicExpression(
503
+ expression=lambda: op(self.eval(), other_expr.eval()),
504
+ dtype=dtype,
505
+ repr=f"({self.repr} {operator_str} {other_expr.repr})",
506
+ )
507
+ return new_expr
508
+
509
+ def __add__(self, other: int | float | bool | Self) -> Self:
510
+ return self._math_op(self=self, other=other, op=operator.add, operator_str="+")
511
+
512
+ def __sub__(self, other: int | float | bool | Self) -> Self:
513
+ return self._math_op(self=self, other=other, op=operator.sub, operator_str="-")
514
+
515
+ def __mul__(self, other: int | float | bool | Self) -> Self:
516
+ return self._math_op(self=self, other=other, op=operator.mul, operator_str="*")
517
+
518
+ def __truediv__(self, other: int | float | bool | Self) -> Self:
519
+ return self._math_op(self=self, other=other, op=operator.truediv, operator_str="/")
520
+
521
+ def __floordiv__(self, other: int | float | bool | Self) -> Self:
522
+ return self._math_op(self=self, other=other, op=operator.floordiv, operator_str="//")
523
+
524
+ def __pow__(self, other: int | float | bool | Self) -> Self:
525
+ return self._math_op(self=self, other=other, op=operator.pow, operator_str="**")
526
+
527
+ # Comparison operators, note that __eq__, __ne__ is special and should not implement as math operator
528
+ def __lt__(self, other: int | float | bool | Self) -> Self:
529
+ return self._math_op(self=self, other=other, op=operator.lt, operator_str="<", dtype=bool)
530
+
531
+ def __le__(self, other: int | float | bool | Self) -> Self:
532
+ return self._math_op(self=self, other=other, op=operator.le, operator_str="<=", dtype=bool)
533
+
534
+ def __gt__(self, other: int | float | bool | Self) -> Self:
535
+ return self._math_op(self=self, other=other, op=operator.gt, operator_str=">", dtype=bool)
536
+
537
+ def __ge__(self, other: int | float | bool | Self) -> Self:
538
+ return self._math_op(self=self, other=other, op=operator.ge, operator_str=">=", dtype=bool)
539
+
540
+ def __repr__(self) -> str:
541
+ return f"LogicExpression(dtype={'Any' if self.dtype is None else self.dtype.__name__}, repr={self.repr})"
542
+
543
+
544
+ class ExpressionCollection(LogicGroup):
545
+ def __init__(self, data: Any, name: str, repr: str = None, **kwargs):
546
+ if 'logic_group' not in kwargs:
547
+ logic_group = kwargs.get("logic_group")
548
+ else:
549
+ logic_group = LGM.active_logic_group
550
+
551
+ super().__init__(
552
+ name=repr if repr is not None else name if logic_group is None else f'{logic_group.name}.{name}',
553
+ parent=logic_group
554
+ )
555
+
556
+ self.data = self.contexts.setdefault('data', data)
557
+
558
+
559
+ class LogicNode(LogicExpression):
560
+ def __init__(
561
+ self,
562
+ expression: float | int | bool | Exception | Callable[[], Any],
563
+ dtype: type = None,
564
+ repr: str = None,
565
+ ):
566
+ """
567
+ Initialize the LogicExpression.
568
+
569
+ Args:
570
+ expression (Union[Any, Callable[[], Any]]): A callable or static value.
571
+ dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
572
+ repr (str, optional): A string representation of the expression.
573
+ """
574
+ super().__init__(expression=expression, dtype=dtype, repr=repr)
575
+
576
+ self.labels = [_.name for _ in LGM._active_groups]
577
+ self.nodes: dict[Any, LogicNode] = {} # Dict[condition, LogicExpression]
578
+ self.parent: LogicNode | None = None
579
+ self.edges = [] # list of condition
580
+ self.subordinates = [] # all the subordinate nodes initialized inside this node with statement
581
+
582
+ def _entry_check(self) -> Any:
583
+ """
584
+ If `LGM.inspection_mode` is active, always returns `True`.
585
+ Which guarantees the entrance the with code block
586
+
587
+ Returns:
588
+ Any: Evaluation result.
589
+ """
590
+ if LGM.inspection_mode:
591
+ return True
592
+ return self.eval()
593
+
594
+ def __rshift__(self, expression: Self):
595
+ """Overloads >> operator for adding child nodes."""
596
+ self.append(expression)
597
+ return expression # Allow chaining
598
+
599
+ def __call__(self, default=None) -> Any:
600
+ """
601
+ Recursively evaluates the decision tree starting from this node.
602
+
603
+ Keyword Args:
604
+ default (Any, optional): Fallback value if no matching condition is found.
605
+
606
+ Returns:
607
+ final_value (Any): The evaluated result of the tree.
608
+
609
+ Raises:
610
+ ValueError: If no matching condition is found and no default value is provided.
611
+ """
612
+
613
+ if default is None:
614
+ from .node import NoAction
615
+ default = NoAction(auto_connect=False)
616
+
617
+ if _ins_mode := LGM.inspection_mode:
618
+ LOGGER.info('LGM inspection mode temporally disabled to evaluate correctly.')
619
+ LGM.inspection_mode = False
620
+
621
+ _, path = self.eval_recursively(default=default)
622
+ LGM.inspection_mode = _ins_mode
623
+ if not path:
624
+ raise TooFewChildren()
625
+
626
+ leaf = path[-1]
627
+ return leaf.eval()
628
+
629
+ def __repr__(self):
630
+ return f'<{self.__class__.__name__}>({self.repr!r})'
631
+
632
+ def _on_enter(self):
633
+ active_node: LogicNode = LGM.active_expression
634
+
635
+ if active_node is None:
636
+ return LGM.enter_expression(node=self)
637
+
638
+ match active_node.subordinates:
639
+ case []:
640
+ active_node.append(expression=self, edge_condition=True)
641
+
642
+ case [*_, last_node] if not last_node.nodes:
643
+ raise TooFewChildren()
644
+
645
+ case [*_, last_node] if len(last_node.nodes) == 1:
646
+ edge_condition = last_node.last_edge
647
+ if not isinstance(edge_condition, bool):
648
+ raise EdgeValueError(f'{last_node} Edge condition must be a Boolean!')
649
+ last_node.append(expression=self, edge_condition=not edge_condition)
650
+
651
+ case [*_, last_node] if len(last_node.nodes) == 2:
652
+ from .node import NoAction
653
+ edge_condition, child = last_node.last_edge, last_node.last_node
654
+ if not isinstance(child, NoAction):
655
+ raise NodeValueError(f'{last_node} second child node must be a NoAction node!')
656
+ last_node.pop(-1)
657
+ last_node.append(expression=self, edge_condition=edge_condition)
658
+
659
+ case [*_, last_node] if len(last_node.nodes) > 2:
660
+ raise TooManyChildren()
661
+
662
+ if isinstance(self, ActionNode):
663
+ pass
664
+ else:
665
+ LGM.enter_expression(node=self)
666
+
667
+ def _on_exit(self):
668
+ self.fill_binary_branch(node=self)
669
+ LGM.exit_expression(node=self)
670
+
671
+ @classmethod
672
+ def fill_binary_branch(cls, node: LogicNode, with_action: ActionNode = None):
673
+ """
674
+ Ensures the decision tree node has both True and False branches.
675
+
676
+ Args:
677
+ node (LogicNode): The node to check.
678
+ with_action (ActionNode, optional): A default action node to add if missing.
679
+ """
680
+ if with_action is None:
681
+ from .node import NoAction
682
+ with_action = NoAction(auto_connect=False)
683
+
684
+ if isinstance(node, ActionNode):
685
+ return
686
+
687
+ match len(node.nodes):
688
+ case 0:
689
+ LOGGER.warning(f"It is rear that {node} having no True branch. Check the <with> statement code block to see if this is intended.")
690
+ node.append(expression=with_action, edge_condition=False)
691
+ case 1:
692
+ edge_condition = node.last_edge
693
+ if not isinstance(edge_condition, bool):
694
+ raise EdgeValueError(f'{node} Edge condition must be a Boolean!')
695
+ node.append(expression=with_action, edge_condition=not edge_condition)
696
+ case _:
697
+ raise TooManyChildren()
698
+
699
+ @classmethod
700
+ def traverse(cls, node: Self, G=None, node_map: dict[int, Self] = None, parent: Self = None, edge_condition: Any = None):
701
+ """
702
+ Recursively traverses the decision tree, adding nodes and edges to the graph.
703
+
704
+ Args:
705
+ node (LogicNode): The current node being traversed.
706
+ G (networkx.DiGraph, optional): The graph being constructed. Defaults to a new graph.
707
+ node_map (dict, optional): A dictionary mapping node IDs to LogicNode instances.
708
+ parent (LogicNode, optional): The parent node of the current node.
709
+ edge_condition (Any, optional): The condition from parent to this node.
710
+ """
711
+ import networkx as nx
712
+
713
+ if G is None:
714
+ G = nx.DiGraph()
715
+ if node_map is None:
716
+ node_map = {}
717
+
718
+ node_id = id(node)
719
+ # if node_id in node_map:
720
+ # return # Avoid duplicate traversal
721
+
722
+ node_map[node_id] = node
723
+ G.add_node(node_id, description=node.repr)
724
+
725
+ if parent is not None:
726
+ edge_label = str(edge_condition) # Use the edge condition from the parent's children list
727
+ G.add_edge(id(parent), node_id, label=edge_label)
728
+
729
+ for edge_condition, child in node.nodes.items():
730
+ cls.traverse(node=child, G=G, node_map=node_map, parent=node, edge_condition=edge_condition)
731
+
732
+ return G, node_map
733
+
734
+ def append(self, expression: LogicNode, edge_condition: Any = None):
735
+ """
736
+ Adds a child node to the current node.
737
+
738
+ Args:
739
+ expression (LogicNode): The child node.
740
+ edge_condition (Any, optional): The condition for branching.
741
+
742
+ Raises:
743
+ ValueError: If no edge condition is provided.
744
+ """
745
+ if edge_condition is None:
746
+ edge_condition = NO_CONDITION
747
+
748
+ if edge_condition is None:
749
+ raise ValueError("Child LogicExpression must have an edge condition.")
750
+
751
+ if edge_condition in self.nodes:
752
+ raise ValueError(f"Edge {edge_condition} already exists.")
753
+
754
+ self.edges.append(edge_condition)
755
+ self.nodes[edge_condition] = expression
756
+ expression.parent = self
757
+
758
+ def pop(self, index: int = -1) -> tuple[Any, LogicNode]:
759
+ edge = self.edges.pop(index)
760
+ node = self.nodes.pop(edge)
761
+ return edge, node
762
+
763
+ def replace(self, original_node: LogicNode, new_node: LogicNode):
764
+ for condition, node in self.nodes.items():
765
+ if node is original_node:
766
+ break
767
+ else:
768
+ raise NodeNotFountError()
769
+
770
+ self.nodes[condition] = new_node
771
+
772
+ def eval_recursively(self, **kwargs):
773
+ """
774
+ Recursively evaluates the decision tree starting from this node.
775
+
776
+ Keyword Args:
777
+ path (list, optional): Tracks the decision path during evaluation. Defaults to a new list.
778
+ default (Any, optional): Fallback value if no matching condition is found.
779
+
780
+ Returns:
781
+ tuple: (final_value, decision_path)
782
+ - final_value (Any): The evaluated result of the tree.
783
+ - decision_path (list): The sequence of nodes traversed during evaluation.
784
+
785
+ Raises:
786
+ ValueError: If no matching condition is found and no default value is provided.
787
+ """
788
+ if 'path' in kwargs:
789
+ path = kwargs['path']
790
+ else:
791
+ path = [self]
792
+
793
+ value = self.eval()
794
+
795
+ if not self.nodes:
796
+ return value, path
797
+
798
+ for condition, child in self.nodes.items():
799
+ if condition == value or condition is NO_CONDITION:
800
+ return child.eval_recursively(path=path)
801
+
802
+ if 'default' in kwargs:
803
+ default = kwargs['default']
804
+ LOGGER.info(f"No matching condition found for value {value} at '{self.repr}', using default {default}.")
805
+ return default, path
806
+
807
+ raise ValueError(f"No matching condition found for value {value} at '{self.repr}'.")
808
+
809
+ def list_labels(self) -> dict[str, list[LogicNode]]:
810
+ """
811
+ Lists all logic groups in the tree and returns a dictionary mapping group names to nodes.
812
+ """
813
+ labels = {}
814
+
815
+ def traverse(node):
816
+ for group in node.labels:
817
+ if group not in labels:
818
+ labels[group] = []
819
+ labels[group].append(node)
820
+ for _, child in node.nodes.items():
821
+ traverse(child)
822
+
823
+ traverse(self)
824
+ return labels
825
+
826
+ def select_node(self, label: str) -> LogicNode | None:
827
+ """
828
+ Selects the root node of a logic group and validates that the group is chained.
829
+ """
830
+ labels = self.list_labels()
831
+ if label not in labels:
832
+ return None
833
+
834
+ nodes = labels[label]
835
+ root = None
836
+
837
+ for node in nodes:
838
+ if not any(node in child_nodes for _, child_nodes in labels.items() if _ != label):
839
+ if root is not None:
840
+ raise ValueError(f"Logic group '{label}' has multiple roots.")
841
+ root = node
842
+
843
+ return root
844
+
845
+ def to_html(self, with_group=True, dry_run=True, filename="decision_tree.html", **kwargs):
846
+ """
847
+ Visualizes the decision tree using PyVis.
848
+ If dry_run=True, shows structure without highlighting active path.
849
+ If dry_run=False, evaluates the tree and highlights the decision path.
850
+ If with_group=True, uses grouped logic view.
851
+ """
852
+ from pyvis.network import Network
853
+
854
+ G, node_map = self.traverse(self)
855
+ # Highlight path if not in dry run
856
+ activated_path = []
857
+ if not dry_run:
858
+ try:
859
+ _, path = self.eval_recursively()
860
+ activated_path = [id(node) for node in path]
861
+ except Exception:
862
+ activated_path.clear()
863
+ dry_run = True
864
+ LOGGER.error(f"Failed to evaluate decision tree.\n{traceback.format_exc()}")
865
+
866
+ # Visualization using PyVis
867
+ net = Network(
868
+ height=kwargs.get('height', "750px"),
869
+ width=kwargs.get('width', "100%"),
870
+ directed=True,
871
+ notebook=False,
872
+ neighborhood_highlight=True
873
+ )
874
+ default_color = kwargs.get('default_color', "lightblue")
875
+ highlight_color = kwargs.get('highlight_color', "lightgreen")
876
+ activated_color = kwargs.get('selected_color', "lightyellow")
877
+ dimmed_color = kwargs.get('dimmed_color', "#e0e0e0")
878
+ logic_shape = kwargs.get('logic_shape', "box")
879
+ action_shape = kwargs.get('action_shape', "ellipse")
880
+
881
+ original_colors = {}
882
+
883
+ # Add nodes with group information
884
+ for node_id, node in node_map.items():
885
+ label = node.repr
886
+ title = f"Node: {node.repr}"
887
+
888
+ # Track the original color for each node
889
+ node_color = activated_color if node_id in activated_path else default_color
890
+ original_colors[node_id] = node_color
891
+
892
+ if with_group:
893
+ net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape, groups=node.labels)
894
+ else:
895
+ net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape)
896
+
897
+ # Add edges
898
+ for source, target, data in G.edges(data=True):
899
+ edge_label = data.get("label", "")
900
+ edge_color = "black" if dry_run else ("green" if source in activated_path and target in activated_path else "black")
901
+ net.add_edge(source, target, label=edge_label, title=edge_label, color=edge_color, arrows="to")
902
+
903
+ # Configure layout and options
904
+ options = {
905
+ "layout": {
906
+ "hierarchical": {
907
+ "enabled": True,
908
+ "direction": "UD", # UD = Up-Down (root at top, leaves at bottom)
909
+ "sortMethod": "directed",
910
+ "nodeSpacing": 150,
911
+ "levelSeparation": 200
912
+ }
913
+ },
914
+ "physics": {
915
+ "hierarchicalRepulsion": {
916
+ "centralGravity": 0.0,
917
+ "springLength": 200,
918
+ "springConstant": 0.01,
919
+ "nodeDistance": 200,
920
+ "damping": 0.09
921
+ },
922
+ "minVelocity": 0.75,
923
+ "solver": "hierarchicalRepulsion"
924
+ },
925
+ "nodes": {
926
+ "shape": "box",
927
+ "shapeProperties": {"borderRadius": 10},
928
+ "font": {"size": 14}
929
+ },
930
+ "edges": {
931
+ "color": "black",
932
+ "smooth": True
933
+ }
934
+ }
935
+
936
+ net.set_options(json.dumps(options))
937
+
938
+ # Generate the base HTML
939
+ html = net.generate_html()
940
+
941
+ # Inject custom controls and JavaScript
942
+ buttons_html = """
943
+ <div style="position: absolute; top: 10px; left: 10px; z-index: 1000;
944
+ background: rgba(255, 255, 255, 0.9); padding: 12px;
945
+ border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.3);
946
+ font-family: Arial, sans-serif;">
947
+
948
+ <h4 style="margin: 0 0 10px; font-size: 16px; text-align: center; color: #333;">
949
+ Decision Tree Controls
950
+ </h4>
951
+
952
+ <button onclick="resetColors()" class="control-btn">Reset</button>
953
+ """
954
+
955
+ if with_group:
956
+ groups = {group for node in node_map.values() for group in node.labels}
957
+ for group in sorted(groups):
958
+ buttons_html += f'<button onclick="highlightGroup(\'{group}\')" class="control-btn">{group}</button>'
959
+ buttons_html += "</div>"
960
+
961
+ js_code = f"""
962
+ <script>
963
+ function resetColors() {{
964
+ // Reset all nodes to their original color and opacity
965
+ nodes.forEach(function(node) {{
966
+ nodes.update([{{
967
+ id: node.id,
968
+ color: originalColors[node.id], // Reset to original color
969
+ opacity: 1
970
+ }}]);
971
+ }});
972
+
973
+ // Reset all edges to default color and opacity
974
+ edges.forEach(function(edge) {{
975
+ edges.update([{{
976
+ id: edge.id,
977
+ color: "black",
978
+ opacity: 1
979
+ }}]);
980
+ }});
981
+ }}
982
+
983
+ function highlightGroup(group) {{
984
+ // Dim all nodes and edges first
985
+ nodes.update([...nodes.getIds().map(id => ({{
986
+ id: id,
987
+ color: "{dimmed_color}",
988
+ opacity: 0.3
989
+ }}))]);
990
+
991
+ edges.update([...edges.getIds().map(id => ({{
992
+ id: id,
993
+ color: "gray",
994
+ opacity: 0.2
995
+ }}))]);
996
+
997
+ // Highlight nodes in the selected group
998
+ const groupNodes = nodes.get({{
999
+ filter: node => node.groups.includes(group)
1000
+ }});
1001
+
1002
+ nodes.update([...groupNodes.map(node => ({{
1003
+ id: node.id,
1004
+ color: "{highlight_color}",
1005
+ opacity: 1
1006
+ }}))]);
1007
+
1008
+ // Highlight connected edges
1009
+ const connectedEdges = edges.get({{
1010
+ filter: edge =>
1011
+ groupNodes.some(n => n.id === edge.from) ||
1012
+ groupNodes.some(n => n.id === edge.to)
1013
+ }});
1014
+
1015
+ edges.update([...connectedEdges.map(edge => ({{
1016
+ id: edge.id,
1017
+ color: "black",
1018
+ opacity: 1
1019
+ }}))]);
1020
+ }}
1021
+
1022
+ // Store the original node colors for reset functionality
1023
+ const originalColors = {json.dumps(original_colors)};
1024
+ </script>
1025
+ """
1026
+
1027
+ # Inject better styles for buttons
1028
+ css_styles = """
1029
+ <style>
1030
+ .control-btn {
1031
+ background-color: #007BFF;
1032
+ color: white;
1033
+ border: none;
1034
+ padding: 8px 14px;
1035
+ margin: 5px;
1036
+ font-size: 14px;
1037
+ border-radius: 5px;
1038
+ cursor: pointer;
1039
+ transition: background 0.3s ease;
1040
+ }
1041
+
1042
+ .control-btn:hover {
1043
+ background-color: #0056b3;
1044
+ }
1045
+
1046
+ .control-btn:active {
1047
+ background-color: #003f7f;
1048
+ }
1049
+ </style>
1050
+ """
1051
+
1052
+ # Insert custom elements into the HTML
1053
+ html = html.replace("</head>", f"{css_styles}</head>")
1054
+ html = html.replace("</body>", f"{buttons_html}{js_code}</body>")
1055
+
1056
+ # Save the modified HTML
1057
+ with open(filename, "w") as f:
1058
+ f.write(html)
1059
+
1060
+ LOGGER.info(f"Decision tree saved to {filename}")
1061
+
1062
+ @property
1063
+ def children(self) -> Iterable[tuple[Any, LogicNode]]:
1064
+ """Returns an iterable of (edge, node) pairs."""
1065
+ return iter(self.nodes.items())
1066
+
1067
+ @property
1068
+ def leaves(self) -> Iterable[LogicNode]:
1069
+ """Recursively finds and returns all leaf nodes (nodes without children)."""
1070
+ if not self.nodes: # If no children, this node is a leaf
1071
+ yield self
1072
+ else:
1073
+ for _, child in self.nodes.items(): # Recursively get leaves from children
1074
+ yield from child.leaves
1075
+
1076
+ @property
1077
+ def last_edge(self) -> Any:
1078
+ return self.edges[-1]
1079
+
1080
+ @property
1081
+ def last_node(self) -> LogicNode:
1082
+ return self.nodes[self.last_edge]
1083
+
1084
+ @property
1085
+ def last_leaf(self) -> LogicNode:
1086
+ if not self.nodes:
1087
+ return self
1088
+ return self.last_node.last_leaf
1089
+
1090
+ @property
1091
+ def last_leaf_expression(self) -> LogicNode:
1092
+ last_leaf = self.last_leaf
1093
+ if isinstance(last_leaf, ActionNode):
1094
+ return last_leaf.parent
1095
+ return last_leaf
1096
+
1097
+
1098
+ class ActionNode(LogicNode):
1099
+ def __init__(
1100
+ self,
1101
+ action: float | int | bool | None | Exception | Callable[[], Any],
1102
+ dtype: type = None,
1103
+ repr: str = None,
1104
+ auto_connect: bool = True
1105
+ ):
1106
+ """
1107
+ Initialize the LogicExpression.
1108
+
1109
+ Args:
1110
+ action (Union[Any, Callable[[], Any]]): The action to execute.
1111
+ dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
1112
+ repr (str, optional): A string representation of the expression.
1113
+ """
1114
+ super().__init__(expression=True, dtype=dtype, repr=repr)
1115
+ self.action = action
1116
+
1117
+ if auto_connect:
1118
+ super()._on_enter()
1119
+
1120
+ def _on_enter(self):
1121
+ LOGGER.warning(f'{self.__class__.__name__} should not use with claude')
1122
+
1123
+ def _on_exit(self):
1124
+ pass
1125
+
1126
+ def eval_recursively(self, path=None):
1127
+ """
1128
+ Evaluates the decision tree from this node based on the given state.
1129
+ Returns the final action and records the decision path.
1130
+ """
1131
+ if path is None:
1132
+ path = []
1133
+ path.append(self)
1134
+
1135
+ value = self.eval()
1136
+
1137
+ if self.action is not None:
1138
+ self.action()
1139
+
1140
+ for condition, child in self.nodes.items():
1141
+ if condition == value or condition is NO_CONDITION:
1142
+ return child.eval_recursively(path=path)
1143
+
1144
+ return value, path
1145
+
1146
+ def append(self, expression: Self, edge_condition: Any = None):
1147
+ raise TooManyChildren("Cannot append child to an ActionNode!")