PyDecisionGraph 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of PyDecisionGraph might be problematic. Click here for more details.
- decision_tree/__init__.py +36 -0
- decision_tree/abc.py +1147 -0
- decision_tree/collection.py +93 -0
- decision_tree/exc.py +38 -0
- decision_tree/expression.py +478 -0
- decision_tree/logic_group.py +307 -0
- decision_tree/node.py +180 -0
- pydecisiongraph-0.1.0.dist-info/LICENSE +373 -0
- pydecisiongraph-0.1.0.dist-info/METADATA +21 -0
- pydecisiongraph-0.1.0.dist-info/RECORD +12 -0
- pydecisiongraph-0.1.0.dist-info/WHEEL +5 -0
- pydecisiongraph-0.1.0.dist-info/top_level.txt +1 -0
decision_tree/abc.py
ADDED
|
@@ -0,0 +1,1147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import operator
|
|
6
|
+
import sys
|
|
7
|
+
import traceback
|
|
8
|
+
from collections.abc import Callable, Iterable
|
|
9
|
+
from typing import Any, Self, final
|
|
10
|
+
|
|
11
|
+
from . import LOGGER
|
|
12
|
+
from .exc import TooFewChildren, TooManyChildren, EdgeValueError, NodeValueError, NodeNotFountError
|
|
13
|
+
|
|
14
|
+
__all__ = ['LGM', 'LogicGroup', 'SkipContextsBlock', 'LogicExpression', 'ExpressionCollection', 'LogicNode', 'ActionNode', 'ELSE_CONDITION']
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Singleton(type):
|
|
18
|
+
_instances = {}
|
|
19
|
+
|
|
20
|
+
def __call__(cls, *args, **kwargs):
|
|
21
|
+
if cls not in cls._instances:
|
|
22
|
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
|
23
|
+
return cls._instances[cls]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ConditionElse(object):
|
|
27
|
+
"""Represents an else condition in decision trees."""
|
|
28
|
+
|
|
29
|
+
def __str__(self):
|
|
30
|
+
return ""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
ELSE_CONDITION = NO_CONDITION = ConditionElse()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LogicGroupManager(metaclass=Singleton):
|
|
37
|
+
"""
|
|
38
|
+
A singleton class to manage caching and reuse of LogicGroup instances.
|
|
39
|
+
Keeps track of active LogicGroup instances using a cursor.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
# Dictionary to store cached LogicGroup instances
|
|
44
|
+
self._cache = {}
|
|
45
|
+
|
|
46
|
+
# Cursor to track the currently active LogicGroups
|
|
47
|
+
self._active_groups: list[LogicGroup] = []
|
|
48
|
+
self._active_nodes: list[LogicNode] = []
|
|
49
|
+
self._exit_nodes: list[ActionNode] = [] # action nodes, usually NoAction() nodes, marked as an early-exit of a logic group
|
|
50
|
+
self._pending_connection_nodes: list[ActionNode] = [] # for those exit-nodes, they will be activated when the corresponding logic group is finalized.
|
|
51
|
+
self.inspection_mode = False
|
|
52
|
+
|
|
53
|
+
def __call__(self, name: str, cls: type[LogicGroup], **kwargs) -> LogicGroup:
|
|
54
|
+
"""
|
|
55
|
+
Retrieve a cached LogicGroup instance or create a new one if not cached.
|
|
56
|
+
|
|
57
|
+
:param name: The name of the LogicGroup.
|
|
58
|
+
:param cls: The class of the LogicGroup to create if not cached.
|
|
59
|
+
:param kwargs: Additional arguments for LogicGroup initialization.
|
|
60
|
+
:return: A LogicGroup instance.
|
|
61
|
+
"""
|
|
62
|
+
if name in self._cache:
|
|
63
|
+
return self._cache[name]
|
|
64
|
+
|
|
65
|
+
# Create a new instance and add it to the cache
|
|
66
|
+
logic_group = cls(name=name, **kwargs)
|
|
67
|
+
self._cache[name] = logic_group
|
|
68
|
+
return logic_group
|
|
69
|
+
|
|
70
|
+
def __contains__(self, name: str) -> bool:
|
|
71
|
+
return name in self._cache
|
|
72
|
+
|
|
73
|
+
def __getitem__(self, name: str) -> LogicGroup:
|
|
74
|
+
return self._cache[name]
|
|
75
|
+
|
|
76
|
+
def __setitem__(self, name: str, value: LogicGroup):
|
|
77
|
+
self._cache[name] = value
|
|
78
|
+
|
|
79
|
+
def enter_logic_group(self, logic_group: LogicGroup):
|
|
80
|
+
"""
|
|
81
|
+
Append a LogicGroup to the active list when it enters.
|
|
82
|
+
|
|
83
|
+
:param logic_group: The LogicGroup entering the context.
|
|
84
|
+
"""
|
|
85
|
+
self._active_groups.append(logic_group)
|
|
86
|
+
|
|
87
|
+
def exit_logic_group(self, logic_group: LogicGroup):
|
|
88
|
+
"""
|
|
89
|
+
Handle the exit of a LogicGroup and ensure subsequent groups also exit.
|
|
90
|
+
|
|
91
|
+
:param logic_group: The LogicGroup exiting the context.
|
|
92
|
+
"""
|
|
93
|
+
if not self._active_groups or self._active_groups[-1] is not logic_group:
|
|
94
|
+
raise ValueError("The LogicGroup is not currently active.")
|
|
95
|
+
|
|
96
|
+
self._active_groups.pop(-1)
|
|
97
|
+
|
|
98
|
+
for node in self._exit_nodes:
|
|
99
|
+
if getattr(node, 'break_from') is logic_group:
|
|
100
|
+
self._pending_connection_nodes.append(node)
|
|
101
|
+
|
|
102
|
+
def enter_expression(self, node: LogicNode):
|
|
103
|
+
if isinstance(self, ActionNode):
|
|
104
|
+
LOGGER.error('Enter the with code block of an ActionNode rejected. Check is this intentional?')
|
|
105
|
+
|
|
106
|
+
if self._pending_connection_nodes:
|
|
107
|
+
from .node import NoAction
|
|
108
|
+
|
|
109
|
+
for _exit_node in self._pending_connection_nodes:
|
|
110
|
+
if isinstance(_exit_node, NoAction):
|
|
111
|
+
if (parent := _exit_node.parent) is None:
|
|
112
|
+
raise NodeNotFountError('ActionNode must have a parent node!')
|
|
113
|
+
parent.replace(original_node=_exit_node, new_node=node)
|
|
114
|
+
else:
|
|
115
|
+
_exit_node.edges.append(NO_CONDITION)
|
|
116
|
+
_exit_node.nodes[NO_CONDITION] = node
|
|
117
|
+
|
|
118
|
+
self._pending_connection_nodes.clear()
|
|
119
|
+
|
|
120
|
+
if (active_node := self.active_expression) is not None:
|
|
121
|
+
active_node: LogicNode = active_node
|
|
122
|
+
active_node.subordinates.append(node)
|
|
123
|
+
|
|
124
|
+
self._active_nodes.append(node)
|
|
125
|
+
|
|
126
|
+
def exit_expression(self, node: LogicNode):
|
|
127
|
+
if not self._active_nodes or self._active_nodes[-1] is not node:
|
|
128
|
+
raise ValueError(f"The {node} is not currently active.")
|
|
129
|
+
|
|
130
|
+
self._active_nodes.pop(-1)
|
|
131
|
+
|
|
132
|
+
def clear(self):
|
|
133
|
+
"""
|
|
134
|
+
Clear the cache of LogicGroup instances and reset active groups.
|
|
135
|
+
"""
|
|
136
|
+
self._cache.clear()
|
|
137
|
+
self._active_groups.clear()
|
|
138
|
+
self._active_nodes.clear()
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def active_logic_group(self) -> LogicGroup | None:
|
|
142
|
+
if self._active_groups:
|
|
143
|
+
return self._active_groups[-1]
|
|
144
|
+
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def active_expression(self) -> LogicNode | None:
|
|
149
|
+
if self._active_nodes:
|
|
150
|
+
return self._active_nodes[-1]
|
|
151
|
+
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
LGM = LogicGroupManager()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class LogicGroupMeta(type):
|
|
159
|
+
"""
|
|
160
|
+
A metaclass for LogicGroup that manages caching of instances.
|
|
161
|
+
"""
|
|
162
|
+
_registry_ = {}
|
|
163
|
+
|
|
164
|
+
def __new__(cls, name, bases, dct):
|
|
165
|
+
new_class = super().__new__(cls, name, bases, dct)
|
|
166
|
+
cls._registry_[name] = new_class
|
|
167
|
+
return new_class
|
|
168
|
+
|
|
169
|
+
def __call__(cls, name, *args, **kwargs):
|
|
170
|
+
if name is None:
|
|
171
|
+
raise ValueError("LogicGroup instances must have a 'name'.")
|
|
172
|
+
|
|
173
|
+
# Check the cache for an existing instance
|
|
174
|
+
if name in LGM:
|
|
175
|
+
return LGM[name]
|
|
176
|
+
|
|
177
|
+
# Create a new instance and cache it
|
|
178
|
+
instance = super().__call__(name=name, *args, **kwargs)
|
|
179
|
+
LGM[name] = instance
|
|
180
|
+
return instance
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def registry(self):
|
|
184
|
+
return self._registry_
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class LogicGroup(object, metaclass=LogicGroupMeta):
|
|
188
|
+
"""
|
|
189
|
+
A minimal context manager to save/restore state from the `.contexts` dict.
|
|
190
|
+
|
|
191
|
+
A logic group maintains no status itself; the status should be restored
|
|
192
|
+
from the outer `.contexts` dict.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(self, name: str, parent: Self = None, contexts: dict[str, Any] = None):
|
|
196
|
+
self.name = name
|
|
197
|
+
self.parent = parent
|
|
198
|
+
self.Break = type(f"{self.__class__.__name__}Break", (Exception,), {}) # Assign Break at instance level
|
|
199
|
+
|
|
200
|
+
# a root logic group
|
|
201
|
+
if parent is None:
|
|
202
|
+
info_dict = {}
|
|
203
|
+
if contexts is None:
|
|
204
|
+
contexts = {}
|
|
205
|
+
# try to recover from parent
|
|
206
|
+
else:
|
|
207
|
+
info_dict = parent._sub_logics.setdefault(name, {})
|
|
208
|
+
logic_type = self.__class__.__name__
|
|
209
|
+
assert info_dict.setdefault('logic_type', logic_type) == logic_type, f"Logic {info_dict['logic_type']} already registered in {parent.name}!"
|
|
210
|
+
contexts = info_dict.setdefault('contexts', {} if contexts is None else contexts)
|
|
211
|
+
|
|
212
|
+
self.contexts: dict[str, Any] = contexts
|
|
213
|
+
self._sub_logics = info_dict.setdefault('sub_logics', {})
|
|
214
|
+
|
|
215
|
+
def __repr__(self):
|
|
216
|
+
return f'<{self.__class__.__name__}>({self.name!r})'
|
|
217
|
+
|
|
218
|
+
def __enter__(self) -> Self:
|
|
219
|
+
LGM.enter_logic_group(self)
|
|
220
|
+
return self
|
|
221
|
+
|
|
222
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
223
|
+
LGM.exit_logic_group(self)
|
|
224
|
+
|
|
225
|
+
if exc_type is None:
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
if exc_type is self.Break:
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
# Explicitly re-raise other exceptions
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
def break_(self, scope: LogicGroup = None):
|
|
235
|
+
if scope is None:
|
|
236
|
+
scope = self
|
|
237
|
+
|
|
238
|
+
# will not break from scope in inspection mode
|
|
239
|
+
if LGM.inspection_mode:
|
|
240
|
+
active_node = LGM.active_expression
|
|
241
|
+
|
|
242
|
+
if active_node is not None:
|
|
243
|
+
active_node: LogicNode
|
|
244
|
+
if not active_node.nodes:
|
|
245
|
+
raise TooFewChildren()
|
|
246
|
+
else:
|
|
247
|
+
last_node = active_node.last_leaf
|
|
248
|
+
assert isinstance(last_node, ActionNode), NodeValueError('An ActionNode is required before breaking a LogicGroup.')
|
|
249
|
+
last_node.break_from = scope
|
|
250
|
+
LGM._exit_nodes.append(last_node)
|
|
251
|
+
return
|
|
252
|
+
|
|
253
|
+
raise scope.Break()
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def sub_logics(self) -> dict[str, Self]:
|
|
257
|
+
sub_logic_instances = {}
|
|
258
|
+
for logic_name, info in self._sub_logics.items():
|
|
259
|
+
logic_type = info["logic_type"]
|
|
260
|
+
|
|
261
|
+
# Dynamically retrieve the class using meta registry
|
|
262
|
+
logic_class = self.__class__.registry.get(logic_type)
|
|
263
|
+
|
|
264
|
+
if logic_class is None:
|
|
265
|
+
raise ValueError(f"Class {logic_type} not found in registry.")
|
|
266
|
+
|
|
267
|
+
# Get the __init__ method's signature
|
|
268
|
+
init_signature = inspect.signature(logic_class.__init__)
|
|
269
|
+
init_params = init_signature.parameters
|
|
270
|
+
|
|
271
|
+
# Prepare arguments for the sub-logic initialization
|
|
272
|
+
init_args = {}
|
|
273
|
+
for param_name, param in init_params.items():
|
|
274
|
+
if param_name == "self":
|
|
275
|
+
continue # Skip 'self'
|
|
276
|
+
|
|
277
|
+
if param_name in info:
|
|
278
|
+
init_args[param_name] = info[param_name]
|
|
279
|
+
elif param_name == "name":
|
|
280
|
+
init_args["name"] = logic_name
|
|
281
|
+
elif param_name == "parent":
|
|
282
|
+
init_args["parent"] = self
|
|
283
|
+
elif param_name == "contexts":
|
|
284
|
+
LOGGER.warning(f"Contexts dict not found for {logic_name}!")
|
|
285
|
+
init_args["contexts"] = {}
|
|
286
|
+
elif param.default == inspect.Parameter.empty:
|
|
287
|
+
# Missing a required argument that cannot be inferred
|
|
288
|
+
raise TypeError(f"Missing required argument '{param_name}' for {logic_type}.")
|
|
289
|
+
|
|
290
|
+
# Instantiate the sub-logic
|
|
291
|
+
sub_logic_instance = logic_class(**init_args)
|
|
292
|
+
sub_logic_instances[logic_name] = sub_logic_instance
|
|
293
|
+
|
|
294
|
+
return sub_logic_instances
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class SkipContextsBlock(object):
|
|
298
|
+
class _Skip(Exception):
|
|
299
|
+
pass
|
|
300
|
+
|
|
301
|
+
def _entry_check(self) -> Any:
|
|
302
|
+
"""
|
|
303
|
+
A True value indicating NOT skip.
|
|
304
|
+
a False value indicating skip the code block.
|
|
305
|
+
"""
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
@final
|
|
309
|
+
def __enter__(self):
|
|
310
|
+
if self._entry_check(): # Check if the expression evaluates to True
|
|
311
|
+
self._on_enter()
|
|
312
|
+
return self
|
|
313
|
+
|
|
314
|
+
self._original_trace = self.get_trace()
|
|
315
|
+
frame = inspect.currentframe().f_back
|
|
316
|
+
sys.settrace(self.empty_trace)
|
|
317
|
+
frame.f_trace = self.err_trace
|
|
318
|
+
|
|
319
|
+
@final
|
|
320
|
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
321
|
+
if exc_type is None:
|
|
322
|
+
self._on_exit()
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
if issubclass(exc_type, self._Skip):
|
|
326
|
+
if hasattr(self, '_original_trace'):
|
|
327
|
+
sys.settrace(self._original_trace) # Restore the original trace
|
|
328
|
+
else:
|
|
329
|
+
raise Exception('original_trace not found! Debugger broken! This should never happened.')
|
|
330
|
+
return True
|
|
331
|
+
|
|
332
|
+
self._on_exit()
|
|
333
|
+
# Propagate any other exception raised in the block
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
def _on_enter(self):
|
|
337
|
+
pass
|
|
338
|
+
|
|
339
|
+
def _on_exit(self):
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def get_trace():
|
|
344
|
+
"""
|
|
345
|
+
Safely retrieve the current trace function, prioritizing the PyDev debugger's trace function.
|
|
346
|
+
"""
|
|
347
|
+
try:
|
|
348
|
+
# Check if PyDev debugger is active
|
|
349
|
+
# noinspection PyUnresolvedReferences
|
|
350
|
+
import pydevd
|
|
351
|
+
debugger = pydevd.GetGlobalDebugger()
|
|
352
|
+
if debugger is not None:
|
|
353
|
+
return debugger.trace_dispatch # Use PyDev's trace function
|
|
354
|
+
except ImportError:
|
|
355
|
+
pass # PyDev debugger is not installed or active
|
|
356
|
+
|
|
357
|
+
# Fall back to the standard trace function
|
|
358
|
+
return sys.gettrace()
|
|
359
|
+
|
|
360
|
+
@classmethod
|
|
361
|
+
def empty_trace(cls, *args, **kwargs) -> None:
|
|
362
|
+
pass
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def err_trace(cls, frame, event, arg):
|
|
366
|
+
raise cls._Skip("Expression evaluated to be False, cannot enter the block.")
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class LogicExpression(SkipContextsBlock):
|
|
370
|
+
"""
|
|
371
|
+
Represents a logical or mathematical expression that supports deferred evaluation.
|
|
372
|
+
"""
|
|
373
|
+
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
expression: float | int | bool | Exception | Callable[[], Any],
|
|
377
|
+
dtype: type = None,
|
|
378
|
+
repr: str = None,
|
|
379
|
+
):
|
|
380
|
+
"""
|
|
381
|
+
Initialize the LogicExpression.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
expression (Union[Any, Callable[[], Any]]): A callable or static value.
|
|
385
|
+
dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
|
|
386
|
+
repr (str, optional): A string representation of the expression.
|
|
387
|
+
"""
|
|
388
|
+
self.expression = expression
|
|
389
|
+
self.dtype = dtype
|
|
390
|
+
self.repr = repr if repr is not None else str(expression)
|
|
391
|
+
|
|
392
|
+
super().__init__()
|
|
393
|
+
|
|
394
|
+
def _entry_check(self) -> Any:
|
|
395
|
+
return self.eval()
|
|
396
|
+
|
|
397
|
+
def eval(self, enforce_dtype: bool = False) -> Any:
|
|
398
|
+
"""Evaluate the expression."""
|
|
399
|
+
if isinstance(self.expression, (float, int, bool, str)):
|
|
400
|
+
value = self.expression
|
|
401
|
+
elif callable(self.expression):
|
|
402
|
+
value = self.expression()
|
|
403
|
+
elif isinstance(self.expression, Exception):
|
|
404
|
+
raise self.expression
|
|
405
|
+
else:
|
|
406
|
+
raise TypeError(f"Unsupported expression type: {type(self.expression)}.")
|
|
407
|
+
|
|
408
|
+
if self.dtype is Any or self.dtype is None:
|
|
409
|
+
pass # No type enforcement
|
|
410
|
+
elif enforce_dtype:
|
|
411
|
+
value = self.dtype(value)
|
|
412
|
+
elif not isinstance(value, self.dtype):
|
|
413
|
+
LOGGER.warning(f"Evaluated value {value} does not match dtype {self.dtype.__name__}.")
|
|
414
|
+
|
|
415
|
+
return value
|
|
416
|
+
|
|
417
|
+
# Logical operators
|
|
418
|
+
@classmethod
|
|
419
|
+
def cast(cls, value: int | float | bool | Exception | Self, dtype: type = None) -> Self:
|
|
420
|
+
"""
|
|
421
|
+
Convert a static value, callable, or error into a LogicExpression.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
value (Union[int, float, bool, LogicExpression, Callable, Exception]):
|
|
425
|
+
The value to convert. Can be:
|
|
426
|
+
- A static value (int, float, or bool).
|
|
427
|
+
- A callable returning a value.
|
|
428
|
+
- A pre-existing LogicExpression.
|
|
429
|
+
- An Exception to raise during evaluation.
|
|
430
|
+
dtype (type, optional): The expected type of the resulting LogicExpression.
|
|
431
|
+
If None, it will be inferred from the value.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
LogicExpression: The resulting LogicExpression.
|
|
435
|
+
|
|
436
|
+
Raises:
|
|
437
|
+
TypeError: If the value type is unsupported or dtype is incompatible.
|
|
438
|
+
"""
|
|
439
|
+
if isinstance(value, LogicExpression):
|
|
440
|
+
return value
|
|
441
|
+
if isinstance(value, (int, float, bool)):
|
|
442
|
+
return LogicExpression(
|
|
443
|
+
expression=value,
|
|
444
|
+
dtype=dtype or type(value),
|
|
445
|
+
repr=str(value)
|
|
446
|
+
)
|
|
447
|
+
if callable(value):
|
|
448
|
+
return LogicExpression(
|
|
449
|
+
expression=value,
|
|
450
|
+
dtype=dtype or Any,
|
|
451
|
+
repr=f"Eval({value})"
|
|
452
|
+
)
|
|
453
|
+
if isinstance(value, Exception):
|
|
454
|
+
return LogicExpression(
|
|
455
|
+
expression=value,
|
|
456
|
+
dtype=dtype or Any,
|
|
457
|
+
repr=f"Raises({type(value).__name__}: {value})"
|
|
458
|
+
)
|
|
459
|
+
raise TypeError(f"Unsupported type for LogicExpression conversion: {type(value)}.")
|
|
460
|
+
|
|
461
|
+
def __bool__(self) -> bool:
|
|
462
|
+
return bool(self.eval())
|
|
463
|
+
|
|
464
|
+
def __and__(self, other: Self | bool) -> Self:
|
|
465
|
+
other_expr = self.cast(value=other, dtype=bool)
|
|
466
|
+
new_expr = LogicExpression(
|
|
467
|
+
expression=lambda: self.eval() and other_expr.eval(),
|
|
468
|
+
dtype=bool,
|
|
469
|
+
repr=f"({self.repr} and {other_expr.repr})"
|
|
470
|
+
)
|
|
471
|
+
return new_expr
|
|
472
|
+
|
|
473
|
+
def __eq__(self, other: int | float | bool | str | Self) -> Self:
|
|
474
|
+
if isinstance(other, LogicExpression):
|
|
475
|
+
other_value = other.eval()
|
|
476
|
+
else:
|
|
477
|
+
other_value = other
|
|
478
|
+
|
|
479
|
+
return LogicExpression(
|
|
480
|
+
expression=lambda: self.eval() == other_value,
|
|
481
|
+
dtype=bool,
|
|
482
|
+
repr=f"({self.repr} == {repr(other_value)})"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
def __or__(self, other: Self | bool) -> Self:
|
|
486
|
+
other_expr = self.cast(value=other, dtype=bool)
|
|
487
|
+
new_expr = LogicExpression(
|
|
488
|
+
expression=lambda: self.eval() or other_expr.eval(),
|
|
489
|
+
dtype=bool,
|
|
490
|
+
repr=f"({self.repr} or {other_expr.repr})"
|
|
491
|
+
)
|
|
492
|
+
return new_expr
|
|
493
|
+
|
|
494
|
+
# Math operators
|
|
495
|
+
@classmethod
|
|
496
|
+
def _math_op(cls, self: Self, other: int | float | Self, op: Callable, operator_str: str, dtype: type = None) -> Self:
|
|
497
|
+
other_expr = LogicExpression.cast(other)
|
|
498
|
+
|
|
499
|
+
if dtype is None:
|
|
500
|
+
dtype = self.dtype
|
|
501
|
+
|
|
502
|
+
new_expr = LogicExpression(
|
|
503
|
+
expression=lambda: op(self.eval(), other_expr.eval()),
|
|
504
|
+
dtype=dtype,
|
|
505
|
+
repr=f"({self.repr} {operator_str} {other_expr.repr})",
|
|
506
|
+
)
|
|
507
|
+
return new_expr
|
|
508
|
+
|
|
509
|
+
def __add__(self, other: int | float | bool | Self) -> Self:
|
|
510
|
+
return self._math_op(self=self, other=other, op=operator.add, operator_str="+")
|
|
511
|
+
|
|
512
|
+
def __sub__(self, other: int | float | bool | Self) -> Self:
|
|
513
|
+
return self._math_op(self=self, other=other, op=operator.sub, operator_str="-")
|
|
514
|
+
|
|
515
|
+
def __mul__(self, other: int | float | bool | Self) -> Self:
|
|
516
|
+
return self._math_op(self=self, other=other, op=operator.mul, operator_str="*")
|
|
517
|
+
|
|
518
|
+
def __truediv__(self, other: int | float | bool | Self) -> Self:
|
|
519
|
+
return self._math_op(self=self, other=other, op=operator.truediv, operator_str="/")
|
|
520
|
+
|
|
521
|
+
def __floordiv__(self, other: int | float | bool | Self) -> Self:
|
|
522
|
+
return self._math_op(self=self, other=other, op=operator.floordiv, operator_str="//")
|
|
523
|
+
|
|
524
|
+
def __pow__(self, other: int | float | bool | Self) -> Self:
|
|
525
|
+
return self._math_op(self=self, other=other, op=operator.pow, operator_str="**")
|
|
526
|
+
|
|
527
|
+
# Comparison operators, note that __eq__, __ne__ is special and should not implement as math operator
|
|
528
|
+
def __lt__(self, other: int | float | bool | Self) -> Self:
|
|
529
|
+
return self._math_op(self=self, other=other, op=operator.lt, operator_str="<", dtype=bool)
|
|
530
|
+
|
|
531
|
+
def __le__(self, other: int | float | bool | Self) -> Self:
|
|
532
|
+
return self._math_op(self=self, other=other, op=operator.le, operator_str="<=", dtype=bool)
|
|
533
|
+
|
|
534
|
+
def __gt__(self, other: int | float | bool | Self) -> Self:
|
|
535
|
+
return self._math_op(self=self, other=other, op=operator.gt, operator_str=">", dtype=bool)
|
|
536
|
+
|
|
537
|
+
def __ge__(self, other: int | float | bool | Self) -> Self:
|
|
538
|
+
return self._math_op(self=self, other=other, op=operator.ge, operator_str=">=", dtype=bool)
|
|
539
|
+
|
|
540
|
+
def __repr__(self) -> str:
|
|
541
|
+
return f"LogicExpression(dtype={'Any' if self.dtype is None else self.dtype.__name__}, repr={self.repr})"
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class ExpressionCollection(LogicGroup):
|
|
545
|
+
def __init__(self, data: Any, name: str, repr: str = None, **kwargs):
|
|
546
|
+
if 'logic_group' not in kwargs:
|
|
547
|
+
logic_group = kwargs.get("logic_group")
|
|
548
|
+
else:
|
|
549
|
+
logic_group = LGM.active_logic_group
|
|
550
|
+
|
|
551
|
+
super().__init__(
|
|
552
|
+
name=repr if repr is not None else name if logic_group is None else f'{logic_group.name}.{name}',
|
|
553
|
+
parent=logic_group
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
self.data = self.contexts.setdefault('data', data)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
class LogicNode(LogicExpression):
|
|
560
|
+
def __init__(
|
|
561
|
+
self,
|
|
562
|
+
expression: float | int | bool | Exception | Callable[[], Any],
|
|
563
|
+
dtype: type = None,
|
|
564
|
+
repr: str = None,
|
|
565
|
+
):
|
|
566
|
+
"""
|
|
567
|
+
Initialize the LogicExpression.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
expression (Union[Any, Callable[[], Any]]): A callable or static value.
|
|
571
|
+
dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
|
|
572
|
+
repr (str, optional): A string representation of the expression.
|
|
573
|
+
"""
|
|
574
|
+
super().__init__(expression=expression, dtype=dtype, repr=repr)
|
|
575
|
+
|
|
576
|
+
self.labels = [_.name for _ in LGM._active_groups]
|
|
577
|
+
self.nodes: dict[Any, LogicNode] = {} # Dict[condition, LogicExpression]
|
|
578
|
+
self.parent: LogicNode | None = None
|
|
579
|
+
self.edges = [] # list of condition
|
|
580
|
+
self.subordinates = [] # all the subordinate nodes initialized inside this node with statement
|
|
581
|
+
|
|
582
|
+
def _entry_check(self) -> Any:
|
|
583
|
+
"""
|
|
584
|
+
If `LGM.inspection_mode` is active, always returns `True`.
|
|
585
|
+
Which guarantees the entrance the with code block
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Any: Evaluation result.
|
|
589
|
+
"""
|
|
590
|
+
if LGM.inspection_mode:
|
|
591
|
+
return True
|
|
592
|
+
return self.eval()
|
|
593
|
+
|
|
594
|
+
def __rshift__(self, expression: Self):
|
|
595
|
+
"""Overloads >> operator for adding child nodes."""
|
|
596
|
+
self.append(expression)
|
|
597
|
+
return expression # Allow chaining
|
|
598
|
+
|
|
599
|
+
def __call__(self, default=None) -> Any:
|
|
600
|
+
"""
|
|
601
|
+
Recursively evaluates the decision tree starting from this node.
|
|
602
|
+
|
|
603
|
+
Keyword Args:
|
|
604
|
+
default (Any, optional): Fallback value if no matching condition is found.
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
final_value (Any): The evaluated result of the tree.
|
|
608
|
+
|
|
609
|
+
Raises:
|
|
610
|
+
ValueError: If no matching condition is found and no default value is provided.
|
|
611
|
+
"""
|
|
612
|
+
|
|
613
|
+
if default is None:
|
|
614
|
+
from .node import NoAction
|
|
615
|
+
default = NoAction(auto_connect=False)
|
|
616
|
+
|
|
617
|
+
if _ins_mode := LGM.inspection_mode:
|
|
618
|
+
LOGGER.info('LGM inspection mode temporally disabled to evaluate correctly.')
|
|
619
|
+
LGM.inspection_mode = False
|
|
620
|
+
|
|
621
|
+
_, path = self.eval_recursively(default=default)
|
|
622
|
+
LGM.inspection_mode = _ins_mode
|
|
623
|
+
if not path:
|
|
624
|
+
raise TooFewChildren()
|
|
625
|
+
|
|
626
|
+
leaf = path[-1]
|
|
627
|
+
return leaf.eval()
|
|
628
|
+
|
|
629
|
+
def __repr__(self):
|
|
630
|
+
return f'<{self.__class__.__name__}>({self.repr!r})'
|
|
631
|
+
|
|
632
|
+
def _on_enter(self):
|
|
633
|
+
active_node: LogicNode = LGM.active_expression
|
|
634
|
+
|
|
635
|
+
if active_node is None:
|
|
636
|
+
return LGM.enter_expression(node=self)
|
|
637
|
+
|
|
638
|
+
match active_node.subordinates:
|
|
639
|
+
case []:
|
|
640
|
+
active_node.append(expression=self, edge_condition=True)
|
|
641
|
+
|
|
642
|
+
case [*_, last_node] if not last_node.nodes:
|
|
643
|
+
raise TooFewChildren()
|
|
644
|
+
|
|
645
|
+
case [*_, last_node] if len(last_node.nodes) == 1:
|
|
646
|
+
edge_condition = last_node.last_edge
|
|
647
|
+
if not isinstance(edge_condition, bool):
|
|
648
|
+
raise EdgeValueError(f'{last_node} Edge condition must be a Boolean!')
|
|
649
|
+
last_node.append(expression=self, edge_condition=not edge_condition)
|
|
650
|
+
|
|
651
|
+
case [*_, last_node] if len(last_node.nodes) == 2:
|
|
652
|
+
from .node import NoAction
|
|
653
|
+
edge_condition, child = last_node.last_edge, last_node.last_node
|
|
654
|
+
if not isinstance(child, NoAction):
|
|
655
|
+
raise NodeValueError(f'{last_node} second child node must be a NoAction node!')
|
|
656
|
+
last_node.pop(-1)
|
|
657
|
+
last_node.append(expression=self, edge_condition=edge_condition)
|
|
658
|
+
|
|
659
|
+
case [*_, last_node] if len(last_node.nodes) > 2:
|
|
660
|
+
raise TooManyChildren()
|
|
661
|
+
|
|
662
|
+
if isinstance(self, ActionNode):
|
|
663
|
+
pass
|
|
664
|
+
else:
|
|
665
|
+
LGM.enter_expression(node=self)
|
|
666
|
+
|
|
667
|
+
def _on_exit(self):
|
|
668
|
+
self.fill_binary_branch(node=self)
|
|
669
|
+
LGM.exit_expression(node=self)
|
|
670
|
+
|
|
671
|
+
@classmethod
|
|
672
|
+
def fill_binary_branch(cls, node: LogicNode, with_action: ActionNode = None):
|
|
673
|
+
"""
|
|
674
|
+
Ensures the decision tree node has both True and False branches.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
node (LogicNode): The node to check.
|
|
678
|
+
with_action (ActionNode, optional): A default action node to add if missing.
|
|
679
|
+
"""
|
|
680
|
+
if with_action is None:
|
|
681
|
+
from .node import NoAction
|
|
682
|
+
with_action = NoAction(auto_connect=False)
|
|
683
|
+
|
|
684
|
+
if isinstance(node, ActionNode):
|
|
685
|
+
return
|
|
686
|
+
|
|
687
|
+
match len(node.nodes):
|
|
688
|
+
case 0:
|
|
689
|
+
LOGGER.warning(f"It is rear that {node} having no True branch. Check the <with> statement code block to see if this is intended.")
|
|
690
|
+
node.append(expression=with_action, edge_condition=False)
|
|
691
|
+
case 1:
|
|
692
|
+
edge_condition = node.last_edge
|
|
693
|
+
if not isinstance(edge_condition, bool):
|
|
694
|
+
raise EdgeValueError(f'{node} Edge condition must be a Boolean!')
|
|
695
|
+
node.append(expression=with_action, edge_condition=not edge_condition)
|
|
696
|
+
case _:
|
|
697
|
+
raise TooManyChildren()
|
|
698
|
+
|
|
699
|
+
@classmethod
|
|
700
|
+
def traverse(cls, node: Self, G=None, node_map: dict[int, Self] = None, parent: Self = None, edge_condition: Any = None):
|
|
701
|
+
"""
|
|
702
|
+
Recursively traverses the decision tree, adding nodes and edges to the graph.
|
|
703
|
+
|
|
704
|
+
Args:
|
|
705
|
+
node (LogicNode): The current node being traversed.
|
|
706
|
+
G (networkx.DiGraph, optional): The graph being constructed. Defaults to a new graph.
|
|
707
|
+
node_map (dict, optional): A dictionary mapping node IDs to LogicNode instances.
|
|
708
|
+
parent (LogicNode, optional): The parent node of the current node.
|
|
709
|
+
edge_condition (Any, optional): The condition from parent to this node.
|
|
710
|
+
"""
|
|
711
|
+
import networkx as nx
|
|
712
|
+
|
|
713
|
+
if G is None:
|
|
714
|
+
G = nx.DiGraph()
|
|
715
|
+
if node_map is None:
|
|
716
|
+
node_map = {}
|
|
717
|
+
|
|
718
|
+
node_id = id(node)
|
|
719
|
+
# if node_id in node_map:
|
|
720
|
+
# return # Avoid duplicate traversal
|
|
721
|
+
|
|
722
|
+
node_map[node_id] = node
|
|
723
|
+
G.add_node(node_id, description=node.repr)
|
|
724
|
+
|
|
725
|
+
if parent is not None:
|
|
726
|
+
edge_label = str(edge_condition) # Use the edge condition from the parent's children list
|
|
727
|
+
G.add_edge(id(parent), node_id, label=edge_label)
|
|
728
|
+
|
|
729
|
+
for edge_condition, child in node.nodes.items():
|
|
730
|
+
cls.traverse(node=child, G=G, node_map=node_map, parent=node, edge_condition=edge_condition)
|
|
731
|
+
|
|
732
|
+
return G, node_map
|
|
733
|
+
|
|
734
|
+
def append(self, expression: LogicNode, edge_condition: Any = None):
|
|
735
|
+
"""
|
|
736
|
+
Adds a child node to the current node.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
expression (LogicNode): The child node.
|
|
740
|
+
edge_condition (Any, optional): The condition for branching.
|
|
741
|
+
|
|
742
|
+
Raises:
|
|
743
|
+
ValueError: If no edge condition is provided.
|
|
744
|
+
"""
|
|
745
|
+
if edge_condition is None:
|
|
746
|
+
edge_condition = NO_CONDITION
|
|
747
|
+
|
|
748
|
+
if edge_condition is None:
|
|
749
|
+
raise ValueError("Child LogicExpression must have an edge condition.")
|
|
750
|
+
|
|
751
|
+
if edge_condition in self.nodes:
|
|
752
|
+
raise ValueError(f"Edge {edge_condition} already exists.")
|
|
753
|
+
|
|
754
|
+
self.edges.append(edge_condition)
|
|
755
|
+
self.nodes[edge_condition] = expression
|
|
756
|
+
expression.parent = self
|
|
757
|
+
|
|
758
|
+
def pop(self, index: int = -1) -> tuple[Any, LogicNode]:
|
|
759
|
+
edge = self.edges.pop(index)
|
|
760
|
+
node = self.nodes.pop(edge)
|
|
761
|
+
return edge, node
|
|
762
|
+
|
|
763
|
+
def replace(self, original_node: LogicNode, new_node: LogicNode):
|
|
764
|
+
for condition, node in self.nodes.items():
|
|
765
|
+
if node is original_node:
|
|
766
|
+
break
|
|
767
|
+
else:
|
|
768
|
+
raise NodeNotFountError()
|
|
769
|
+
|
|
770
|
+
self.nodes[condition] = new_node
|
|
771
|
+
|
|
772
|
+
def eval_recursively(self, **kwargs):
|
|
773
|
+
"""
|
|
774
|
+
Recursively evaluates the decision tree starting from this node.
|
|
775
|
+
|
|
776
|
+
Keyword Args:
|
|
777
|
+
path (list, optional): Tracks the decision path during evaluation. Defaults to a new list.
|
|
778
|
+
default (Any, optional): Fallback value if no matching condition is found.
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
tuple: (final_value, decision_path)
|
|
782
|
+
- final_value (Any): The evaluated result of the tree.
|
|
783
|
+
- decision_path (list): The sequence of nodes traversed during evaluation.
|
|
784
|
+
|
|
785
|
+
Raises:
|
|
786
|
+
ValueError: If no matching condition is found and no default value is provided.
|
|
787
|
+
"""
|
|
788
|
+
if 'path' in kwargs:
|
|
789
|
+
path = kwargs['path']
|
|
790
|
+
else:
|
|
791
|
+
path = [self]
|
|
792
|
+
|
|
793
|
+
value = self.eval()
|
|
794
|
+
|
|
795
|
+
if not self.nodes:
|
|
796
|
+
return value, path
|
|
797
|
+
|
|
798
|
+
for condition, child in self.nodes.items():
|
|
799
|
+
if condition == value or condition is NO_CONDITION:
|
|
800
|
+
return child.eval_recursively(path=path)
|
|
801
|
+
|
|
802
|
+
if 'default' in kwargs:
|
|
803
|
+
default = kwargs['default']
|
|
804
|
+
LOGGER.info(f"No matching condition found for value {value} at '{self.repr}', using default {default}.")
|
|
805
|
+
return default, path
|
|
806
|
+
|
|
807
|
+
raise ValueError(f"No matching condition found for value {value} at '{self.repr}'.")
|
|
808
|
+
|
|
809
|
+
def list_labels(self) -> dict[str, list[LogicNode]]:
|
|
810
|
+
"""
|
|
811
|
+
Lists all logic groups in the tree and returns a dictionary mapping group names to nodes.
|
|
812
|
+
"""
|
|
813
|
+
labels = {}
|
|
814
|
+
|
|
815
|
+
def traverse(node):
|
|
816
|
+
for group in node.labels:
|
|
817
|
+
if group not in labels:
|
|
818
|
+
labels[group] = []
|
|
819
|
+
labels[group].append(node)
|
|
820
|
+
for _, child in node.nodes.items():
|
|
821
|
+
traverse(child)
|
|
822
|
+
|
|
823
|
+
traverse(self)
|
|
824
|
+
return labels
|
|
825
|
+
|
|
826
|
+
def select_node(self, label: str) -> LogicNode | None:
|
|
827
|
+
"""
|
|
828
|
+
Selects the root node of a logic group and validates that the group is chained.
|
|
829
|
+
"""
|
|
830
|
+
labels = self.list_labels()
|
|
831
|
+
if label not in labels:
|
|
832
|
+
return None
|
|
833
|
+
|
|
834
|
+
nodes = labels[label]
|
|
835
|
+
root = None
|
|
836
|
+
|
|
837
|
+
for node in nodes:
|
|
838
|
+
if not any(node in child_nodes for _, child_nodes in labels.items() if _ != label):
|
|
839
|
+
if root is not None:
|
|
840
|
+
raise ValueError(f"Logic group '{label}' has multiple roots.")
|
|
841
|
+
root = node
|
|
842
|
+
|
|
843
|
+
return root
|
|
844
|
+
|
|
845
|
+
def to_html(self, with_group=True, dry_run=True, filename="decision_tree.html", **kwargs):
|
|
846
|
+
"""
|
|
847
|
+
Visualizes the decision tree using PyVis.
|
|
848
|
+
If dry_run=True, shows structure without highlighting active path.
|
|
849
|
+
If dry_run=False, evaluates the tree and highlights the decision path.
|
|
850
|
+
If with_group=True, uses grouped logic view.
|
|
851
|
+
"""
|
|
852
|
+
from pyvis.network import Network
|
|
853
|
+
|
|
854
|
+
G, node_map = self.traverse(self)
|
|
855
|
+
# Highlight path if not in dry run
|
|
856
|
+
activated_path = []
|
|
857
|
+
if not dry_run:
|
|
858
|
+
try:
|
|
859
|
+
_, path = self.eval_recursively()
|
|
860
|
+
activated_path = [id(node) for node in path]
|
|
861
|
+
except Exception:
|
|
862
|
+
activated_path.clear()
|
|
863
|
+
dry_run = True
|
|
864
|
+
LOGGER.error(f"Failed to evaluate decision tree.\n{traceback.format_exc()}")
|
|
865
|
+
|
|
866
|
+
# Visualization using PyVis
|
|
867
|
+
net = Network(
|
|
868
|
+
height=kwargs.get('height', "750px"),
|
|
869
|
+
width=kwargs.get('width', "100%"),
|
|
870
|
+
directed=True,
|
|
871
|
+
notebook=False,
|
|
872
|
+
neighborhood_highlight=True
|
|
873
|
+
)
|
|
874
|
+
default_color = kwargs.get('default_color', "lightblue")
|
|
875
|
+
highlight_color = kwargs.get('highlight_color', "lightgreen")
|
|
876
|
+
activated_color = kwargs.get('selected_color', "lightyellow")
|
|
877
|
+
dimmed_color = kwargs.get('dimmed_color', "#e0e0e0")
|
|
878
|
+
logic_shape = kwargs.get('logic_shape', "box")
|
|
879
|
+
action_shape = kwargs.get('action_shape', "ellipse")
|
|
880
|
+
|
|
881
|
+
original_colors = {}
|
|
882
|
+
|
|
883
|
+
# Add nodes with group information
|
|
884
|
+
for node_id, node in node_map.items():
|
|
885
|
+
label = node.repr
|
|
886
|
+
title = f"Node: {node.repr}"
|
|
887
|
+
|
|
888
|
+
# Track the original color for each node
|
|
889
|
+
node_color = activated_color if node_id in activated_path else default_color
|
|
890
|
+
original_colors[node_id] = node_color
|
|
891
|
+
|
|
892
|
+
if with_group:
|
|
893
|
+
net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape, groups=node.labels)
|
|
894
|
+
else:
|
|
895
|
+
net.add_node(node_id, label=label, title=title, color=node_color, shape=action_shape if isinstance(node, ActionNode) else logic_shape)
|
|
896
|
+
|
|
897
|
+
# Add edges
|
|
898
|
+
for source, target, data in G.edges(data=True):
|
|
899
|
+
edge_label = data.get("label", "")
|
|
900
|
+
edge_color = "black" if dry_run else ("green" if source in activated_path and target in activated_path else "black")
|
|
901
|
+
net.add_edge(source, target, label=edge_label, title=edge_label, color=edge_color, arrows="to")
|
|
902
|
+
|
|
903
|
+
# Configure layout and options
|
|
904
|
+
options = {
|
|
905
|
+
"layout": {
|
|
906
|
+
"hierarchical": {
|
|
907
|
+
"enabled": True,
|
|
908
|
+
"direction": "UD", # UD = Up-Down (root at top, leaves at bottom)
|
|
909
|
+
"sortMethod": "directed",
|
|
910
|
+
"nodeSpacing": 150,
|
|
911
|
+
"levelSeparation": 200
|
|
912
|
+
}
|
|
913
|
+
},
|
|
914
|
+
"physics": {
|
|
915
|
+
"hierarchicalRepulsion": {
|
|
916
|
+
"centralGravity": 0.0,
|
|
917
|
+
"springLength": 200,
|
|
918
|
+
"springConstant": 0.01,
|
|
919
|
+
"nodeDistance": 200,
|
|
920
|
+
"damping": 0.09
|
|
921
|
+
},
|
|
922
|
+
"minVelocity": 0.75,
|
|
923
|
+
"solver": "hierarchicalRepulsion"
|
|
924
|
+
},
|
|
925
|
+
"nodes": {
|
|
926
|
+
"shape": "box",
|
|
927
|
+
"shapeProperties": {"borderRadius": 10},
|
|
928
|
+
"font": {"size": 14}
|
|
929
|
+
},
|
|
930
|
+
"edges": {
|
|
931
|
+
"color": "black",
|
|
932
|
+
"smooth": True
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
net.set_options(json.dumps(options))
|
|
937
|
+
|
|
938
|
+
# Generate the base HTML
|
|
939
|
+
html = net.generate_html()
|
|
940
|
+
|
|
941
|
+
# Inject custom controls and JavaScript
|
|
942
|
+
buttons_html = """
|
|
943
|
+
<div style="position: absolute; top: 10px; left: 10px; z-index: 1000;
|
|
944
|
+
background: rgba(255, 255, 255, 0.9); padding: 12px;
|
|
945
|
+
border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.3);
|
|
946
|
+
font-family: Arial, sans-serif;">
|
|
947
|
+
|
|
948
|
+
<h4 style="margin: 0 0 10px; font-size: 16px; text-align: center; color: #333;">
|
|
949
|
+
Decision Tree Controls
|
|
950
|
+
</h4>
|
|
951
|
+
|
|
952
|
+
<button onclick="resetColors()" class="control-btn">Reset</button>
|
|
953
|
+
"""
|
|
954
|
+
|
|
955
|
+
if with_group:
|
|
956
|
+
groups = {group for node in node_map.values() for group in node.labels}
|
|
957
|
+
for group in sorted(groups):
|
|
958
|
+
buttons_html += f'<button onclick="highlightGroup(\'{group}\')" class="control-btn">{group}</button>'
|
|
959
|
+
buttons_html += "</div>"
|
|
960
|
+
|
|
961
|
+
js_code = f"""
|
|
962
|
+
<script>
|
|
963
|
+
function resetColors() {{
|
|
964
|
+
// Reset all nodes to their original color and opacity
|
|
965
|
+
nodes.forEach(function(node) {{
|
|
966
|
+
nodes.update([{{
|
|
967
|
+
id: node.id,
|
|
968
|
+
color: originalColors[node.id], // Reset to original color
|
|
969
|
+
opacity: 1
|
|
970
|
+
}}]);
|
|
971
|
+
}});
|
|
972
|
+
|
|
973
|
+
// Reset all edges to default color and opacity
|
|
974
|
+
edges.forEach(function(edge) {{
|
|
975
|
+
edges.update([{{
|
|
976
|
+
id: edge.id,
|
|
977
|
+
color: "black",
|
|
978
|
+
opacity: 1
|
|
979
|
+
}}]);
|
|
980
|
+
}});
|
|
981
|
+
}}
|
|
982
|
+
|
|
983
|
+
function highlightGroup(group) {{
|
|
984
|
+
// Dim all nodes and edges first
|
|
985
|
+
nodes.update([...nodes.getIds().map(id => ({{
|
|
986
|
+
id: id,
|
|
987
|
+
color: "{dimmed_color}",
|
|
988
|
+
opacity: 0.3
|
|
989
|
+
}}))]);
|
|
990
|
+
|
|
991
|
+
edges.update([...edges.getIds().map(id => ({{
|
|
992
|
+
id: id,
|
|
993
|
+
color: "gray",
|
|
994
|
+
opacity: 0.2
|
|
995
|
+
}}))]);
|
|
996
|
+
|
|
997
|
+
// Highlight nodes in the selected group
|
|
998
|
+
const groupNodes = nodes.get({{
|
|
999
|
+
filter: node => node.groups.includes(group)
|
|
1000
|
+
}});
|
|
1001
|
+
|
|
1002
|
+
nodes.update([...groupNodes.map(node => ({{
|
|
1003
|
+
id: node.id,
|
|
1004
|
+
color: "{highlight_color}",
|
|
1005
|
+
opacity: 1
|
|
1006
|
+
}}))]);
|
|
1007
|
+
|
|
1008
|
+
// Highlight connected edges
|
|
1009
|
+
const connectedEdges = edges.get({{
|
|
1010
|
+
filter: edge =>
|
|
1011
|
+
groupNodes.some(n => n.id === edge.from) ||
|
|
1012
|
+
groupNodes.some(n => n.id === edge.to)
|
|
1013
|
+
}});
|
|
1014
|
+
|
|
1015
|
+
edges.update([...connectedEdges.map(edge => ({{
|
|
1016
|
+
id: edge.id,
|
|
1017
|
+
color: "black",
|
|
1018
|
+
opacity: 1
|
|
1019
|
+
}}))]);
|
|
1020
|
+
}}
|
|
1021
|
+
|
|
1022
|
+
// Store the original node colors for reset functionality
|
|
1023
|
+
const originalColors = {json.dumps(original_colors)};
|
|
1024
|
+
</script>
|
|
1025
|
+
"""
|
|
1026
|
+
|
|
1027
|
+
# Inject better styles for buttons
|
|
1028
|
+
css_styles = """
|
|
1029
|
+
<style>
|
|
1030
|
+
.control-btn {
|
|
1031
|
+
background-color: #007BFF;
|
|
1032
|
+
color: white;
|
|
1033
|
+
border: none;
|
|
1034
|
+
padding: 8px 14px;
|
|
1035
|
+
margin: 5px;
|
|
1036
|
+
font-size: 14px;
|
|
1037
|
+
border-radius: 5px;
|
|
1038
|
+
cursor: pointer;
|
|
1039
|
+
transition: background 0.3s ease;
|
|
1040
|
+
}
|
|
1041
|
+
|
|
1042
|
+
.control-btn:hover {
|
|
1043
|
+
background-color: #0056b3;
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
.control-btn:active {
|
|
1047
|
+
background-color: #003f7f;
|
|
1048
|
+
}
|
|
1049
|
+
</style>
|
|
1050
|
+
"""
|
|
1051
|
+
|
|
1052
|
+
# Insert custom elements into the HTML
|
|
1053
|
+
html = html.replace("</head>", f"{css_styles}</head>")
|
|
1054
|
+
html = html.replace("</body>", f"{buttons_html}{js_code}</body>")
|
|
1055
|
+
|
|
1056
|
+
# Save the modified HTML
|
|
1057
|
+
with open(filename, "w") as f:
|
|
1058
|
+
f.write(html)
|
|
1059
|
+
|
|
1060
|
+
LOGGER.info(f"Decision tree saved to {filename}")
|
|
1061
|
+
|
|
1062
|
+
@property
|
|
1063
|
+
def children(self) -> Iterable[tuple[Any, LogicNode]]:
|
|
1064
|
+
"""Returns an iterable of (edge, node) pairs."""
|
|
1065
|
+
return iter(self.nodes.items())
|
|
1066
|
+
|
|
1067
|
+
@property
|
|
1068
|
+
def leaves(self) -> Iterable[LogicNode]:
|
|
1069
|
+
"""Recursively finds and returns all leaf nodes (nodes without children)."""
|
|
1070
|
+
if not self.nodes: # If no children, this node is a leaf
|
|
1071
|
+
yield self
|
|
1072
|
+
else:
|
|
1073
|
+
for _, child in self.nodes.items(): # Recursively get leaves from children
|
|
1074
|
+
yield from child.leaves
|
|
1075
|
+
|
|
1076
|
+
@property
|
|
1077
|
+
def last_edge(self) -> Any:
|
|
1078
|
+
return self.edges[-1]
|
|
1079
|
+
|
|
1080
|
+
@property
|
|
1081
|
+
def last_node(self) -> LogicNode:
|
|
1082
|
+
return self.nodes[self.last_edge]
|
|
1083
|
+
|
|
1084
|
+
@property
|
|
1085
|
+
def last_leaf(self) -> LogicNode:
|
|
1086
|
+
if not self.nodes:
|
|
1087
|
+
return self
|
|
1088
|
+
return self.last_node.last_leaf
|
|
1089
|
+
|
|
1090
|
+
@property
|
|
1091
|
+
def last_leaf_expression(self) -> LogicNode:
|
|
1092
|
+
last_leaf = self.last_leaf
|
|
1093
|
+
if isinstance(last_leaf, ActionNode):
|
|
1094
|
+
return last_leaf.parent
|
|
1095
|
+
return last_leaf
|
|
1096
|
+
|
|
1097
|
+
|
|
1098
|
+
class ActionNode(LogicNode):
|
|
1099
|
+
def __init__(
|
|
1100
|
+
self,
|
|
1101
|
+
action: float | int | bool | None | Exception | Callable[[], Any],
|
|
1102
|
+
dtype: type = None,
|
|
1103
|
+
repr: str = None,
|
|
1104
|
+
auto_connect: bool = True
|
|
1105
|
+
):
|
|
1106
|
+
"""
|
|
1107
|
+
Initialize the LogicExpression.
|
|
1108
|
+
|
|
1109
|
+
Args:
|
|
1110
|
+
action (Union[Any, Callable[[], Any]]): The action to execute.
|
|
1111
|
+
dtype (type, optional): The expected type of the evaluated value (float, int, or bool).
|
|
1112
|
+
repr (str, optional): A string representation of the expression.
|
|
1113
|
+
"""
|
|
1114
|
+
super().__init__(expression=True, dtype=dtype, repr=repr)
|
|
1115
|
+
self.action = action
|
|
1116
|
+
|
|
1117
|
+
if auto_connect:
|
|
1118
|
+
super()._on_enter()
|
|
1119
|
+
|
|
1120
|
+
def _on_enter(self):
|
|
1121
|
+
LOGGER.warning(f'{self.__class__.__name__} should not use with claude')
|
|
1122
|
+
|
|
1123
|
+
def _on_exit(self):
|
|
1124
|
+
pass
|
|
1125
|
+
|
|
1126
|
+
def eval_recursively(self, path=None):
|
|
1127
|
+
"""
|
|
1128
|
+
Evaluates the decision tree from this node based on the given state.
|
|
1129
|
+
Returns the final action and records the decision path.
|
|
1130
|
+
"""
|
|
1131
|
+
if path is None:
|
|
1132
|
+
path = []
|
|
1133
|
+
path.append(self)
|
|
1134
|
+
|
|
1135
|
+
value = self.eval()
|
|
1136
|
+
|
|
1137
|
+
if self.action is not None:
|
|
1138
|
+
self.action()
|
|
1139
|
+
|
|
1140
|
+
for condition, child in self.nodes.items():
|
|
1141
|
+
if condition == value or condition is NO_CONDITION:
|
|
1142
|
+
return child.eval_recursively(path=path)
|
|
1143
|
+
|
|
1144
|
+
return value, path
|
|
1145
|
+
|
|
1146
|
+
def append(self, expression: Self, edge_condition: Any = None):
|
|
1147
|
+
raise TooManyChildren("Cannot append child to an ActionNode!")
|