mycorrhizal 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. mycorrhizal/_version.py +1 -1
  2. mycorrhizal/common/__init__.py +15 -3
  3. mycorrhizal/common/cache.py +114 -0
  4. mycorrhizal/common/compilation.py +263 -0
  5. mycorrhizal/common/interface_detection.py +159 -0
  6. mycorrhizal/common/interfaces.py +3 -50
  7. mycorrhizal/common/mermaid.py +124 -0
  8. mycorrhizal/common/wrappers.py +1 -1
  9. mycorrhizal/hypha/core/builder.py +11 -1
  10. mycorrhizal/hypha/core/runtime.py +242 -107
  11. mycorrhizal/mycelium/__init__.py +174 -0
  12. mycorrhizal/mycelium/core.py +619 -0
  13. mycorrhizal/mycelium/exceptions.py +30 -0
  14. mycorrhizal/mycelium/hypha_bridge.py +1143 -0
  15. mycorrhizal/mycelium/instance.py +440 -0
  16. mycorrhizal/mycelium/pn_context.py +276 -0
  17. mycorrhizal/mycelium/runner.py +165 -0
  18. mycorrhizal/mycelium/spores_integration.py +655 -0
  19. mycorrhizal/mycelium/tree_builder.py +102 -0
  20. mycorrhizal/mycelium/tree_spec.py +197 -0
  21. mycorrhizal/rhizomorph/README.md +82 -33
  22. mycorrhizal/rhizomorph/core.py +287 -119
  23. mycorrhizal/septum/TRANSITION_REFERENCE.md +385 -0
  24. mycorrhizal/{enoki → septum}/core.py +326 -100
  25. mycorrhizal/{enoki → septum}/testing_utils.py +7 -7
  26. mycorrhizal/{enoki → septum}/util.py +44 -21
  27. mycorrhizal/spores/__init__.py +3 -3
  28. mycorrhizal/spores/core.py +149 -28
  29. mycorrhizal/spores/dsl/__init__.py +8 -8
  30. mycorrhizal/spores/dsl/hypha.py +3 -15
  31. mycorrhizal/spores/dsl/rhizomorph.py +3 -11
  32. mycorrhizal/spores/dsl/{enoki.py → septum.py} +26 -77
  33. mycorrhizal/spores/encoder/json.py +21 -12
  34. mycorrhizal/spores/extraction.py +14 -11
  35. mycorrhizal/spores/models.py +53 -20
  36. mycorrhizal-0.2.1.dist-info/METADATA +335 -0
  37. mycorrhizal-0.2.1.dist-info/RECORD +54 -0
  38. mycorrhizal-0.1.2.dist-info/METADATA +0 -198
  39. mycorrhizal-0.1.2.dist-info/RECORD +0 -39
  40. /mycorrhizal/{enoki → septum}/__init__.py +0 -0
  41. {mycorrhizal-0.1.2.dist-info → mycorrhizal-0.2.1.dist-info}/WHEEL +0 -0
@@ -52,11 +52,13 @@ Multi-file Composition:
52
52
  from __future__ import annotations
53
53
 
54
54
  import asyncio
55
+ import contextvars
55
56
  import inspect
56
57
  import logging
57
58
  import traceback
58
59
  from dataclasses import dataclass, field
59
60
  from enum import Enum
