mycorrhizal 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. mycorrhizal/_version.py +1 -0
  2. mycorrhizal/common/__init__.py +15 -3
  3. mycorrhizal/common/cache.py +114 -0
  4. mycorrhizal/common/compilation.py +263 -0
  5. mycorrhizal/common/interface_detection.py +159 -0
  6. mycorrhizal/common/interfaces.py +3 -50
  7. mycorrhizal/common/mermaid.py +124 -0
  8. mycorrhizal/common/wrappers.py +1 -1
  9. mycorrhizal/hypha/core/builder.py +56 -8
  10. mycorrhizal/hypha/core/runtime.py +242 -107
  11. mycorrhizal/hypha/core/specs.py +19 -3
  12. mycorrhizal/mycelium/__init__.py +174 -0
  13. mycorrhizal/mycelium/core.py +619 -0
  14. mycorrhizal/mycelium/exceptions.py +30 -0
  15. mycorrhizal/mycelium/hypha_bridge.py +1143 -0
  16. mycorrhizal/mycelium/instance.py +440 -0
  17. mycorrhizal/mycelium/pn_context.py +276 -0
  18. mycorrhizal/mycelium/runner.py +165 -0
  19. mycorrhizal/mycelium/spores_integration.py +655 -0
  20. mycorrhizal/mycelium/tree_builder.py +102 -0
  21. mycorrhizal/mycelium/tree_spec.py +197 -0
  22. mycorrhizal/rhizomorph/README.md +82 -33
  23. mycorrhizal/rhizomorph/core.py +308 -82
  24. mycorrhizal/septum/TRANSITION_REFERENCE.md +385 -0
  25. mycorrhizal/{enoki → septum}/core.py +326 -100
  26. mycorrhizal/{enoki → septum}/testing_utils.py +7 -7
  27. mycorrhizal/{enoki → septum}/util.py +44 -21
  28. mycorrhizal/spores/__init__.py +72 -19
  29. mycorrhizal/spores/core.py +907 -75
  30. mycorrhizal/spores/dsl/__init__.py +8 -8
  31. mycorrhizal/spores/dsl/hypha.py +3 -15
  32. mycorrhizal/spores/dsl/rhizomorph.py +3 -11
  33. mycorrhizal/spores/dsl/{enoki.py → septum.py} +26 -77
  34. mycorrhizal/spores/encoder/json.py +21 -12
  35. mycorrhizal/spores/extraction.py +14 -11
  36. mycorrhizal/spores/models.py +75 -20
  37. mycorrhizal/spores/transport/__init__.py +9 -2
  38. mycorrhizal/spores/transport/base.py +36 -17
  39. mycorrhizal/spores/transport/file.py +126 -0
  40. mycorrhizal-0.2.0.dist-info/METADATA +335 -0
  41. mycorrhizal-0.2.0.dist-info/RECORD +54 -0
  42. mycorrhizal-0.1.0.dist-info/METADATA +0 -198
  43. mycorrhizal-0.1.0.dist-info/RECORD +0 -37
  44. /mycorrhizal/{enoki → septum}/__init__.py +0 -0
  45. {mycorrhizal-0.1.0.dist-info → mycorrhizal-0.2.0.dist-info}/WHEEL +0 -0
@@ -52,11 +52,13 @@ Multi-file Composition:
52
52
  from __future__ import annotations
53
53
 
54
54
  import asyncio
55
+ import contextvars
55
56
  import inspect
56
57
  import logging
57
58
  import traceback
58
59
  from dataclasses import dataclass, field
59
60
  from enum import Enum
61
+ from types import ModuleType, SimpleNamespace
60
62
  from typing import (
61
63
  Any,
62
64
  Callable,
@@ -64,60 +66,99 @@ from typing import (
64
66
  Generator,
65
67
  Generic,
66
68
  List,
69
+ Literal,
67
70
  Optional,
68
71
  Tuple,
72
+ Type,
69
73
  TypeVar,
70
74
  Union,
71
75
  Set,
72
76
  Protocol,
77
+ overload,
78
+ get_type_hints,
79
+ Sequence as SequenceT,
73
80
  )
74
- from typing import Sequence as SequenceT
75
- from types import SimpleNamespace
76
81
 
82
+ from mycorrhizal.common.wrappers import create_view_from_protocol
83
+ from mycorrhizal.common.compilation import (
84
+ _get_compiled_metadata,
85
+ _clear_compilation_cache,
86
+ CompiledMetadata,
87
+ )
77
88
  from mycorrhizal.common.timebase import *
78
89
 
79
90
  logger = logging.getLogger(__name__)
80
91
 
92
+ # Context variable for trace logger
93
+ _trace_logger_ctx: contextvars.ContextVar[Optional[logging.Logger]] = contextvars.ContextVar(
94
+ "_trace_logger_ctx", default=None
95
+ )
96
+
81
97
  BB = TypeVar("BB")
98
+ F = TypeVar("F", bound=Callable[..., Any])
82
99
 
83
100
 
84
101
  # ======================================================================================
85
- # Interface Integration Helper
102
+ # Interface View Caching
86
103
  # ======================================================================================
87
104
 
105
+ # Cache for interface views to avoid repeated creation
106
+ _interface_view_cache: Dict[Tuple[int, Type], Any] = {}
107
+
108
+
109
+ def _clear_interface_view_cache() -> None:
110
+ """Clear the interface view cache. Useful for testing."""
111
+ global _interface_view_cache
112
+ _interface_view_cache.clear()
113
+ # Also clear the compilation cache from common module
114
+ _clear_compilation_cache()
115
+
88
116
 
89
117
  def _create_interface_view_if_needed(bb: Any, func: Callable) -> Any:
90
118
  """
91
119
  Create a constrained view if the function has an interface type hint on its
92
- first parameter (typically named 'bb').
120
+ blackboard parameter.
93
121
 
94
122
  This enables type-safe, constrained access to blackboard state based on
95
123
  interface definitions created with @blackboard_interface.
96
124
 
125
+ The function signature can use an interface type:
126
+ async def my_action(bb: MyInterface) -> Status:
127
+ # bb is automatically a constrained view
128
+ return Status.SUCCESS
129
+
97
130
  Args:
98
131
  bb: The blackboard instance
99
132
  func: The function to check for interface type hints
100
133
 
101
134
  Returns:
102
135
  Either the original blackboard or a constrained view based on interface metadata
103
- """
104
- from typing import get_type_hints
105
-
106
- try:
107
- sig = inspect.signature(func)
108
- params = list(sig.parameters.values())
109
136
 
110
- # Check first parameter (usually 'bb')
111
- if params and params[0].name == 'bb':
112
- bb_type = get_type_hints(func).get('bb')
137
+ Raises:
138
+ TypeError: If func is not callable or type hints are malformed
139
+ AttributeError: If type hints reference undefined types
140
+ """
141
+ # Get compiled metadata (uses EAFP pattern internally)
142
+ # Raises specific exceptions if compilation fails
143
+ metadata = _get_compiled_metadata(func)
144
+
145
+ # If handler has interface type hint, create constrained view
146
+ if metadata.has_interface and metadata.interface_type:
147
+ # EAFP: Try to get view from cache, create if not present
148
+ cache_key = (id(bb), metadata.interface_type)
149
+ try:
150
+ return _interface_view_cache[cache_key]
151
+ except KeyError:
152
+ # Create view with pre-extracted interface metadata
153
+ view = create_view_from_protocol(
154
+ bb,
155
+ metadata.interface_type,
156
+ readonly_fields=metadata.readonly_fields
157
+ )
113
158
 
114
- # If type hint exists and has interface metadata
115
- if bb_type and hasattr(bb_type, '_readonly_fields'):
116
- from mycorrhizal.common.wrappers import create_view_from_protocol
117
- return create_view_from_protocol(bb, bb_type)
118
- except Exception:
119
- # If anything goes wrong with type inspection, fall back to original bb
120
- pass
159
+ # Cache for reuse
160
+ _interface_view_cache[cache_key] = view
161
+ return view
121
162
 
122
163
  return bb
123
164
 
@@ -136,17 +177,23 @@ def _supports_timebase(func: Callable) -> bool:
136
177
  return False
137
178
 
138
179
 
139
- async def _call_node_function(func: Callable, bb: Any, tb: Timebase) -> Any:
180
+ async def _call_node_function(func: Callable, bb: Any, tb: Timebase, supports_timebase: bool) -> Any:
140
181
  """
141
182
  Call a node function with appropriate parameters based on its signature.
142
183
 
143
184
  If the function has an interface type hint on its 'bb' parameter, a
144
185
  constrained view will be created automatically to enforce access control.
186
+
187
+ Args:
188
+ func: The function to call
189
+ bb: The blackboard
190
+ tb: The timebase
191
+ supports_timebase: Cached flag indicating if func accepts 'tb' parameter
145
192
  """
146
193
  # Create interface view if function has interface type hint
147
194
  bb_to_pass = _create_interface_view_if_needed(bb, func)
148
195
 
149
- if _supports_timebase(func):
196
+ if supports_timebase:
150
197
  if inspect.iscoroutinefunction(func):
151
198
  return await func(bb=bb_to_pass, tb=tb)
152
199
  else:
@@ -202,6 +249,24 @@ def _name_of(obj: Any) -> str:
202
249
  return f"{obj.__class__.__name__}@{id(obj):x}"
203
250
 
204
251
 
252
+ def _fully_qualified_name(func: Callable[..., Any]) -> str:
253
+ """
254
+ Get the fully qualified name of a function.
255
+
256
+ Returns module.function_name if the function has a module,
257
+ otherwise returns just the function name.
258
+ """
259
+ name = func.__name__
260
+ module = getattr(func, "__module__", None)
261
+ if module:
262
+ # Handle nested functions by trying to get qualname
263
+ qualname = getattr(func, "__qualname__", None)
264
+ if qualname and qualname != name:
265
+ return f"{module}.{qualname}"
266
+ return f"{module}.{name}"
267
+ return name
268
+
269
+
205
270
  # ======================================================================================
206
271
  # Recursion Detection
207
272
  # ======================================================================================
@@ -315,11 +380,15 @@ class Action(Node[BB]):
315
380
  ) -> None:
316
381
  super().__init__(name=_name_of(func), exception_policy=exception_policy)
317
382
  self._func = func
383
+ # Cache whether function supports timebase parameter (checked during construction)
384
+ self._supports_timebase = _supports_timebase(func)
385
+ # Cache fully qualified name for tracing
386
+ self._fq_name = _fully_qualified_name(func)
318
387
 
319
388
  async def tick(self, bb: BB, tb: Timebase) -> Status:
320
389
  await self._ensure_entered(bb, tb)
321
390
  try:
322
- result = await _call_node_function(self._func, bb, tb)
391
+ result = await _call_node_function(self._func, bb, tb, self._supports_timebase)
323
392
  except asyncio.CancelledError:
324
393
  return await self._finish(bb, Status.CANCELLED, tb)
325
394
  except Exception as e:
@@ -330,13 +399,20 @@ class Action(Node[BB]):
330
399
  raise
331
400
  return await self._finish(bb, Status.ERROR, tb)
332
401
 
402
+ # Determine final status
333
403
  if isinstance(result, Status):
334
- return await self._finish(bb, result, tb)
335
- if isinstance(result, bool):
336
- return await self._finish(
337
- bb, Status.SUCCESS if result else Status.FAILURE, tb
338
- )
339
- return await self._finish(bb, Status.SUCCESS, tb)
404
+ final_status = result
405
+ elif isinstance(result, bool):
406
+ final_status = Status.SUCCESS if result else Status.FAILURE
407
+ else:
408
+ final_status = Status.SUCCESS
409
+
410
+ # Log trace if enabled
411
+ trace_logger = _trace_logger_ctx.get()
412
+ if trace_logger is not None:
413
+ trace_logger.info(f"action: {self._fq_name} | {final_status.name}")
414
+
415
+ return await self._finish(bb, final_status, tb)
340
416
 
341
417
 
342
418
  class Condition(Action[BB]):
@@ -345,7 +421,7 @@ class Condition(Action[BB]):
345
421
  async def tick(self, bb: BB, tb: Timebase) -> Status:
346
422
  await self._ensure_entered(bb, tb)
347
423
  try:
348
- result = await _call_node_function(self._func, bb, tb)
424
+ result = await _call_node_function(self._func, bb, tb, self._supports_timebase)
349
425
  except asyncio.CancelledError:
350
426
  return await self._finish(bb, Status.CANCELLED, tb)
351
427
  except Exception as e:
@@ -356,11 +432,18 @@ class Condition(Action[BB]):
356
432
  raise
357
433
  return await self._finish(bb, Status.ERROR, tb)
358
434
 
435
+ # Determine final status
359
436
  if isinstance(result, Status):
360
- return await self._finish(bb, result, tb)
361
- return await self._finish(
362
- bb, Status.SUCCESS if bool(result) else Status.FAILURE, tb
363
- )
437
+ final_status = result
438
+ else:
439
+ final_status = Status.SUCCESS if bool(result) else Status.FAILURE
440
+
441
+ # Log trace if enabled
442
+ trace_logger = _trace_logger_ctx.get()
443
+ if trace_logger is not None:
444
+ trace_logger.info(f"condition: {self._fq_name} | {final_status.name}")
445
+
446
+ return await self._finish(bb, final_status, tb)
364
447
 
365
448
 
366
449
  # ======================================================================================
@@ -799,6 +882,8 @@ class Match(Node[BB]):
799
882
  for _, child in self._cases:
800
883
  child.parent = self
801
884
  self._matched_idx: Optional[int] = None
885
+ # Cache whether key_fn supports timebase parameter
886
+ self._key_fn_supports_timebase = _supports_timebase(key_fn)
802
887
 
803
888
  def reset(self) -> None:
804
889
  super().reset()
@@ -818,7 +903,7 @@ class Match(Node[BB]):
818
903
 
819
904
  async def tick(self, bb: BB, tb: Timebase) -> Status:
820
905
  await self._ensure_entered(bb, tb)
821
-
906
+
822
907
  if self._matched_idx is not None:
823
908
  _, child = self._cases[self._matched_idx]
824
909
  st = await child.tick(bb, tb)
@@ -826,9 +911,9 @@ class Match(Node[BB]):
826
911
  return Status.RUNNING
827
912
  self._matched_idx = None
828
913
  return await self._finish(bb, st, tb)
829
-
830
- value = await _call_node_function(self._key_fn, bb, tb)
831
-
914
+
915
+ value = await _call_node_function(self._key_fn, bb, tb, self._key_fn_supports_timebase)
916
+
832
917
  for i, (matcher, child) in enumerate(self._cases):
833
918
  if self._matches(matcher, value):
834
919
  st = await child.tick(bb, tb)
@@ -836,7 +921,7 @@ class Match(Node[BB]):
836
921
  self._matched_idx = i
837
922
  return Status.RUNNING
838
923
  return await self._finish(bb, st, tb)
839
-
924
+
840
925
  return await self._finish(bb, Status.FAILURE, tb)
841
926
 
842
927
 
@@ -971,7 +1056,6 @@ class NodeSpec:
971
1056
 
972
1057
  def to_node(
973
1058
  self,
974
- owner: Optional[Any] = None,
975
1059
  exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
976
1060
  ) -> Node[Any]:
977
1061
  match self.kind:
@@ -980,12 +1064,11 @@ class NodeSpec:
980
1064
  case NodeSpecKind.CONDITION:
981
1065
  return Condition(self.payload, exception_policy=exception_policy)
982
1066
  case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
983
- owner_effective = getattr(self, "owner", None) or owner
984
1067
  factory = self.payload["factory"]
985
- expanded = _bt_expand_children(factory, owner_effective)
1068
+ expanded = _bt_expand_children(factory)
986
1069
  self.children = expanded
987
1070
  built = [
988
- ch.to_node(owner_effective, exception_policy) for ch in expanded
1071
+ ch.to_node(exception_policy) for ch in expanded
989
1072
  ]
990
1073
  match self.kind:
991
1074
  case NodeSpecKind.SEQUENCE:
@@ -1013,7 +1096,7 @@ class NodeSpec:
1013
1096
 
1014
1097
  case NodeSpecKind.DECORATOR:
1015
1098
  assert len(self.children) == 1, "Decorator must wrap exactly one child"
1016
- child_node = self.children[0].to_node(owner, exception_policy)
1099
+ child_node = self.children[0].to_node(exception_policy)
1017
1100
  builder = self.payload
1018
1101
  return builder(child_node)
1019
1102
 
@@ -1025,7 +1108,7 @@ class NodeSpec:
1025
1108
  key_fn = self.payload["key_fn"]
1026
1109
  case_specs: List[CaseSpec] = self.payload["cases"]
1027
1110
  cases = [
1028
- (cs.matcher, cs.child.to_node(owner, exception_policy))
1111
+ (cs.matcher, cs.child.to_node(exception_policy))
1029
1112
  for cs in case_specs
1030
1113
  ]
1031
1114
  return Match(
@@ -1039,8 +1122,8 @@ class NodeSpec:
1039
1122
  cond_spec = self.payload["condition"]
1040
1123
  child_spec = self.children[0]
1041
1124
  return DoWhile(
1042
- cond_spec.to_node(owner, exception_policy),
1043
- child_spec.to_node(owner, exception_policy),
1125
+ cond_spec.to_node(exception_policy),
1126
+ child_spec.to_node(exception_policy),
1044
1127
  name=self.name,
1045
1128
  exception_policy=exception_policy,
1046
1129
  )
@@ -1051,7 +1134,6 @@ class NodeSpec:
1051
1134
 
1052
1135
  def _bt_expand_children(
1053
1136
  factory: Callable[..., Generator[Any, None, None]],
1054
- owner: Optional[Any],
1055
1137
  expansion_stack: Optional[Set[str]] = None,
1056
1138
  ) -> List[NodeSpec]:
1057
1139
  """
@@ -1059,7 +1141,6 @@ def _bt_expand_children(
1059
1141
 
1060
1142
  Args:
1061
1143
  factory: The generator function that yields child specs
1062
- owner: The namespace object for resolving N references
1063
1144
  expansion_stack: Stack of factory names to detect recursion
1064
1145
  """
1065
1146
  if expansion_stack is None:
@@ -1080,10 +1161,7 @@ def _bt_expand_children(
1080
1161
  expansion_stack = expansion_stack.copy()
1081
1162
  expansion_stack.add(factory_name)
1082
1163
 
1083
- try:
1084
- gen = factory(owner)
1085
- except TypeError:
1086
- gen = factory()
1164
+ gen = factory()
1087
1165
 
1088
1166
  if not inspect.isgenerator(gen):
1089
1167
  raise TypeError(
@@ -1119,8 +1197,7 @@ def _bt_expand_children(
1119
1197
  NodeSpecKind.PARALLEL,
1120
1198
  ) and hasattr(spec, "_expansion_stack"):
1121
1199
  child_factory = spec.payload["factory"]
1122
- child_owner = getattr(spec, "owner", None) or owner
1123
- _bt_expand_children(child_factory, child_owner, spec._expansion_stack) # type: ignore
1200
+ _bt_expand_children(child_factory, spec._expansion_stack) # type: ignore
1124
1201
 
1125
1202
  return out
1126
1203
 
@@ -1309,12 +1386,78 @@ class _BT:
1309
1386
  self._tracking_stack[-1].append((fn.__name__, fn))
1310
1387
  return fn
1311
1388
 
1312
- def sequence(self, *, memory: bool = True):
1313
- """Decorator to mark a generator function as a sequence composite."""
1389
+ def sequence(
1390
+ self, *args: Union[F, NodeSpec, Callable[[Any], Any]], memory: bool = True
1391
+ ) -> Union[F, Callable[[F], F], NodeSpec]:
1392
+ """Decorator to mark a generator function as a sequence composite.
1393
+
1394
+ Can be used in three ways:
1395
+ 1. Decorator without parentheses:
1396
+ @bt.sequence
1397
+ def root():
1398
+ ...
1399
+
1400
+ 2. Decorator with parameters:
1401
+ @bt.sequence(memory=False)
1402
+ def root():
1403
+ ...
1404
+
1405
+ 3. Direct call with child nodes:
1406
+ bt.sequence(action1, action2, action3)
1407
+ """
1408
+ # Case 3: Direct call with children - bt.sequence(node1, node2, ...)
1409
+ # This is detected when we have multiple args, or a single arg that's not a generator function
1410
+ if len(args) == 0:
1411
+ # Case 2a: bt.sequence() with no arguments - return decorator
1412
+ return self._sequence_impl(memory=memory)
1413
+
1414
+ if len(args) > 1:
1415
+ # Multiple children - create sequence directly
1416
+ return self._sequence_from_children(args, memory)
1417
+
1418
+ # Single argument - check if it's a generator function (decorator case) or a node (direct call)
1419
+ single_arg = args[0]
1420
+
1421
+ # Check if it's a generator function (decorator form)
1422
+ if inspect.isfunction(single_arg) and inspect.isgeneratorfunction(single_arg):
1423
+ # Case 1: @bt.sequence def root(): ...
1424
+ spec = NodeSpec(
1425
+ kind=NodeSpecKind.SEQUENCE,
1426
+ name=_name_of(single_arg),
1427
+ payload={"factory": single_arg, "memory": memory},
1428
+ )
1429
+ single_arg.node_spec = spec # type: ignore
1430
+ if self._tracking_stack:
1431
+ self._tracking_stack[-1].append((single_arg.__name__, single_arg))
1432
+ return single_arg
1433
+
1434
+ # Case 3b: Single child node - bt.sequence(action1)
1435
+ return self._sequence_from_children(args, memory)
1436
+
1437
+ def _sequence_from_children(self, children: Tuple[Any, ...], memory: bool) -> NodeSpec:
1438
+ """Create a sequence NodeSpec from child nodes."""
1439
+ # Create a uniquely named factory to avoid false recursion detection
1440
+ child_names = ', '.join(_name_of(c) for c in children)
1441
+
1442
+ def _sequence_factory_direct() -> Generator[Any, None, None]:
1443
+ for child in children:
1444
+ yield child
1445
+
1446
+ # Set a unique name for the factory function
1447
+ _sequence_factory_direct.__name__ = f"_sequence_factory_direct_{id(children)}"
1448
+ _sequence_factory_direct.__qualname__ = f"_sequence_factory_direct_{id(children)}"
1314
1449
 
1315
- def deco(
1316
- factory: Callable[..., Generator[Any, None, None]],
1317
- ) -> Callable[..., Generator[Any, None, None]]:
1450
+ name = f"Sequence({child_names})" if children else "Sequence"
1451
+
1452
+ return NodeSpec(
1453
+ kind=NodeSpecKind.SEQUENCE,
1454
+ name=name,
1455
+ payload={"factory": _sequence_factory_direct, "memory": memory},
1456
+ )
1457
+
1458
+ def _sequence_impl(self, memory: bool) -> Callable[[F], F]:
1459
+ """Implementation of sequence decorator."""
1460
+ def deco(factory: F) -> F:
1318
1461
  spec = NodeSpec(
1319
1462
  kind=NodeSpecKind.SEQUENCE,
1320
1463
  name=_name_of(factory),
@@ -1327,16 +1470,82 @@ class _BT:
1327
1470
 
1328
1471
  return deco
1329
1472
 
1330
- def selector(self, *, memory: bool = True, reactive: bool = False):
1331
- """Decorator to mark a generator function as a selector composite."""
1473
+ def selector(
1474
+ self, *args: Union[F, NodeSpec, Callable[[Any], Any]], memory: bool = True
1475
+ ) -> Union[F, Callable[[F], F], NodeSpec]:
1476
+ """Decorator to mark a generator function as a selector composite.
1477
+
1478
+ Can be used in three ways:
1479
+ 1. Decorator without parentheses:
1480
+ @bt.selector
1481
+ def root():
1482
+ ...
1483
+
1484
+ 2. Decorator with parameters:
1485
+ @bt.selector(memory=False)
1486
+ def root():
1487
+ ...
1488
+
1489
+ 3. Direct call with child nodes:
1490
+ bt.selector(option1, option2, option3)
1491
+ """
1492
+ # Case 3: Direct call with children - bt.selector(node1, node2, ...)
1493
+ # This is detected when we have multiple args, or a single arg that's not a generator function
1494
+ if len(args) == 0:
1495
+ # Case 2a: bt.selector() with no arguments - return decorator
1496
+ return self._selector_impl(memory=memory)
1497
+
1498
+ if len(args) > 1:
1499
+ # Multiple children - create selector directly
1500
+ return self._selector_from_children(args, memory)
1501
+
1502
+ # Single argument - check if it's a generator function (decorator case) or a node (direct call)
1503
+ single_arg = args[0]
1504
+
1505
+ # Check if it's a generator function (decorator form)
1506
+ if inspect.isfunction(single_arg) and inspect.isgeneratorfunction(single_arg):
1507
+ # Case 1: @bt.selector def root(): ...
1508
+ spec = NodeSpec(
1509
+ kind=NodeSpecKind.SELECTOR,
1510
+ name=_name_of(single_arg),
1511
+ payload={"factory": single_arg, "memory": memory},
1512
+ )
1513
+ single_arg.node_spec = spec # type: ignore
1514
+ if self._tracking_stack:
1515
+ self._tracking_stack[-1].append((single_arg.__name__, single_arg))
1516
+ return single_arg
1517
+
1518
+ # Case 3b: Single child node - bt.selector(action1)
1519
+ return self._selector_from_children(args, memory)
1520
+
1521
+ def _selector_from_children(self, children: Tuple[Any, ...], memory: bool) -> NodeSpec:
1522
+ """Create a selector NodeSpec from child nodes."""
1523
+ # Create a uniquely named factory to avoid false recursion detection
1524
+ child_names = ', '.join(_name_of(c) for c in children)
1525
+
1526
+ def _selector_factory_direct() -> Generator[Any, None, None]:
1527
+ for child in children:
1528
+ yield child
1529
+
1530
+ # Set a unique name for the factory function
1531
+ _selector_factory_direct.__name__ = f"_selector_factory_direct_{id(children)}"
1532
+ _selector_factory_direct.__qualname__ = f"_selector_factory_direct_{id(children)}"
1533
+
1534
+ name = f"Selector({child_names})" if children else "Selector"
1535
+
1536
+ return NodeSpec(
1537
+ kind=NodeSpecKind.SELECTOR,
1538
+ name=name,
1539
+ payload={"factory": _selector_factory_direct, "memory": memory},
1540
+ )
1332
1541
 
1333
- def deco(
1334
- factory: Callable[..., Generator[Any, None, None]],
1335
- ) -> Callable[..., Generator[Any, None, None]]:
1542
+ def _selector_impl(self, memory: bool) -> Callable[[F], F]:
1543
+ """Implementation of selector decorator."""
1544
+ def deco(factory: F) -> F:
1336
1545
  spec = NodeSpec(
1337
1546
  kind=NodeSpecKind.SELECTOR,
1338
1547
  name=_name_of(factory),
1339
- payload={"factory": factory, "memory": memory, "reactive": reactive},
1548
+ payload={"factory": factory, "memory": memory},
1340
1549
  )
1341
1550
  factory.node_spec = spec # type: ignore
1342
1551
  if self._tracking_stack:
@@ -1347,12 +1556,10 @@ class _BT:
1347
1556
 
1348
1557
  def parallel(
1349
1558
  self, *, success_threshold: int, failure_threshold: Optional[int] = None
1350
- ):
1559
+ ) -> Callable[[F], F]:
1351
1560
  """Decorator to mark a generator function as a parallel composite."""
1352
1561
 
1353
- def deco(
1354
- factory: Callable[..., Generator[Any, None, None]],
1355
- ) -> Callable[..., Generator[Any, None, None]]:
1562
+ def deco(factory: F) -> F:
1356
1563
  spec = NodeSpec(
1357
1564
  kind=NodeSpecKind.PARALLEL,
1358
1565
  name=_name_of(factory),
@@ -1511,8 +1718,8 @@ class _BT:
1511
1718
  ...
1512
1719
 
1513
1720
  @bt.sequence()
1514
- def root(N):
1515
- yield N.my_action
1721
+ def root():
1722
+ yield my_action
1516
1723
  """
1517
1724
  created_nodes = []
1518
1725
  self._tracking_stack.append(created_nodes)
@@ -1526,7 +1733,7 @@ class _BT:
1526
1733
  name: node for name, node in created_nodes if hasattr(node, "node_spec")
1527
1734
  }
1528
1735
  namespace = SimpleNamespace(**nodes)
1529
-
1736
+
1530
1737
  # Store the tree's name for use in subtree references
1531
1738
  namespace._tree_name = fn.__name__
1532
1739
 
@@ -1542,10 +1749,13 @@ class _BT:
1542
1749
 
1543
1750
  return namespace
1544
1751
 
1545
- def root(
1546
- self, fn: Callable[..., Generator[Any, None, None]]
1547
- ) -> Callable[..., Generator[Any, None, None]]:
1752
+ def root(self, fn: F) -> F:
1548
1753
  """Mark a composite as the root of the tree."""
1754
+ if not hasattr(fn, "node_spec"):
1755
+ raise TypeError(
1756
+ f"@bt.root can only be used on composites (sequences, selectors, etc.), "
1757
+ f"got {fn!r}"
1758
+ )
1549
1759
  fn.node_spec.is_root = True # type: ignore
1550
1760
  return fn
1551
1761
 
@@ -1609,9 +1819,8 @@ def _generate_mermaid(tree: SimpleNamespace) -> str:
1609
1819
  def ensure_children(spec: NodeSpec) -> List[NodeSpec]:
1610
1820
  match spec.kind:
1611
1821
  case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
1612
- owner = getattr(spec, "owner", None) or tree
1613
1822
  factory = spec.payload["factory"]
1614
- spec.children = _bt_expand_children(factory, owner)
1823
+ spec.children = _bt_expand_children(factory)
1615
1824
  case NodeSpecKind.SUBTREE:
1616
1825
  subtree_root = spec.payload["root"]
1617
1826
  spec.children = [subtree_root]
@@ -1677,6 +1886,7 @@ class Runner(Generic[BB]):
1677
1886
  bb: Blackboard containing shared state
1678
1887
  tb: Optional timebase for time management (defaults to MonotonicClock)
1679
1888
  exception_policy: How to handle exceptions during tree execution
1889
+ trace: Optional logger instance for tracing action/condition execution
1680
1890
 
1681
1891
  Methods:
1682
1892
  tick(): Execute one tick of the behavior tree
@@ -1693,6 +1903,14 @@ class Runner(Generic[BB]):
1693
1903
 
1694
1904
  runner = Runner(MyTree, bb=blackboard)
1695
1905
  result = await runner.tick_until_complete()
1906
+
1907
+ Tracing:
1908
+ import logging
1909
+ trace_logger = logging.getLogger("bt.trace")
1910
+ runner = Runner(MyTree, bb=blackboard, trace=trace_logger)
1911
+ result = await runner.tick_until_complete()
1912
+ # Logs: "action: module.do_work | SUCCESS"
1913
+ # "condition: module.check_condition | SUCCESS"
1696
1914
  """
1697
1915
  def __init__(
1698
1916
  self,
@@ -1700,23 +1918,31 @@ class Runner(Generic[BB]):
1700
1918
  bb: BB,
1701
1919
  tb: Optional[Timebase] = None,
1702
1920
  exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
1921
+ trace: Optional[logging.Logger] = None,
1703
1922
  ) -> None:
1704
1923
  self.tree = tree
1705
1924
  self.bb: BB = bb
1706
1925
  self.tb = tb or MonotonicClock()
1707
1926
  self.exception_policy = exception_policy
1927
+ self.trace = trace
1708
1928
 
1709
1929
  if not hasattr(tree, "root"):
1710
1930
  raise ValueError("Tree namespace must have a 'root' attribute")
1711
1931
 
1712
1932
  self.root: Node[BB] = tree.root.to_node(
1713
- owner=tree, exception_policy=exception_policy
1933
+ exception_policy=exception_policy
1714
1934
  )
1715
1935
 
1716
1936
  async def tick(self) -> Status:
1717
- result = await self.root.tick(self.bb, self.tb)
1718
- self.tb.advance()
1719
- return result
1937
+ # Set trace logger in context for this tick
1938
+ token = _trace_logger_ctx.set(self.trace)
1939
+ try:
1940
+ result = await self.root.tick(self.bb, self.tb)
1941
+ self.tb.advance()
1942
+ return result
1943
+ finally:
1944
+ # Clear the context variable
1945
+ _trace_logger_ctx.reset(token)
1720
1946
 
1721
1947
  async def tick_until_complete(self, timeout: Optional[float] = None) -> Status:
1722
1948
  start = self.tb.now()