61
+ from types import ModuleType, SimpleNamespace
60
62
  from typing import (
61
63
  Any,
62
64
  Callable,
@@ -64,60 +66,99 @@ from typing import (
64
66
  Generator,
65
67
  Generic,
66
68
  List,
69
+ Literal,
67
70
  Optional,
68
71
  Tuple,
72
+ Type,
69
73
  TypeVar,
70
74
  Union,
71
75
  Set,
72
76
  Protocol,
77
+ overload,
78
+ get_type_hints,
79
+ Sequence as SequenceT,
73
80
  )
74
- from typing import Sequence as SequenceT
75
- from types import SimpleNamespace
76
81
 
82
+ from mycorrhizal.common.wrappers import create_view_from_protocol
83
+ from mycorrhizal.common.compilation import (
84
+ _get_compiled_metadata,
85
+ _clear_compilation_cache,
86
+ CompiledMetadata,
87
+ )
77
88
  from mycorrhizal.common.timebase import *
78
89
 
79
90
  logger = logging.getLogger(__name__)
80
91
 
92
+ # Context variable for trace logger
93
+ _trace_logger_ctx: contextvars.ContextVar[Optional[logging.Logger]] = contextvars.ContextVar(
94
+ "_trace_logger_ctx", default=None
95
+ )
96
+
81
97
  BB = TypeVar("BB")
98
+ F = TypeVar("F", bound=Callable[..., Any])
82
99
 
83
100
 
84
101
  # ======================================================================================
85
- # Interface Integration Helper
102
+ # Interface View Caching
86
103
  # ======================================================================================
87
104
 
105
+ # Cache for interface views to avoid repeated creation
106
+ _interface_view_cache: Dict[Tuple[int, Type], Any] = {}
107
+
108
+
109
+ def _clear_interface_view_cache() -> None:
110
+ """Clear the interface view cache. Useful for testing."""
111
+ global _interface_view_cache
112
+ _interface_view_cache.clear()
113
+ # Also clear the compilation cache from common module
114
+ _clear_compilation_cache()
115
+
88
116
 
89
117
  def _create_interface_view_if_needed(bb: Any, func: Callable) -> Any:
90
118
  """
91
119
  Create a constrained view if the function has an interface type hint on its
92
- first parameter (typically named 'bb').
120
+ blackboard parameter.
93
121
 
94
122
  This enables type-safe, constrained access to blackboard state based on
95
123
  interface definitions created with @blackboard_interface.
96
124
 
125
+ The function signature can use an interface type:
126
+ async def my_action(bb: MyInterface) -> Status:
127
+ # bb is automatically a constrained view
128
+ return Status.SUCCESS
129
+
97
130
  Args:
98
131
  bb: The blackboard instance
99
132
  func: The function to check for interface type hints
100
133
 
101
134
  Returns:
102
135
  Either the original blackboard or a constrained view based on interface metadata
103
- """
104
- from typing import get_type_hints
105
-
106
- try:
107
- sig = inspect.signature(func)
108
- params = list(sig.parameters.values())
109
136
 
110
- # Check first parameter (usually 'bb')
111
- if params and params[0].name == 'bb':
112
- bb_type = get_type_hints(func).get('bb')
137
+ Raises:
138
+ TypeError: If func is not callable or type hints are malformed
139
+ AttributeError: If type hints reference undefined types
140
+ """
141
+ # Get compiled metadata (uses EAFP pattern internally)
142
+ # Raises specific exceptions if compilation fails
143
+ metadata = _get_compiled_metadata(func)
144
+
145
+ # If handler has interface type hint, create constrained view
146
+ if metadata.has_interface and metadata.interface_type:
147
+ # EAFP: Try to get view from cache, create if not present
148
+ cache_key = (id(bb), metadata.interface_type)
149
+ try:
150
+ return _interface_view_cache[cache_key]
151
+ except KeyError:
152
+ # Create view with pre-extracted interface metadata
153
+ view = create_view_from_protocol(
154
+ bb,
155
+ metadata.interface_type,
156
+ readonly_fields=metadata.readonly_fields
157
+ )
113
158
 
114
- # If type hint exists and has interface metadata
115
- if bb_type and hasattr(bb_type, '_readonly_fields'):
116
- from mycorrhizal.common.wrappers import create_view_from_protocol
117
- return create_view_from_protocol(bb, bb_type)
118
- except Exception:
119
- # If anything goes wrong with type inspection, fall back to original bb
120
- pass
159
+ # Cache for reuse
160
+ _interface_view_cache[cache_key] = view
161
+ return view
121
162
 
122
163
  return bb
123
164
 
@@ -136,17 +177,23 @@ def _supports_timebase(func: Callable) -> bool:
136
177
  return False
137
178
 
138
179
 
139
- async def _call_node_function(func: Callable, bb: Any, tb: Timebase) -> Any:
180
+ async def _call_node_function(func: Callable, bb: Any, tb: Timebase, supports_timebase: bool) -> Any:
140
181
  """
141
182
  Call a node function with appropriate parameters based on its signature.
142
183
 
143
184
  If the function has an interface type hint on its 'bb' parameter, a
144
185
  constrained view will be created automatically to enforce access control.
186
+
187
+ Args:
188
+ func: The function to call
189
+ bb: The blackboard
190
+ tb: The timebase
191
+ supports_timebase: Cached flag indicating if func accepts 'tb' parameter
145
192
  """
146
193
  # Create interface view if function has interface type hint
147
194
  bb_to_pass = _create_interface_view_if_needed(bb, func)
148
195
 
149
- if _supports_timebase(func):
196
+ if supports_timebase:
150
197
  if inspect.iscoroutinefunction(func):
151
198
  return await func(bb=bb_to_pass, tb=tb)
152
199
  else:
@@ -202,6 +249,24 @@ def _name_of(obj: Any) -> str:
202
249
  return f"{obj.__class__.__name__}@{id(obj):x}"
203
250
 
204
251
 
252
+ def _fully_qualified_name(func: Callable[..., Any]) -> str:
253
+ """
254
+ Get the fully qualified name of a function.
255
+
256
+ Returns module.function_name if the function has a module,
257
+ otherwise returns just the function name.
258
+ """
259
+ name = func.__name__
260
+ module = getattr(func, "__module__", None)
261
+ if module:
262
+ # Handle nested functions by trying to get qualname
263
+ qualname = getattr(func, "__qualname__", None)
264
+ if qualname and qualname != name:
265
+ return f"{module}.{qualname}"
266
+ return f"{module}.{name}"
267
+ return name
268
+
269
+
205
270
  # ======================================================================================
206
271
  # Recursion Detection
207
272
  # ======================================================================================
@@ -315,11 +380,15 @@ class Action(Node[BB]):
315
380
  ) -> None:
316
381
  super().__init__(name=_name_of(func), exception_policy=exception_policy)
317
382
  self._func = func
383
+ # Cache whether function supports timebase parameter (checked during construction)
384
+ self._supports_timebase = _supports_timebase(func)
385
+ # Cache fully qualified name for tracing
386
+ self._fq_name = _fully_qualified_name(func)
318
387
 
319
388
  async def tick(self, bb: BB, tb: Timebase) -> Status:
320
389
  await self._ensure_entered(bb, tb)
321
390
  try:
322
- result = await _call_node_function(self._func, bb, tb)
391
+ result = await _call_node_function(self._func, bb, tb, self._supports_timebase)
323
392
  except asyncio.CancelledError:
324
393
  return await self._finish(bb, Status.CANCELLED, tb)
325
394
  except Exception as e:
@@ -330,13 +399,20 @@ class Action(Node[BB]):
330
399
  raise
331
400
  return await self._finish(bb, Status.ERROR, tb)
332
401
 
402
+ # Determine final status
333
403
  if isinstance(result, Status):
334
- return await self._finish(bb, result, tb)
335
- if isinstance(result, bool):
336
- return await self._finish(
337
- bb, Status.SUCCESS if result else Status.FAILURE, tb
338
- )
339
- return await self._finish(bb, Status.SUCCESS, tb)
404
+ final_status = result
405
+ elif isinstance(result, bool):
406
+ final_status = Status.SUCCESS if result else Status.FAILURE
407
+ else:
408
+ final_status = Status.SUCCESS
409
+
410
+ # Log trace if enabled
411
+ trace_logger = _trace_logger_ctx.get()
412
+ if trace_logger is not None:
413
+ trace_logger.debug(f"action: {self._fq_name} | {final_status.name}")
414
+
415
+ return await self._finish(bb, final_status, tb)
340
416
 
341
417
 
342
418
  class Condition(Action[BB]):
@@ -345,7 +421,7 @@ class Condition(Action[BB]):
345
421
  async def tick(self, bb: BB, tb: Timebase) -> Status:
346
422
  await self._ensure_entered(bb, tb)
347
423
  try:
348
- result = await _call_node_function(self._func, bb, tb)
424
+ result = await _call_node_function(self._func, bb, tb, self._supports_timebase)
349
425
  except asyncio.CancelledError:
350
426
  return await self._finish(bb, Status.CANCELLED, tb)
351
427
  except Exception as e:
@@ -356,11 +432,18 @@ class Condition(Action[BB]):
356
432
  raise
357
433
  return await self._finish(bb, Status.ERROR, tb)
358
434
 
435
+ # Determine final status
359
436
  if isinstance(result, Status):
360
- return await self._finish(bb, result, tb)
361
- return await self._finish(
362
- bb, Status.SUCCESS if bool(result) else Status.FAILURE, tb
363
- )
437
+ final_status = result
438
+ else:
439
+ final_status = Status.SUCCESS if bool(result) else Status.FAILURE
440
+
441
+ # Log trace if enabled
442
+ trace_logger = _trace_logger_ctx.get()
443
+ if trace_logger is not None:
444
+ trace_logger.debug(f"condition: {self._fq_name} | {final_status.name}")
445
+
446
+ return await self._finish(bb, final_status, tb)
364
447
 
365
448
 
366
449
  # ======================================================================================
@@ -799,6 +882,8 @@ class Match(Node[BB]):
799
882
  for _, child in self._cases:
800
883
  child.parent = self
801
884
  self._matched_idx: Optional[int] = None
885
+ # Cache whether key_fn supports timebase parameter
886
+ self._key_fn_supports_timebase = _supports_timebase(key_fn)
802
887
 
803
888
  def reset(self) -> None:
804
889
  super().reset()
@@ -818,7 +903,7 @@ class Match(Node[BB]):
818
903
 
819
904
  async def tick(self, bb: BB, tb: Timebase) -> Status:
820
905
  await self._ensure_entered(bb, tb)
821
-
906
+
822
907
  if self._matched_idx is not None:
823
908
  _, child = self._cases[self._matched_idx]
824
909
  st = await child.tick(bb, tb)
@@ -826,9 +911,9 @@ class Match(Node[BB]):
826
911
  return Status.RUNNING
827
912
  self._matched_idx = None
828
913
  return await self._finish(bb, st, tb)
829
-
830
- value = await _call_node_function(self._key_fn, bb, tb)
831
-
914
+
915
+ value = await _call_node_function(self._key_fn, bb, tb, self._key_fn_supports_timebase)
916
+
832
917
  for i, (matcher, child) in enumerate(self._cases):
833
918
  if self._matches(matcher, value):
834
919
  st = await child.tick(bb, tb)
@@ -836,7 +921,7 @@ class Match(Node[BB]):
836
921
  self._matched_idx = i
837
922
  return Status.RUNNING
838
923
  return await self._finish(bb, st, tb)
839
-
924
+
840
925
  return await self._finish(bb, Status.FAILURE, tb)
841
926
 
842
927
 
@@ -971,7 +1056,6 @@ class NodeSpec:
971
1056
 
972
1057
  def to_node(
973
1058
  self,
974
- owner: Optional[Any] = None,
975
1059
  exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
976
1060
  ) -> Node[Any]:
977
1061
  match self.kind:
@@ -980,12 +1064,11 @@ class NodeSpec:
980
1064
  case NodeSpecKind.CONDITION:
981
1065
  return Condition(self.payload, exception_policy=exception_policy)
982
1066
  case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
983
- owner_effective = getattr(self, "owner", None) or owner
984
1067
  factory = self.payload["factory"]
985
- expanded = _bt_expand_children(factory, owner_effective)
1068
+ expanded = _bt_expand_children(factory)
986
1069
  self.children = expanded
987
1070
  built = [
988
- ch.to_node(owner_effective, exception_policy) for ch in expanded
1071
+ ch.to_node(exception_policy) for ch in expanded
989
1072
  ]
990
1073
  match self.kind:
991
1074
  case NodeSpecKind.SEQUENCE:
@@ -1013,7 +1096,7 @@ class NodeSpec:
1013
1096
 
1014
1097
  case NodeSpecKind.DECORATOR:
1015
1098
  assert len(self.children) == 1, "Decorator must wrap exactly one child"
1016
- child_node = self.children[0].to_node(owner, exception_policy)
1099
+ child_node = self.children[0].to_node(exception_policy)
1017
1100
  builder = self.payload
1018
1101
  return builder(child_node)
1019
1102
 
@@ -1025,7 +1108,7 @@ class NodeSpec:
1025
1108
  key_fn = self.payload["key_fn"]
1026
1109
  case_specs: List[CaseSpec] = self.payload["cases"]
1027
1110
  cases = [
1028
- (cs.matcher, cs.child.to_node(owner, exception_policy))
1111
+ (cs.matcher, cs.child.to_node(exception_policy))
1029
1112
  for cs in case_specs
1030
1113
  ]
1031
1114
  return Match(
@@ -1039,8 +1122,8 @@ class NodeSpec:
1039
1122
  cond_spec = self.payload["condition"]
1040
1123
  child_spec = self.children[0]
1041
1124
  return DoWhile(
1042
- cond_spec.to_node(owner, exception_policy),
1043
- child_spec.to_node(owner, exception_policy),
1125
+ cond_spec.to_node(exception_policy),
1126
+ child_spec.to_node(exception_policy),
1044
1127
  name=self.name,
1045
1128
  exception_policy=exception_policy,
1046
1129
  )
@@ -1051,7 +1134,6 @@ class NodeSpec:
1051
1134
 
1052
1135
  def _bt_expand_children(
1053
1136
  factory: Callable[..., Generator[Any, None, None]],
1054
- owner: Optional[Any],
1055
1137
  expansion_stack: Optional[Set[str]] = None,
1056
1138
  ) -> List[NodeSpec]:
1057
1139
  """
@@ -1059,7 +1141,6 @@ def _bt_expand_children(
1059
1141
 
1060
1142
  Args:
1061
1143
  factory: The generator function that yields child specs
1062
- owner: The namespace object for resolving N references
1063
1144
  expansion_stack: Stack of factory names to detect recursion
1064
1145
  """
1065
1146
  if expansion_stack is None:
@@ -1080,10 +1161,7 @@ def _bt_expand_children(
1080
1161
  expansion_stack = expansion_stack.copy()
1081
1162
  expansion_stack.add(factory_name)
1082
1163
 
1083
- try:
1084
- gen = factory(owner)
1085
- except TypeError:
1086
- gen = factory()
1164
+ gen = factory()
1087
1165
 
1088
1166
  if not inspect.isgenerator(gen):
1089
1167
  raise TypeError(
@@ -1119,8 +1197,7 @@ def _bt_expand_children(
1119
1197
  NodeSpecKind.PARALLEL,
1120
1198
  ) and hasattr(spec, "_expansion_stack"):
1121
1199
  child_factory = spec.payload["factory"]
1122
- child_owner = getattr(spec, "owner", None) or owner
1123
- _bt_expand_children(child_factory, child_owner, spec._expansion_stack) # type: ignore
1200
+ _bt_expand_children(child_factory, spec._expansion_stack) # type: ignore
1124
1201
 
1125
1202
  return out
1126
1203
 
@@ -1309,39 +1386,78 @@ class _BT:
1309
1386
  self._tracking_stack[-1].append((fn.__name__, fn))
1310
1387
  return fn
1311
1388
 
1312
- def sequence(self, func_or_none=None, *, memory: bool = True):
1389
+ def sequence(
1390
+ self, *args: Union[F, NodeSpec, Callable[[Any], Any]], memory: bool = True
1391
+ ) -> Union[F, Callable[[F], F], NodeSpec]:
1313
1392
  """Decorator to mark a generator function as a sequence composite.
1314
1393
 
1315
- Can be used with or without parentheses:
1316
- @bt.sequence
1317
- def root():
1318
- ...
1394
+ Can be used in three ways:
1395
+ 1. Decorator without parentheses:
1396
+ @bt.sequence
1397
+ def root():
1398
+ ...
1319
1399
 
1320
- @bt.sequence(memory=False)
1321
- def root():
1322
- ...
1400
+ 2. Decorator with parameters:
1401
+ @bt.sequence(memory=False)
1402
+ def root():
1403
+ ...
1404
+
1405
+ 3. Direct call with child nodes:
1406
+ bt.sequence(action1, action2, action3)
1323
1407
  """
1324
- # Case 1: Used as @bt.sequence (no parens)
1325
- if callable(func_or_none):
1326
- # Apply decorator directly with defaults
1408
+ # Case 3: Direct call with children - bt.sequence(node1, node2, ...)
1409
+ # This is detected when we have multiple args, or a single arg that's not a generator function
1410
+ if len(args) == 0:
1411
+ # Case 2a: bt.sequence() with no arguments - return decorator
1412
+ return self._sequence_impl(memory=memory)
1413
+
1414
+ if len(args) > 1:
1415
+ # Multiple children - create sequence directly
1416
+ return self._sequence_from_children(args, memory)
1417
+
1418
+ # Single argument - check if it's a generator function (decorator case) or a node (direct call)
1419
+ single_arg = args[0]
1420
+
1421
+ # Check if it's a generator function (decorator form)
1422
+ if inspect.isfunction(single_arg) and inspect.isgeneratorfunction(single_arg):
1423
+ # Case 1: @bt.sequence def root(): ...
1327
1424
  spec = NodeSpec(
1328
1425
  kind=NodeSpecKind.SEQUENCE,
1329
- name=_name_of(func_or_none),
1330
- payload={"factory": func_or_none, "memory": memory},
1426
+ name=_name_of(single_arg),
1427
+ payload={"factory": single_arg, "memory": memory},
1331
1428
  )
1332
- func_or_none.node_spec = spec # type: ignore
1429
+ single_arg.node_spec = spec # type: ignore
1333
1430
  if self._tracking_stack:
1334
- self._tracking_stack[-1].append((func_or_none.__name__, func_or_none))
1335
- return func_or_none
1431
+ self._tracking_stack[-1].append((single_arg.__name__, single_arg))
1432
+ return single_arg
1433
+
1434
+ # Case 3b: Single child node - bt.sequence(action1)
1435
+ return self._sequence_from_children(args, memory)
1336
1436
 
1337
- # Case 2: Used as @bt.sequence() or @bt.sequence(memory=False)
1338
- return self._sequence_impl(memory=memory)
1437
+ def _sequence_from_children(self, children: Tuple[Any, ...], memory: bool) -> NodeSpec:
1438
+ """Create a sequence NodeSpec from child nodes."""
1439
+ # Create a uniquely named factory to avoid false recursion detection
1440
+ child_names = ', '.join(_name_of(c) for c in children)
1441
+
1442
+ def _sequence_factory_direct() -> Generator[Any, None, None]:
1443
+ for child in children:
1444
+ yield child
1339
1445
 
1340
- def _sequence_impl(self, memory: bool):
1446
+ # Set a unique name for the factory function
1447
+ _sequence_factory_direct.__name__ = f"_sequence_factory_direct_{id(children)}"
1448
+ _sequence_factory_direct.__qualname__ = f"_sequence_factory_direct_{id(children)}"
1449
+
1450
+ name = f"Sequence({child_names})" if children else "Sequence"
1451
+
1452
+ return NodeSpec(
1453
+ kind=NodeSpecKind.SEQUENCE,
1454
+ name=name,
1455
+ payload={"factory": _sequence_factory_direct, "memory": memory},
1456
+ )
1457
+
1458
+ def _sequence_impl(self, memory: bool) -> Callable[[F], F]:
1341
1459
  """Implementation of sequence decorator."""
1342
- def deco(
1343
- factory: Callable[..., Generator[Any, None, None]],
1344
- ) -> Callable[..., Generator[Any, None, None]]:
1460
+ def deco(factory: F) -> F:
1345
1461
  spec = NodeSpec(
1346
1462
  kind=NodeSpecKind.SEQUENCE,
1347
1463
  name=_name_of(factory),
@@ -1354,47 +1470,82 @@ class _BT:
1354
1470
 
1355
1471
  return deco
1356
1472
 
1357
- def selector(self, func_or_none=None, *, memory: bool = True, reactive: bool = False):
1473
+ def selector(
1474
+ self, *args: Union[F, NodeSpec, Callable[[Any], Any]], memory: bool = True
1475
+ ) -> Union[F, Callable[[F], F], NodeSpec]:
1358
1476
  """Decorator to mark a generator function as a selector composite.
1359
1477
 
1360
- Can be used with or without parentheses:
1361
- @bt.selector
1362
- def root():
1363
- ...
1478
+ Can be used in three ways:
1479
+ 1. Decorator without parentheses:
1480
+ @bt.selector
1481
+ def root():
1482
+ ...
1364
1483
 
1365
- @bt.selector(memory=False)
1366
- def root():
1367
- ...
1484
+ 2. Decorator with parameters:
1485
+ @bt.selector(memory=False)
1486
+ def root():
1487
+ ...
1368
1488
 
1369
- @bt.selector(reactive=True)
1370
- def root():
1371
- ...
1489
+ 3. Direct call with child nodes:
1490
+ bt.selector(option1, option2, option3)
1372
1491
  """
1373
- # Case 1: Used as @bt.selector (no parens)
1374
- if callable(func_or_none):
1375
- # Apply decorator directly with defaults
1492
+ # Case 3: Direct call with children - bt.selector(node1, node2, ...)
1493
+ # This is detected when we have multiple args, or a single arg that's not a generator function
1494
+ if len(args) == 0:
1495
+ # Case 2a: bt.selector() with no arguments - return decorator
1496
+ return self._selector_impl(memory=memory)
1497
+
1498
+ if len(args) > 1:
1499
+ # Multiple children - create selector directly
1500
+ return self._selector_from_children(args, memory)
1501
+
1502
+ # Single argument - check if it's a generator function (decorator case) or a node (direct call)
1503
+ single_arg = args[0]
1504
+
1505
+ # Check if it's a generator function (decorator form)
1506
+ if inspect.isfunction(single_arg) and inspect.isgeneratorfunction(single_arg):
1507
+ # Case 1: @bt.selector def root(): ...
1376
1508
  spec = NodeSpec(
1377
1509
  kind=NodeSpecKind.SELECTOR,
1378
- name=_name_of(func_or_none),
1379
- payload={"factory": func_or_none, "memory": memory, "reactive": reactive},
1510
+ name=_name_of(single_arg),
1511
+ payload={"factory": single_arg, "memory": memory},
1380
1512
  )
1381
- func_or_none.node_spec = spec # type: ignore
1513
+ single_arg.node_spec = spec # type: ignore
1382
1514
  if self._tracking_stack:
1383
- self._tracking_stack[-1].append((func_or_none.__name__, func_or_none))
1384
- return func_or_none
1515
+ self._tracking_stack[-1].append((single_arg.__name__, single_arg))
1516
+ return single_arg
1517
+
1518
+ # Case 3b: Single child node - bt.selector(action1)
1519
+ return self._selector_from_children(args, memory)
1520
+
1521
+ def _selector_from_children(self, children: Tuple[Any, ...], memory: bool) -> NodeSpec:
1522
+ """Create a selector NodeSpec from child nodes."""
1523
+ # Create a uniquely named factory to avoid false recursion detection
1524
+ child_names = ', '.join(_name_of(c) for c in children)
1525
+
1526
+ def _selector_factory_direct() -> Generator[Any, None, None]:
1527
+ for child in children:
1528
+ yield child
1385
1529
 
1386
- # Case 2: Used as @bt.selector() or @bt.selector(memory=False)
1387
- return self._selector_impl(memory=memory, reactive=reactive)
1530
+ # Set a unique name for the factory function
1531
+ _selector_factory_direct.__name__ = f"_selector_factory_direct_{id(children)}"
1532
+ _selector_factory_direct.__qualname__ = f"_selector_factory_direct_{id(children)}"
1388
1533
 
1389
- def _selector_impl(self, memory: bool, reactive: bool):
1534
+ name = f"Selector({child_names})" if children else "Selector"
1535
+
1536
+ return NodeSpec(
1537
+ kind=NodeSpecKind.SELECTOR,
1538
+ name=name,
1539
+ payload={"factory": _selector_factory_direct, "memory": memory},
1540
+ )
1541
+
1542
+ def _selector_impl(self, memory: bool) -> Callable[[F], F]:
1390
1543
  """Implementation of selector decorator."""
1391
- def deco(
1392
- factory: Callable[..., Generator[Any, None, None]],
1393
- ) -> Callable[..., Generator[Any, None, None]]:
1544
+ def deco(factory: F) -> F:
1394
1545
  spec = NodeSpec(
1395
1546
  kind=NodeSpecKind.SELECTOR,
1396
1547
  name=_name_of(factory),
1397
- payload={"factory": factory, "memory": memory, "reactive": reactive},
1548
+ payload={"factory": factory, "memory": memory},
1398
1549
  )
1399
1550
  factory.node_spec = spec # type: ignore
1400
1551
  if self._tracking_stack:
@@ -1405,12 +1556,10 @@ class _BT:
1405
1556
 
1406
1557
  def parallel(
1407
1558
  self, *, success_threshold: int, failure_threshold: Optional[int] = None
1408
- ):
1559
+ ) -> Callable[[F], F]:
1409
1560
  """Decorator to mark a generator function as a parallel composite."""
1410
1561
 
1411
- def deco(
1412
- factory: Callable[..., Generator[Any, None, None]],
1413
- ) -> Callable[..., Generator[Any, None, None]]:
1562
+ def deco(factory: F) -> F:
1414
1563
  spec = NodeSpec(
1415
1564
  kind=NodeSpecKind.PARALLEL,
1416
1565
  name=_name_of(factory),
@@ -1569,8 +1718,8 @@ class _BT:
1569
1718
  ...
1570
1719
 
1571
1720
  @bt.sequence()
1572
- def root(N):
1573
- yield N.my_action
1721
+ def root():
1722
+ yield my_action
1574
1723
  """
1575
1724
  created_nodes = []
1576
1725
  self._tracking_stack.append(created_nodes)
@@ -1584,7 +1733,7 @@ class _BT:
1584
1733
  name: node for name, node in created_nodes if hasattr(node, "node_spec")
1585
1734
  }
1586
1735
  namespace = SimpleNamespace(**nodes)
1587
-
1736
+
1588
1737
  # Store the tree's name for use in subtree references
1589
1738
  namespace._tree_name = fn.__name__
1590
1739
 
@@ -1600,10 +1749,13 @@ class _BT:
1600
1749
 
1601
1750
  return namespace
1602
1751
 
1603
- def root(
1604
- self, fn: Callable[..., Generator[Any, None, None]]
1605
- ) -> Callable[..., Generator[Any, None, None]]:
1752
+ def root(self, fn: F) -> F:
1606
1753
  """Mark a composite as the root of the tree."""
1754
+ if not hasattr(fn, "node_spec"):
1755
+ raise TypeError(
1756
+ f"@bt.root can only be used on composites (sequences, selectors, etc.), "
1757
+ f"got {fn!r}"
1758
+ )
1607
1759
  fn.node_spec.is_root = True # type: ignore
1608
1760
  return fn
1609
1761
 
@@ -1667,9 +1819,8 @@ def _generate_mermaid(tree: SimpleNamespace) -> str:
1667
1819
  def ensure_children(spec: NodeSpec) -> List[NodeSpec]:
1668
1820
  match spec.kind:
1669
1821
  case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
1670
- owner = getattr(spec, "owner", None) or tree
1671
1822
  factory = spec.payload["factory"]
1672
- spec.children = _bt_expand_children(factory, owner)
1823
+ spec.children = _bt_expand_children(factory)
1673
1824
  case NodeSpecKind.SUBTREE:
1674
1825
  subtree_root = spec.payload["root"]
1675
1826
  spec.children = [subtree_root]
@@ -1735,6 +1886,7 @@ class Runner(Generic[BB]):
1735
1886
  bb: Blackboard containing shared state
1736
1887
  tb: Optional timebase for time management (defaults to MonotonicClock)
1737
1888
  exception_policy: How to handle exceptions during tree execution
1889
+ trace: Optional logger instance for tracing action/condition execution
1738
1890
 
1739
1891
  Methods:
1740
1892
  tick(): Execute one tick of the behavior tree
@@ -1751,6 +1903,14 @@ class Runner(Generic[BB]):
1751
1903
 
1752
1904
  runner = Runner(MyTree, bb=blackboard)
1753
1905
  result = await runner.tick_until_complete()
1906
+
1907
+ Tracing:
1908
+ import logging
1909
+ trace_logger = logging.getLogger("bt.trace")
1910
+ runner = Runner(MyTree, bb=blackboard, trace=trace_logger)
1911
+ result = await runner.tick_until_complete()
1912
+ # Logs: "action: module.do_work | SUCCESS"
1913
+ # "condition: module.check_condition | SUCCESS"
1754
1914
  """
1755
1915
  def __init__(
1756
1916
  self,
@@ -1758,23 +1918,31 @@ class Runner(Generic[BB]):
1758
1918
  bb: BB,
1759
1919
  tb: Optional[Timebase] = None,
1760
1920
  exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
1921
+ trace: Optional[logging.Logger] = None,
1761
1922
  ) -> None:
1762
1923
  self.tree = tree
1763
1924
  self.bb: BB = bb
1764
1925
  self.tb = tb or MonotonicClock()
1765
1926
  self.exception_policy = exception_policy
1927
+ self.trace = trace
1766
1928
 
1767
1929
  if not hasattr(tree, "root"):
1768
1930
  raise ValueError("Tree namespace must have a 'root' attribute")
1769
1931
 
1770
1932
  self.root: Node[BB] = tree.root.to_node(
1771
- owner=tree, exception_policy=exception_policy
1933
+ exception_policy=exception_policy
1772
1934
  )
1773
1935
 
1774
1936
  async def tick(self) -> Status:
1775
- result = await self.root.tick(self.bb, self.tb)
1776
- self.tb.advance()
1777
- return result
1937
+ # Set trace logger in context for this tick
1938
+ token = _trace_logger_ctx.set(self.trace)
1939
+ try:
1940
+ result = await self.root.tick(self.bb, self.tb)
1941
+ self.tb.advance()
1942
+ return result
1943
+ finally:
1944
+ # Clear the context variable
1945
+ _trace_logger_ctx.reset(token)
1778
1946
 
1779
1947
  async def tick_until_complete(self, timeout: Optional[float] = None) -> Status:
1780
1948
  start = self.tb.now()