mycorrhizal 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. mycorrhizal/__init__.py +3 -0
  2. mycorrhizal/common/__init__.py +68 -0
  3. mycorrhizal/common/interface_builder.py +203 -0
  4. mycorrhizal/common/interfaces.py +412 -0
  5. mycorrhizal/common/timebase.py +99 -0
  6. mycorrhizal/common/wrappers.py +532 -0
  7. mycorrhizal/enoki/__init__.py +0 -0
  8. mycorrhizal/enoki/core.py +1545 -0
  9. mycorrhizal/enoki/testing_utils.py +529 -0
  10. mycorrhizal/enoki/util.py +220 -0
  11. mycorrhizal/hypha/__init__.py +0 -0
  12. mycorrhizal/hypha/core/__init__.py +107 -0
  13. mycorrhizal/hypha/core/builder.py +404 -0
  14. mycorrhizal/hypha/core/runtime.py +890 -0
  15. mycorrhizal/hypha/core/specs.py +234 -0
  16. mycorrhizal/hypha/util.py +38 -0
  17. mycorrhizal/rhizomorph/README.md +220 -0
  18. mycorrhizal/rhizomorph/__init__.py +0 -0
  19. mycorrhizal/rhizomorph/core.py +1729 -0
  20. mycorrhizal/rhizomorph/util.py +45 -0
  21. mycorrhizal/spores/__init__.py +124 -0
  22. mycorrhizal/spores/cache.py +208 -0
  23. mycorrhizal/spores/core.py +419 -0
  24. mycorrhizal/spores/dsl/__init__.py +48 -0
  25. mycorrhizal/spores/dsl/enoki.py +514 -0
  26. mycorrhizal/spores/dsl/hypha.py +399 -0
  27. mycorrhizal/spores/dsl/rhizomorph.py +351 -0
  28. mycorrhizal/spores/encoder/__init__.py +11 -0
  29. mycorrhizal/spores/encoder/base.py +42 -0
  30. mycorrhizal/spores/encoder/json.py +159 -0
  31. mycorrhizal/spores/extraction.py +484 -0
  32. mycorrhizal/spores/models.py +288 -0
  33. mycorrhizal/spores/transport/__init__.py +10 -0
  34. mycorrhizal/spores/transport/base.py +46 -0
  35. mycorrhizal-0.1.0.dist-info/METADATA +198 -0
  36. mycorrhizal-0.1.0.dist-info/RECORD +37 -0
  37. mycorrhizal-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1729 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Rhizomorph - Asyncio Behavior Tree Framework
4
+
5
+ A decorator-based DSL for defining and executing behavior trees with support for
6
+ asyncio, multi-file composition, and type-safe blackboard interfaces.
7
+
8
+ Usage:
9
+ from mycorrhizal.rhizomorph.core import bt, Runner, Status
10
+
11
+ @bt.tree
12
+ def MyBehaviorTree():
13
+ @bt.action
14
+ async def do_work(bb) -> Status:
15
+ # Do some work
16
+ return Status.SUCCESS
17
+
18
+ @bt.condition
19
+ def should_work(bb) -> bool:
20
+ return bb.work_available
21
+
22
+ @bt.root
23
+ @bt.sequence
24
+ def root():
25
+ yield should_work
26
+ yield do_work
27
+
28
+ # Run the tree
29
+ runner = Runner(MyBehaviorTree, bb=blackboard)
30
+ result = await runner.tick_until_complete()
31
+
32
+ Key Classes:
33
+ Status - Result status for behavior tree nodes (SUCCESS, FAILURE, RUNNING, etc.)
34
+ Node - Base class for all behavior tree nodes
35
+ Action - Leaf node that executes a function
36
+ Condition - Leaf node that evaluates a predicate
37
+ Sequence - Composite that runs children in order
38
+ Selector - Composite that runs children until one succeeds
39
+ Parallel - Composite that runs children concurrently
40
+
41
+ Multi-file Composition:
42
+ Use bt.subtree() to reference trees defined in other modules:
43
+ from other_module import OtherTree
44
+
45
+ @bt.tree
46
+ def MainTree():
47
+ @bt.root
48
+ @bt.sequence
49
+ def root():
50
+ yield bt.subtree(OtherTree, owner=MainTree)
51
+ """
52
+ from __future__ import annotations
53
+
54
+ import asyncio
55
+ import inspect
56
+ import logging
57
+ import traceback
58
+ from dataclasses import dataclass, field
59
+ from enum import Enum
60
+ from typing import (
61
+ Any,
62
+ Callable,
63
+ Dict,
64
+ Generator,
65
+ Generic,
66
+ List,
67
+ Optional,
68
+ Tuple,
69
+ TypeVar,
70
+ Union,
71
+ Set,
72
+ Protocol,
73
+ )
74
+ from typing import Sequence as SequenceT
75
+ from types import SimpleNamespace
76
+
77
+ from mycorrhizal.common.timebase import *
78
+
79
+ logger = logging.getLogger(__name__)
80
+
81
+ BB = TypeVar("BB")
82
+
83
+
84
+ # ======================================================================================
85
+ # Interface Integration Helper
86
+ # ======================================================================================
87
+
88
+
89
+ def _create_interface_view_if_needed(bb: Any, func: Callable) -> Any:
90
+ """
91
+ Create a constrained view if the function has an interface type hint on its
92
+ first parameter (typically named 'bb').
93
+
94
+ This enables type-safe, constrained access to blackboard state based on
95
+ interface definitions created with @blackboard_interface.
96
+
97
+ Args:
98
+ bb: The blackboard instance
99
+ func: The function to check for interface type hints
100
+
101
+ Returns:
102
+ Either the original blackboard or a constrained view based on interface metadata
103
+ """
104
+ from typing import get_type_hints
105
+
106
+ try:
107
+ sig = inspect.signature(func)
108
+ params = list(sig.parameters.values())
109
+
110
+ # Check first parameter (usually 'bb')
111
+ if params and params[0].name == 'bb':
112
+ bb_type = get_type_hints(func).get('bb')
113
+
114
+ # If type hint exists and has interface metadata
115
+ if bb_type and hasattr(bb_type, '_readonly_fields'):
116
+ from mycorrhizal.common.wrappers import create_view_from_protocol
117
+ return create_view_from_protocol(bb, bb_type)
118
+ except Exception:
119
+ # If anything goes wrong with type inspection, fall back to original bb
120
+ pass
121
+
122
+ return bb
123
+
124
+
125
+ # ======================================================================================
126
+ # Function signature protocols
127
+ # ======================================================================================
128
+
129
+
130
+ def _supports_timebase(func: Callable) -> bool:
131
+ """Check if a function accepts a 'tb' parameter."""
132
+ try:
133
+ sig = inspect.signature(func)
134
+ return 'tb' in sig.parameters
135
+ except (ValueError, TypeError):
136
+ return False
137
+
138
+
139
+ async def _call_node_function(func: Callable, bb: Any, tb: Timebase) -> Any:
140
+ """
141
+ Call a node function with appropriate parameters based on its signature.
142
+
143
+ If the function has an interface type hint on its 'bb' parameter, a
144
+ constrained view will be created automatically to enforce access control.
145
+ """
146
+ # Create interface view if function has interface type hint
147
+ bb_to_pass = _create_interface_view_if_needed(bb, func)
148
+
149
+ if _supports_timebase(func):
150
+ if inspect.iscoroutinefunction(func):
151
+ return await func(bb=bb_to_pass, tb=tb)
152
+ else:
153
+ return func(bb=bb_to_pass, tb=tb)
154
+ else:
155
+ if inspect.iscoroutinefunction(func):
156
+ return await func(bb=bb_to_pass)
157
+ else:
158
+ return func(bb=bb_to_pass)
159
+
160
+
161
+ # ======================================================================================
162
+ # Status / helpers
163
+ # ======================================================================================
164
+
165
+
166
+ class Status(Enum):
167
+ """Result status for behavior tree node execution.
168
+
169
+ Attributes:
170
+ SUCCESS - Node completed successfully
171
+ FAILURE - Node failed
172
+ RUNNING - Node is still running (async operation in progress)
173
+ CANCELLED - Node was cancelled before completion
174
+ ERROR - Node encountered an error
175
+
176
+ Note:
177
+ Composite nodes use these statuses to determine control flow:
178
+ - Sequence stops on first FAILURE
179
+ - Selector stops on first SUCCESS
180
+ - Parallel waits for all children, fails if any fail
181
+ """
182
+
183
+ SUCCESS = 1
184
+ FAILURE = 2
185
+ RUNNING = 3
186
+ CANCELLED = 4
187
+ ERROR = 5
188
+
189
+
190
+ class ExceptionPolicy(Enum):
191
+ """Policy for handling exceptions in action/condition nodes."""
192
+
193
+ LOG_AND_CONTINUE = 1
194
+ PROPAGATE = 2
195
+
196
+
197
+ def _name_of(obj: Any) -> str:
198
+ if hasattr(obj, "__name__"):
199
+ return obj.__name__
200
+ if hasattr(obj, "name"):
201
+ return str(obj.name)
202
+ return f"{obj.__class__.__name__}@{id(obj):x}"
203
+
204
+
205
+ # ======================================================================================
206
+ # Recursion Detection
207
+ # ======================================================================================
208
+
209
+
210
+ class RecursionError(Exception):
211
+ """Raised when recursive behavior tree structure is detected"""
212
+
213
+ pass
214
+
215
+
216
+ # ======================================================================================
217
+ # Node model
218
+ # ======================================================================================
219
+
220
+
221
+ class Node(Generic[BB]):
222
+ """Base class for behavior tree nodes.
223
+
224
+ All behavior tree nodes inherit from this class. Nodes are executed
225
+ by calling the tick() method, which returns a Status indicating
226
+ the result of execution.
227
+
228
+ Args:
229
+ name: Optional name for this node (used in debugging and logging)
230
+ exception_policy: How to handle exceptions during execution
231
+
232
+ Attributes:
233
+ name: Node name
234
+ parent: Parent node in the tree
235
+ exception_policy: Exception handling policy
236
+ _entered: Whether on_enter has been called
237
+ _last_status: Last status returned by tick
238
+
239
+ Methods:
240
+ tick(bb, tb): Execute the node, return Status
241
+ on_enter(bb, tb): Called when node is first entered
242
+ on_exit(bb, status, tb): Called when node exits (not RUNNING)
243
+ reset(): Reset node state for reuse
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ name: Optional[str] = None,
249
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
250
+ ) -> None:
251
+ self.name: str = name or _name_of(self)
252
+ self.parent: Optional[Node[BB]] = None
253
+ self.exception_policy = exception_policy
254
+ self._entered: bool = False
255
+ self._last_status: Optional[Status] = None
256
+
257
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
258
+ raise NotImplementedError
259
+
260
+ async def on_enter(self, bb: BB, tb: Timebase) -> None:
261
+ return
262
+
263
+ async def on_exit(self, bb: BB, status: Status, tb: Timebase) -> None:
264
+ return
265
+
266
+ def reset(self) -> None:
267
+ self._entered = False
268
+ self._last_status = None
269
+
270
+ async def _ensure_entered(self, bb: BB, tb: Timebase) -> None:
271
+ if not self._entered:
272
+ await self.on_enter(bb, tb)
273
+ self._entered = True
274
+
275
+ async def _finish(self, bb: BB, status: Status, tb: Timebase) -> Status:
276
+ self._last_status = status
277
+ if status is not Status.RUNNING:
278
+ try:
279
+ await self.on_exit(bb, status, tb)
280
+ finally:
281
+ self._entered = False
282
+ return status
283
+
284
+
285
+ # ======================================================================================
286
+ # Leaves
287
+ # ======================================================================================
288
+
289
+
290
+ class Action(Node[BB]):
291
+ """Leaf node that executes a function.
292
+
293
+ Action nodes wrap sync or async functions that perform work or
294
+ check conditions. The function can return:
295
+ - Status enum (SUCCESS, FAILURE, RUNNING, etc.)
296
+ - bool (True -> SUCCESS, False -> FAILURE)
297
+ - None (treated as SUCCESS)
298
+
299
+ Args:
300
+ func: Function to execute (sync or async)
301
+ name: Optional name for this action
302
+ exception_policy: How to handle exceptions
303
+
304
+ The function signature can be:
305
+ - func(bb) -> Status | bool | None
306
+ - func(bb, tb) -> Status | bool | None
307
+
308
+ where bb is the blackboard and tb is an optional timebase.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ func: Callable[..., Any],
314
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
315
+ ) -> None:
316
+ super().__init__(name=_name_of(func), exception_policy=exception_policy)
317
+ self._func = func
318
+
319
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
320
+ await self._ensure_entered(bb, tb)
321
+ try:
322
+ result = await _call_node_function(self._func, bb, tb)
323
+ except asyncio.CancelledError:
324
+ return await self._finish(bb, Status.CANCELLED, tb)
325
+ except Exception as e:
326
+ logger.error(
327
+ f"Exception in Action '{self.name}':\n{traceback.format_exc()}"
328
+ )
329
+ if self.exception_policy == ExceptionPolicy.PROPAGATE:
330
+ raise
331
+ return await self._finish(bb, Status.ERROR, tb)
332
+
333
+ if isinstance(result, Status):
334
+ return await self._finish(bb, result, tb)
335
+ if isinstance(result, bool):
336
+ return await self._finish(
337
+ bb, Status.SUCCESS if result else Status.FAILURE, tb
338
+ )
339
+ return await self._finish(bb, Status.SUCCESS, tb)
340
+
341
+
342
+ class Condition(Action[BB]):
343
+ """Boolean leaf: True→SUCCESS, False→FAILURE (Status accepted but discouraged)."""
344
+
345
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
346
+ await self._ensure_entered(bb, tb)
347
+ try:
348
+ result = await _call_node_function(self._func, bb, tb)
349
+ except asyncio.CancelledError:
350
+ return await self._finish(bb, Status.CANCELLED, tb)
351
+ except Exception as e:
352
+ logger.error(
353
+ f"Exception in Condition '{self.name}':\n{traceback.format_exc()}"
354
+ )
355
+ if self.exception_policy == ExceptionPolicy.PROPAGATE:
356
+ raise
357
+ return await self._finish(bb, Status.ERROR, tb)
358
+
359
+ if isinstance(result, Status):
360
+ return await self._finish(bb, result, tb)
361
+ return await self._finish(
362
+ bb, Status.SUCCESS if bool(result) else Status.FAILURE, tb
363
+ )
364
+
365
+
366
+ # ======================================================================================
367
+ # Composites
368
+ # ======================================================================================
369
+
370
+
371
+ class Sequence(Node[BB]):
372
+ """Sequence (AND): fail/err fast; RUNNING bubbles; all SUCCESS → SUCCESS."""
373
+
374
+ def __init__(
375
+ self,
376
+ children: SequenceT[Node[BB]],
377
+ *,
378
+ memory: bool = True,
379
+ name: Optional[str] = None,
380
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
381
+ ) -> None:
382
+ super().__init__(name or "Sequence", exception_policy=exception_policy)
383
+ self.children = list(children)
384
+ for ch in self.children:
385
+ ch.parent = self
386
+ self.memory = memory
387
+ self._idx = 0
388
+
389
+ def reset(self) -> None:
390
+ super().reset()
391
+ self._idx = 0
392
+ for ch in self.children:
393
+ ch.reset()
394
+
395
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
396
+ await self._ensure_entered(bb, tb)
397
+ start = self._idx if self.memory else 0
398
+ for i in range(start, len(self.children)):
399
+ st = await self.children[i].tick(bb, tb)
400
+ if st is Status.RUNNING:
401
+ if self.memory:
402
+ self._idx = i
403
+ return Status.RUNNING
404
+ if st in (Status.FAILURE, Status.ERROR, Status.CANCELLED):
405
+ self._idx = 0
406
+ return await self._finish(bb, st, tb)
407
+ self._idx = 0
408
+ return await self._finish(bb, Status.SUCCESS, tb)
409
+
410
+
411
+ class Selector(Node[BB]):
412
+ """Selector (Fallback): first SUCCESS wins; RUNNING bubbles; else FAILURE."""
413
+
414
+ def __init__(
415
+ self,
416
+ children: SequenceT[Node[BB]],
417
+ *,
418
+ memory: bool = True,
419
+ name: Optional[str] = None,
420
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
421
+ ) -> None:
422
+ super().__init__(name or "Selector", exception_policy=exception_policy)
423
+ self.children = list(children)
424
+ for ch in self.children:
425
+ ch.parent = self
426
+ self.memory = memory
427
+ self._idx = 0
428
+
429
+ def reset(self) -> None:
430
+ super().reset()
431
+ self._idx = 0
432
+ for ch in self.children:
433
+ ch.reset()
434
+
435
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
436
+ await self._ensure_entered(bb, tb)
437
+ start = self._idx if self.memory else 0
438
+ for i in range(start, len(self.children)):
439
+ st = await self.children[i].tick(bb, tb)
440
+ if st is Status.SUCCESS:
441
+ self._idx = 0
442
+ return await self._finish(bb, Status.SUCCESS, tb)
443
+ if st is Status.RUNNING:
444
+ if self.memory:
445
+ self._idx = i
446
+ return Status.RUNNING
447
+ self._idx = 0
448
+ return await self._finish(bb, Status.FAILURE, tb)
449
+
450
+
451
+ class Parallel(Node[BB]):
452
+ """
453
+ Parallel: tick all children per tick concurrently.
454
+ - success_threshold (k) SUCCESS to report SUCCESS
455
+ - failure_threshold defaults to n-k+1 (cannot reach success anymore) to report FAILURE
456
+ Otherwise RUNNING.
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ children: SequenceT[Node[BB]],
462
+ *,
463
+ success_threshold: int,
464
+ failure_threshold: Optional[int] = None,
465
+ name: Optional[str] = None,
466
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
467
+ ) -> None:
468
+ super().__init__(name or "Parallel", exception_policy=exception_policy)
469
+ self.children = list(children)
470
+ for ch in self.children:
471
+ ch.parent = self
472
+ self.n = len(self.children)
473
+ self.k = max(1, min(success_threshold, self.n))
474
+ self.f = (
475
+ failure_threshold
476
+ if failure_threshold is not None
477
+ else (self.n - self.k + 1)
478
+ )
479
+
480
+ def reset(self) -> None:
481
+ super().reset()
482
+ for ch in self.children:
483
+ ch.reset()
484
+
485
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
486
+ await self._ensure_entered(bb, tb)
487
+ tasks = [asyncio.create_task(ch.tick(bb, tb)) for ch in self.children]
488
+ try:
489
+ results = await asyncio.gather(*tasks, return_exceptions=False)
490
+ finally:
491
+ pass
492
+ succ = sum(1 for s in results if s is Status.SUCCESS)
493
+ fail = sum(
494
+ 1 for s in results if s in (Status.FAILURE, Status.ERROR, Status.CANCELLED)
495
+ )
496
+ if succ >= self.k:
497
+ return await self._finish(bb, Status.SUCCESS, tb)
498
+ if fail >= self.f:
499
+ return await self._finish(bb, Status.FAILURE, tb)
500
+ return Status.RUNNING
501
+
502
+
503
+ # ======================================================================================
504
+ # Decorators (wrappers)
505
+ # ======================================================================================
506
+
507
+
508
+ class Inverter(Node[BB]):
509
+ def __init__(
510
+ self,
511
+ child: Node[BB],
512
+ name: Optional[str] = None,
513
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
514
+ ) -> None:
515
+ super().__init__(
516
+ name or f"Inverter({_name_of(child)})", exception_policy=exception_policy
517
+ )
518
+ self.child = child
519
+ self.child.parent = self
520
+
521
+ def reset(self) -> None:
522
+ super().reset()
523
+ self.child.reset()
524
+
525
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
526
+ await self._ensure_entered(bb, tb)
527
+ st = await self.child.tick(bb, tb)
528
+ if st is Status.SUCCESS:
529
+ return await self._finish(bb, Status.FAILURE, tb)
530
+ if st is Status.FAILURE:
531
+ return await self._finish(bb, Status.SUCCESS, tb)
532
+ return st
533
+
534
+
535
+ class Retry(Node[BB]):
536
+ def __init__(
537
+ self,
538
+ child: Node[BB],
539
+ *,
540
+ max_attempts: int,
541
+ retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
542
+ name: Optional[str] = None,
543
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
544
+ ) -> None:
545
+ super().__init__(
546
+ name or f"Retry({_name_of(child)},{max_attempts})",
547
+ exception_policy=exception_policy,
548
+ )
549
+ self.child = child
550
+ self.child.parent = self
551
+ self.max_attempts = max(1, int(max_attempts))
552
+ self.retry_on = retry_on
553
+ self._attempt = 0
554
+
555
+ def reset(self) -> None:
556
+ super().reset()
557
+ self._attempt = 0
558
+ self.child.reset()
559
+
560
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
561
+ await self._ensure_entered(bb, tb)
562
+ st = await self.child.tick(bb, tb)
563
+ if st is Status.RUNNING:
564
+ return Status.RUNNING
565
+ if st is Status.SUCCESS:
566
+ self._attempt = 0
567
+ return await self._finish(bb, Status.SUCCESS, tb)
568
+ self._attempt += 1
569
+ if st in self.retry_on and self._attempt < self.max_attempts:
570
+ self.child.reset()
571
+ return Status.RUNNING
572
+ self._attempt = 0
573
+ return await self._finish(bb, Status.FAILURE, tb)
574
+
575
+
576
+ class Timeout(Node[BB]):
577
+ def __init__(
578
+ self,
579
+ child: Node[BB],
580
+ *,
581
+ seconds: float,
582
+ name: Optional[str] = None,
583
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
584
+ ) -> None:
585
+ super().__init__(
586
+ name or f"Timeout({_name_of(child)},{seconds}s)",
587
+ exception_policy=exception_policy,
588
+ )
589
+ self.child = child
590
+ self.child.parent = self
591
+ self.seconds = max(0.0, float(seconds))
592
+ self._deadline: Optional[float] = None
593
+
594
+ def reset(self) -> None:
595
+ super().reset()
596
+ self._deadline = None
597
+ self.child.reset()
598
+
599
+ async def on_enter(self, bb: BB, tb: Timebase) -> None:
600
+ self._deadline = tb.now() + self.seconds
601
+
602
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
603
+ await self._ensure_entered(bb, tb)
604
+ if self._deadline is None:
605
+ self._deadline = tb.now() + self.seconds
606
+ if tb.now() > self._deadline:
607
+ self.child.reset()
608
+ return await self._finish(bb, Status.FAILURE, tb)
609
+ st = await self.child.tick(bb, tb)
610
+ if st is Status.RUNNING:
611
+ return Status.RUNNING
612
+ return await self._finish(bb, st, tb)
613
+
614
+
615
+ class Succeeder(Node[BB]):
616
+ def __init__(
617
+ self,
618
+ child: Node[BB],
619
+ name: Optional[str] = None,
620
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
621
+ ) -> None:
622
+ super().__init__(
623
+ name or f"Succeeder({_name_of(child)})", exception_policy=exception_policy
624
+ )
625
+ self.child = child
626
+ self.child.parent = self
627
+
628
+ def reset(self) -> None:
629
+ super().reset()
630
+ self.child.reset()
631
+
632
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
633
+ await self._ensure_entered(bb, tb)
634
+ st = await self.child.tick(bb, tb)
635
+ if st is Status.RUNNING:
636
+ return Status.RUNNING
637
+ return await self._finish(bb, Status.SUCCESS, tb)
638
+
639
+
640
+ class Failer(Node[BB]):
641
+ def __init__(
642
+ self,
643
+ child: Node[BB],
644
+ name: Optional[str] = None,
645
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
646
+ ) -> None:
647
+ super().__init__(
648
+ name or f"Failer({_name_of(child)})", exception_policy=exception_policy
649
+ )
650
+ self.child = child
651
+ self.child.parent = self
652
+
653
+ def reset(self) -> None:
654
+ super().reset()
655
+ self.child.reset()
656
+
657
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
658
+ await self._ensure_entered(bb, tb)
659
+ st = await self.child.tick(bb, tb)
660
+ if st is Status.RUNNING:
661
+ return Status.RUNNING
662
+ return await self._finish(bb, Status.FAILURE, tb)
663
+
664
+
665
+ class RateLimit(Node[BB]):
666
+ """
667
+ Throttle *starting* the child to at most 1 per period (or hz).
668
+ If child is RUNNING, do not throttle (avoid starvation).
669
+ """
670
+
671
+ def __init__(
672
+ self,
673
+ child: Node[BB],
674
+ *,
675
+ hz: Optional[float] = None,
676
+ period: Optional[float] = None,
677
+ name: Optional[str] = None,
678
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
679
+ ) -> None:
680
+ if (hz is not None) and (period is not None):
681
+ raise ValueError("RateLimit requires exactly one of (hz, period)")
682
+
683
+ if period is not None:
684
+ per = period
685
+ elif hz is not None:
686
+ per = 1.0 / float(hz)
687
+ else:
688
+ raise ValueError("RateLimit requires exactly one of (hz, period)")
689
+
690
+ super().__init__(
691
+ name or f"RateLimit({_name_of(child)},{per:.6f}s)",
692
+ exception_policy=exception_policy,
693
+ )
694
+ self.child = child
695
+ self.child.parent = self
696
+ self._period = max(0.0, per)
697
+ self._next_allowed: Optional[float] = None
698
+ self._last: Optional[Status] = None
699
+
700
+ def reset(self) -> None:
701
+ super().reset()
702
+ self._next_allowed = None
703
+ self._last = None
704
+ self.child.reset()
705
+
706
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
707
+ await self._ensure_entered(bb, tb)
708
+ if self._last is Status.RUNNING:
709
+ st = await self.child.tick(bb, tb)
710
+ self._last = st
711
+ if st is Status.RUNNING:
712
+ return Status.RUNNING
713
+ self._next_allowed = tb.now() + self._period
714
+ return await self._finish(bb, st, tb)
715
+
716
+ now = tb.now()
717
+ if self._next_allowed is not None and now < self._next_allowed:
718
+ return Status.RUNNING
719
+ st = await self.child.tick(bb, tb)
720
+ self._last = st
721
+ if st is Status.RUNNING:
722
+ return Status.RUNNING
723
+ self._next_allowed = now + self._period
724
+ return await self._finish(bb, st, tb)
725
+
726
+
727
+ class Gate(Node[BB]):
728
+ """
729
+ Guard a child with a condition:
730
+ SUCCESS → tick child
731
+ RUNNING → RUNNING
732
+ else → FAILURE
733
+ """
734
+
735
+ def __init__(
736
+ self,
737
+ condition: Node[BB],
738
+ child: Node[BB],
739
+ name: Optional[str] = None,
740
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
741
+ ) -> None:
742
+ super().__init__(
743
+ name or f"Gate(cond={_name_of(condition)}, child={_name_of(child)})",
744
+ exception_policy=exception_policy,
745
+ )
746
+ self.condition = condition
747
+ self.child = child
748
+ self.condition.parent = self
749
+ self.child.parent = self
750
+
751
+ def reset(self) -> None:
752
+ super().reset()
753
+ self.condition.reset()
754
+ self.child.reset()
755
+
756
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
757
+ await self._ensure_entered(bb, tb)
758
+ c = await self.condition.tick(bb, tb)
759
+ if c is Status.RUNNING:
760
+ return Status.RUNNING
761
+ if c is not Status.SUCCESS:
762
+ return await self._finish(bb, Status.FAILURE, tb)
763
+ st = await self.child.tick(bb, tb)
764
+ if st is Status.RUNNING:
765
+ return Status.RUNNING
766
+ return await self._finish(bb, st, tb)
767
+
768
+
769
+ class Match(Node[BB]):
770
+ """
771
+ Pattern-matching dispatch node.
772
+
773
+ Evaluates a key function against the blackboard, then checks each case
774
+ in order. The first matching case's child is executed. If the child
775
+ completes (SUCCESS or FAILURE), that status is returned immediately.
776
+
777
+ Cases can match by:
778
+ - Type: isinstance(value, case_type)
779
+ - Predicate: case_predicate(value) returns True
780
+ - Value: value == case_value
781
+ - Default: always matches (should be last)
782
+
783
+ If no case matches and there's no default, returns FAILURE.
784
+ """
785
+
786
+ def __init__(
787
+ self,
788
+ key_fn: Callable[[Any], Any],
789
+ cases: List[Tuple[Any, Node[BB]]],
790
+ name: Optional[str] = None,
791
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
792
+ ) -> None:
793
+ super().__init__(
794
+ name or f"Match({_name_of(key_fn)})",
795
+ exception_policy=exception_policy,
796
+ )
797
+ self._key_fn = key_fn
798
+ self._cases = cases
799
+ for _, child in self._cases:
800
+ child.parent = self
801
+ self._matched_idx: Optional[int] = None
802
+
803
+ def reset(self) -> None:
804
+ super().reset()
805
+ self._matched_idx = None
806
+ for _, child in self._cases:
807
+ child.reset()
808
+
809
+ def _matches(self, matcher: Any, value: Any) -> bool:
810
+ """Check if a matcher matches the given value."""
811
+ if matcher is _DefaultCase:
812
+ return True
813
+ if isinstance(matcher, type):
814
+ return isinstance(value, matcher)
815
+ if callable(matcher):
816
+ return bool(matcher(value))
817
+ return value == matcher
818
+
819
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
820
+ await self._ensure_entered(bb, tb)
821
+
822
+ if self._matched_idx is not None:
823
+ _, child = self._cases[self._matched_idx]
824
+ st = await child.tick(bb, tb)
825
+ if st is Status.RUNNING:
826
+ return Status.RUNNING
827
+ self._matched_idx = None
828
+ return await self._finish(bb, st, tb)
829
+
830
+ value = await _call_node_function(self._key_fn, bb, tb)
831
+
832
+ for i, (matcher, child) in enumerate(self._cases):
833
+ if self._matches(matcher, value):
834
+ st = await child.tick(bb, tb)
835
+ if st is Status.RUNNING:
836
+ self._matched_idx = i
837
+ return Status.RUNNING
838
+ return await self._finish(bb, st, tb)
839
+
840
+ return await self._finish(bb, Status.FAILURE, tb)
841
+
842
+
843
+ class DoWhile(Node[BB]):
844
+ """
845
+ Loop decorator that repeats its child while a condition is true.
846
+
847
+ Behavior:
848
+ 1. Evaluate condition
849
+ 2. If condition is FALSE → return SUCCESS (loop complete)
850
+ 3. If condition is TRUE → tick child
851
+ - If child returns RUNNING → return RUNNING (resume child next tick)
852
+ - If child returns SUCCESS → reset child, return RUNNING (re-check condition next tick)
853
+ - If child returns FAILURE → return FAILURE (loop aborted)
854
+
855
+ The "return RUNNING after child SUCCESS" prevents infinite loops within a single tick.
856
+ """
857
+
858
+ def __init__(
859
+ self,
860
+ condition: Node[BB],
861
+ child: Node[BB],
862
+ name: Optional[str] = None,
863
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
864
+ ) -> None:
865
+ super().__init__(
866
+ name or f"DoWhile(cond={_name_of(condition)}, child={_name_of(child)})",
867
+ exception_policy=exception_policy,
868
+ )
869
+ self.condition = condition
870
+ self.child = child
871
+ self.condition.parent = self
872
+ self.child.parent = self
873
+ self._child_running = False
874
+
875
+ def reset(self) -> None:
876
+ super().reset()
877
+ self.condition.reset()
878
+ self.child.reset()
879
+ self._child_running = False
880
+
881
+ async def tick(self, bb: BB, tb: Timebase) -> Status:
882
+ await self._ensure_entered(bb, tb)
883
+
884
+ # If child was RUNNING, continue it without re-checking condition
885
+ if self._child_running:
886
+ st = await self.child.tick(bb, tb)
887
+ if st is Status.RUNNING:
888
+ return Status.RUNNING
889
+ self._child_running = False
890
+ if st is Status.SUCCESS:
891
+ # Child completed successfully, reset and loop (next tick)
892
+ self.child.reset()
893
+ self.condition.reset()
894
+ return Status.RUNNING
895
+ # Child failed, abort loop
896
+ return await self._finish(bb, Status.FAILURE, tb)
897
+
898
+ # Check condition
899
+ cond_status = await self.condition.tick(bb, tb)
900
+ if cond_status is Status.RUNNING:
901
+ return Status.RUNNING
902
+
903
+ if cond_status is not Status.SUCCESS:
904
+ # Condition is false, loop complete
905
+ return await self._finish(bb, Status.SUCCESS, tb)
906
+
907
+ # Condition is true, tick child
908
+ st = await self.child.tick(bb, tb)
909
+ if st is Status.RUNNING:
910
+ self._child_running = True
911
+ return Status.RUNNING
912
+ if st is Status.SUCCESS:
913
+ # Child completed successfully, reset and loop (next tick)
914
+ self.child.reset()
915
+ self.condition.reset()
916
+ return Status.RUNNING
917
+ # Child failed, abort loop
918
+ return await self._finish(bb, Status.FAILURE, tb)
919
+
920
+
921
+ # ======================================================================================
922
+ # Authoring DSL (NodeSpec + bt namespace)
923
+ # ======================================================================================
924
+
925
+
926
+ class NodeSpecKind(Enum):
927
+ ACTION = "action"
928
+ CONDITION = "condition"
929
+ SEQUENCE = "sequence"
930
+ SELECTOR = "selector"
931
+ PARALLEL = "parallel"
932
+ DECORATOR = "decorator"
933
+ SUBTREE = "subtree"
934
+ MATCH = "match"
935
+ DO_WHILE = "do_while"
936
+
937
+
938
+ class _DefaultCase:
939
+ """Sentinel for default case in match expressions."""
940
+ pass
941
+
942
+
943
+ @dataclass
944
+ class CaseSpec:
945
+ """Specification for a case in a match expression."""
946
+ matcher: Any
947
+ child: "NodeSpec"
948
+ label: str = ""
949
+
950
+ def __post_init__(self):
951
+ if not self.label:
952
+ if self.matcher is _DefaultCase:
953
+ self.label = "default"
954
+ elif isinstance(self.matcher, type):
955
+ self.label = self.matcher.__name__
956
+ elif callable(self.matcher):
957
+ self.label = _name_of(self.matcher)
958
+ else:
959
+ self.label = repr(self.matcher)
960
+
961
+
962
+ @dataclass
963
+ class NodeSpec:
964
+ kind: NodeSpecKind
965
+ name: str
966
+ payload: Any = None
967
+ children: List["NodeSpec"] = field(default_factory=list)
968
+
969
+ def __hash__(self):
970
+ return hash((self.kind, self.name, id(self.payload), tuple(self.children)))
971
+
972
+ def to_node(
973
+ self,
974
+ owner: Optional[Any] = None,
975
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
976
+ ) -> Node[Any]:
977
+ match self.kind:
978
+ case NodeSpecKind.ACTION:
979
+ return Action(self.payload, exception_policy=exception_policy)
980
+ case NodeSpecKind.CONDITION:
981
+ return Condition(self.payload, exception_policy=exception_policy)
982
+ case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
983
+ owner_effective = getattr(self, "owner", None) or owner
984
+ factory = self.payload["factory"]
985
+ expanded = _bt_expand_children(factory, owner_effective)
986
+ self.children = expanded
987
+ built = [
988
+ ch.to_node(owner_effective, exception_policy) for ch in expanded
989
+ ]
990
+ match self.kind:
991
+ case NodeSpecKind.SEQUENCE:
992
+ return Sequence(
993
+ built,
994
+ memory=self.payload.get("memory", True),
995
+ name=self.name,
996
+ exception_policy=exception_policy,
997
+ )
998
+ case NodeSpecKind.SELECTOR:
999
+ return Selector(
1000
+ built,
1001
+ memory=self.payload.get("memory", True),
1002
+ name=self.name,
1003
+ exception_policy=exception_policy,
1004
+ )
1005
+ case NodeSpecKind.PARALLEL:
1006
+ return Parallel(
1007
+ built,
1008
+ success_threshold=self.payload["success_threshold"],
1009
+ failure_threshold=self.payload.get("failure_threshold"),
1010
+ name=self.name,
1011
+ exception_policy=exception_policy,
1012
+ )
1013
+
1014
+ case NodeSpecKind.DECORATOR:
1015
+ assert len(self.children) == 1, "Decorator must wrap exactly one child"
1016
+ child_node = self.children[0].to_node(owner, exception_policy)
1017
+ builder = self.payload
1018
+ return builder(child_node)
1019
+
1020
+ case NodeSpecKind.SUBTREE:
1021
+ subtree_root = self.payload["root"]
1022
+ return subtree_root.to_node(exception_policy=exception_policy)
1023
+
1024
+ case NodeSpecKind.MATCH:
1025
+ key_fn = self.payload["key_fn"]
1026
+ case_specs: List[CaseSpec] = self.payload["cases"]
1027
+ cases = [
1028
+ (cs.matcher, cs.child.to_node(owner, exception_policy))
1029
+ for cs in case_specs
1030
+ ]
1031
+ return Match(
1032
+ key_fn,
1033
+ cases,
1034
+ name=self.name,
1035
+ exception_policy=exception_policy,
1036
+ )
1037
+
1038
+ case NodeSpecKind.DO_WHILE:
1039
+ cond_spec = self.payload["condition"]
1040
+ child_spec = self.children[0]
1041
+ return DoWhile(
1042
+ cond_spec.to_node(owner, exception_policy),
1043
+ child_spec.to_node(owner, exception_policy),
1044
+ name=self.name,
1045
+ exception_policy=exception_policy,
1046
+ )
1047
+
1048
+ case _:
1049
+ raise ValueError(f"Unknown spec kind: {self.kind}")
1050
+
1051
+
1052
+ def _bt_expand_children(
1053
+ factory: Callable[..., Generator[Any, None, None]],
1054
+ owner: Optional[Any],
1055
+ expansion_stack: Optional[Set[str]] = None,
1056
+ ) -> List[NodeSpec]:
1057
+ """
1058
+ Execute a composite factory to get child specs.
1059
+
1060
+ Args:
1061
+ factory: The generator function that yields child specs
1062
+ owner: The namespace object for resolving N references
1063
+ expansion_stack: Stack of factory names to detect recursion
1064
+ """
1065
+ if expansion_stack is None:
1066
+ expansion_stack = set()
1067
+
1068
+ factory_name = _name_of(factory)
1069
+ if factory_name in expansion_stack:
1070
+ chain = " -> ".join(expansion_stack) + f" -> {factory_name}"
1071
+ raise RecursionError(
1072
+ f"Recursive behavior tree structure detected: {chain}\n"
1073
+ f"Behavior trees must be acyclic. Consider using:\n"
1074
+ f" - A Selector with memory to iterate through options\n"
1075
+ f" - A Sequence with conditions to control flow\n"
1076
+ f" - A Retry decorator for repeated attempts\n"
1077
+ f" - State in the blackboard to track progress"
1078
+ )
1079
+
1080
+ expansion_stack = expansion_stack.copy()
1081
+ expansion_stack.add(factory_name)
1082
+
1083
+ try:
1084
+ gen = factory(owner)
1085
+ except TypeError:
1086
+ gen = factory()
1087
+
1088
+ if not inspect.isgenerator(gen):
1089
+ raise TypeError(
1090
+ f"Composite factory {factory.__name__} must be a generator (use 'yield')."
1091
+ )
1092
+
1093
+ out: List[NodeSpec] = []
1094
+ for yielded in gen:
1095
+ if isinstance(yielded, (list, tuple)):
1096
+ for y in yielded:
1097
+ spec = bt.as_spec(y)
1098
+ if (
1099
+ hasattr(spec, "payload")
1100
+ and isinstance(spec.payload, dict)
1101
+ and "factory" in spec.payload
1102
+ ):
1103
+ spec._expansion_stack = expansion_stack # type: ignore
1104
+ out.append(spec)
1105
+ continue
1106
+ spec = bt.as_spec(yielded)
1107
+ if (
1108
+ hasattr(spec, "payload")
1109
+ and isinstance(spec.payload, dict)
1110
+ and "factory" in spec.payload
1111
+ ):
1112
+ spec._expansion_stack = expansion_stack # type: ignore
1113
+ out.append(spec)
1114
+
1115
+ for spec in out:
1116
+ if spec.kind in (
1117
+ NodeSpecKind.SEQUENCE,
1118
+ NodeSpecKind.SELECTOR,
1119
+ NodeSpecKind.PARALLEL,
1120
+ ) and hasattr(spec, "_expansion_stack"):
1121
+ child_factory = spec.payload["factory"]
1122
+ child_owner = getattr(spec, "owner", None) or owner
1123
+ _bt_expand_children(child_factory, child_owner, spec._expansion_stack) # type: ignore
1124
+
1125
+ return out
1126
+
1127
+
1128
+ # --------------------------------------------------------------------------------------
1129
+ # Fluent wrapper chain (left-to-right)
1130
+ # --------------------------------------------------------------------------------------
1131
+
1132
+
1133
+ class _WrapperChain:
1134
+ """
1135
+ Fluent factory for decorator stacks that read left→right.
1136
+
1137
+ chain = bt.failer().gate(battery_ok).timeout(0.12)
1138
+ yield chain(engage)
1139
+ """
1140
+
1141
+ def __init__(
1142
+ self,
1143
+ builders: Optional[List[Callable[..., Any]]] = None,
1144
+ labels: Optional[List[str]] = None,
1145
+ ) -> None:
1146
+ self._builders: List[Callable[..., Any]] = list(builders or [])
1147
+ self._labels: List[str] = list(labels or [])
1148
+
1149
+ def _append(self, label: str, builder: Callable[..., Any]) -> "_WrapperChain":
1150
+ self._builders.append(builder)
1151
+ self._labels.append(label)
1152
+ return self
1153
+
1154
+ def failer(self) -> "_WrapperChain":
1155
+ return self._append("Failer", lambda ch: Failer(ch))
1156
+
1157
+ def succeeder(self) -> "_WrapperChain":
1158
+ return self._append("Succeeder", lambda ch: Succeeder(ch))
1159
+
1160
+ def inverter(self) -> "_WrapperChain":
1161
+ return self._append("Inverter", lambda ch: Inverter(ch))
1162
+
1163
+ def timeout(self, seconds: float) -> "_WrapperChain":
1164
+ return self._append(
1165
+ f"Timeout({seconds}s)", lambda ch: Timeout(ch, seconds=seconds)
1166
+ )
1167
+
1168
+ def retry(
1169
+ self,
1170
+ max_attempts: int,
1171
+ retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
1172
+ ) -> "_WrapperChain":
1173
+ return self._append(
1174
+ f"Retry({max_attempts})",
1175
+ lambda ch: Retry(ch, max_attempts=max_attempts, retry_on=retry_on),
1176
+ )
1177
+
1178
+ def ratelimit(
1179
+ self, *, hz: Optional[float] = None, period: Optional[float] = None
1180
+ ) -> "_WrapperChain":
1181
+ label = "RateLimit(?)"
1182
+ if hz is not None:
1183
+ label = f"RateLimit({1.0/float(hz):.6f}s)"
1184
+ elif period is not None:
1185
+ label = f"RateLimit({float(period):.6f}s)"
1186
+ return self._append(label, lambda ch: RateLimit(ch, hz=hz, period=period))
1187
+
1188
+ def gate(
1189
+ self, condition_spec_or_fn: Union["NodeSpec", Callable[[Any], Any]]
1190
+ ) -> "_WrapperChain":
1191
+ cond_spec = bt.as_spec(condition_spec_or_fn)
1192
+ return self._append(
1193
+ f"Gate(cond={_name_of(cond_spec)})",
1194
+ lambda ch: Gate(cond_spec.to_node(), ch),
1195
+ )
1196
+
1197
+ def __call__(self, inner: Union["NodeSpec", Callable[[Any], Any]]) -> "NodeSpec":
1198
+ """
1199
+ Apply the chain to a child spec → nested decorator NodeSpecs.
1200
+ Left→right call order becomes outermost→…→innermost when built.
1201
+ """
1202
+ spec = bt.as_spec(inner)
1203
+ result = spec
1204
+ for label, builder in reversed(list(zip(self._labels, self._builders))):
1205
+ result = NodeSpec(
1206
+ kind=NodeSpecKind.DECORATOR,
1207
+ name=f"{label}({_name_of(result)})",
1208
+ payload=builder,
1209
+ children=[result],
1210
+ )
1211
+ return result
1212
+
1213
+ def __rshift__(self, inner: Union["NodeSpec", Callable[[Any], Any]]) -> "NodeSpec":
1214
+ return self(inner)
1215
+
1216
+
1217
+ # --------------------------------------------------------------------------------------
1218
+ # Match expression builders
1219
+ # --------------------------------------------------------------------------------------
1220
+
1221
+
1222
+ class _CaseBuilder:
1223
+ """Builder for individual match cases."""
1224
+
1225
+ def __init__(self, matcher: Any) -> None:
1226
+ self._matcher = matcher
1227
+
1228
+ def __call__(self, child: Union["NodeSpec", Callable[[Any], Any]]) -> CaseSpec:
1229
+ child_spec = bt.as_spec(child)
1230
+ return CaseSpec(matcher=self._matcher, child=child_spec)
1231
+
1232
+
1233
+ class _MatchBuilder:
1234
+ """Builder for match expressions."""
1235
+
1236
+ def __init__(self, key_fn: Callable[[Any], Any], name: Optional[str] = None) -> None:
1237
+ self._key_fn = key_fn
1238
+ self._name = name
1239
+
1240
+ def __call__(self, *cases: CaseSpec) -> NodeSpec:
1241
+ if not cases:
1242
+ raise ValueError("bt.match() requires at least one case")
1243
+
1244
+ for case in cases:
1245
+ if not isinstance(case, CaseSpec):
1246
+ raise TypeError(
1247
+ f"bt.match() expects CaseSpec instances (from bt.case() or bt.defaultcase()), "
1248
+ f"got {type(case).__name__}"
1249
+ )
1250
+
1251
+ children = [case.child for case in cases]
1252
+
1253
+ display_name = self._name or _name_of(self._key_fn)
1254
+
1255
+ return NodeSpec(
1256
+ kind=NodeSpecKind.MATCH,
1257
+ name=f"Match({display_name})",
1258
+ payload={
1259
+ "key_fn": self._key_fn,
1260
+ "cases": list(cases),
1261
+ },
1262
+ children=children,
1263
+ )
1264
+
1265
+
1266
+ class _DoWhileBuilder:
1267
+ """Builder for do_while loops."""
1268
+
1269
+ def __init__(self, condition_spec: NodeSpec) -> None:
1270
+ self._condition_spec = condition_spec
1271
+
1272
+ def __call__(self, child: Union["NodeSpec", Callable[[Any], Any]]) -> NodeSpec:
1273
+ child_spec = bt.as_spec(child)
1274
+ return NodeSpec(
1275
+ kind=NodeSpecKind.DO_WHILE,
1276
+ name=f"DoWhile({_name_of(self._condition_spec)})",
1277
+ payload={
1278
+ "condition": self._condition_spec,
1279
+ },
1280
+ children=[child_spec],
1281
+ )
1282
+
1283
+
1284
+ # --------------------------------------------------------------------------------------
1285
+ # User-facing decorator/constructor namespace
1286
+ # --------------------------------------------------------------------------------------
1287
+
1288
+
1289
+ class _BT:
1290
+ """User-facing decorator/constructor namespace."""
1291
+
1292
+ # TODO: Fix decorators to not throw a fit when passing timebase arguments to actions/conditions
1293
+ def __init__(self):
1294
+ self._tracking_stack: List[List[Tuple[str, Any]]] = []
1295
+
1296
+ def action(self, fn: Callable[..., Any]) -> Callable[..., Any]:
1297
+ """Decorator to mark a function as an action node."""
1298
+ spec = NodeSpec(kind=NodeSpecKind.ACTION, name=_name_of(fn), payload=fn)
1299
+ fn.node_spec = spec # type: ignore
1300
+ if self._tracking_stack:
1301
+ self._tracking_stack[-1].append((fn.__name__, fn))
1302
+ return fn
1303
+
1304
+ def condition(self, fn: Callable[..., Any]) -> Callable[..., Any]:
1305
+ """Decorator to mark a function as a condition node."""
1306
+ spec = NodeSpec(kind=NodeSpecKind.CONDITION, name=_name_of(fn), payload=fn)
1307
+ fn.node_spec = spec # type: ignore
1308
+ if self._tracking_stack:
1309
+ self._tracking_stack[-1].append((fn.__name__, fn))
1310
+ return fn
1311
+
1312
+ def sequence(self, *, memory: bool = True):
1313
+ """Decorator to mark a generator function as a sequence composite."""
1314
+
1315
+ def deco(
1316
+ factory: Callable[..., Generator[Any, None, None]],
1317
+ ) -> Callable[..., Generator[Any, None, None]]:
1318
+ spec = NodeSpec(
1319
+ kind=NodeSpecKind.SEQUENCE,
1320
+ name=_name_of(factory),
1321
+ payload={"factory": factory, "memory": memory},
1322
+ )
1323
+ factory.node_spec = spec # type: ignore
1324
+ if self._tracking_stack:
1325
+ self._tracking_stack[-1].append((factory.__name__, factory))
1326
+ return factory
1327
+
1328
+ return deco
1329
+
1330
+ def selector(self, *, memory: bool = True, reactive: bool = False):
1331
+ """Decorator to mark a generator function as a selector composite."""
1332
+
1333
+ def deco(
1334
+ factory: Callable[..., Generator[Any, None, None]],
1335
+ ) -> Callable[..., Generator[Any, None, None]]:
1336
+ spec = NodeSpec(
1337
+ kind=NodeSpecKind.SELECTOR,
1338
+ name=_name_of(factory),
1339
+ payload={"factory": factory, "memory": memory, "reactive": reactive},
1340
+ )
1341
+ factory.node_spec = spec # type: ignore
1342
+ if self._tracking_stack:
1343
+ self._tracking_stack[-1].append((factory.__name__, factory))
1344
+ return factory
1345
+
1346
+ return deco
1347
+
1348
+ def parallel(
1349
+ self, *, success_threshold: int, failure_threshold: Optional[int] = None
1350
+ ):
1351
+ """Decorator to mark a generator function as a parallel composite."""
1352
+
1353
+ def deco(
1354
+ factory: Callable[..., Generator[Any, None, None]],
1355
+ ) -> Callable[..., Generator[Any, None, None]]:
1356
+ spec = NodeSpec(
1357
+ kind=NodeSpecKind.PARALLEL,
1358
+ name=_name_of(factory),
1359
+ payload={
1360
+ "factory": factory,
1361
+ "success_threshold": success_threshold,
1362
+ "failure_threshold": failure_threshold,
1363
+ },
1364
+ )
1365
+ factory.node_spec = spec # type: ignore
1366
+ if self._tracking_stack:
1367
+ self._tracking_stack[-1].append((factory.__name__, factory))
1368
+ return factory
1369
+
1370
+ return deco
1371
+
1372
+ def inverter(self) -> _WrapperChain:
1373
+ return _WrapperChain().inverter()
1374
+
1375
+ def succeeder(self) -> _WrapperChain:
1376
+ return _WrapperChain().succeeder()
1377
+
1378
+ def failer(self) -> _WrapperChain:
1379
+ return _WrapperChain().failer()
1380
+
1381
+ def timeout(self, seconds: float) -> _WrapperChain:
1382
+ return _WrapperChain().timeout(seconds)
1383
+
1384
+ def retry(
1385
+ self,
1386
+ max_attempts: int,
1387
+ retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
1388
+ ) -> _WrapperChain:
1389
+ return _WrapperChain().retry(max_attempts, retry_on=retry_on)
1390
+
1391
+ def ratelimit(
1392
+ self, hz: Optional[float] = None, period: Optional[float] = None
1393
+ ) -> _WrapperChain:
1394
+ return _WrapperChain().ratelimit(hz=hz, period=period)
1395
+
1396
+ def gate(self, condition: Union[NodeSpec, Callable[[Any], Any]]) -> _WrapperChain:
1397
+ cond_spec = self.as_spec(condition)
1398
+ return _WrapperChain().gate(cond_spec)
1399
+
1400
+ def match(
1401
+ self, key_fn: Callable[[Any], Any], name: Optional[str] = None
1402
+ ) -> "_MatchBuilder":
1403
+ """
1404
+ Create a pattern-matching dispatch node.
1405
+
1406
+ Usage:
1407
+ bt.match(lambda bb: bb.current_action, name="action_type")(
1408
+ bt.case(ImageAction)(bt.subtree(handle_image)),
1409
+ bt.case(MoveAction)(bt.subtree(handle_move)),
1410
+ bt.case(lambda a: a.priority > 5)(bt.subtree(handle_urgent)),
1411
+ bt.defaultcase(log_unknown),
1412
+ )
1413
+
1414
+ Args:
1415
+ key_fn: Function that extracts the value to match against from the blackboard
1416
+ name: Optional display name for the match node (useful when key_fn is a lambda)
1417
+
1418
+ Returns:
1419
+ A builder that accepts case specs and returns a NodeSpec
1420
+ """
1421
+ return _MatchBuilder(key_fn, name=name)
1422
+
1423
+ def case(self, matcher: Any) -> "_CaseBuilder":
1424
+ """
1425
+ Define a case for a match expression.
1426
+
1427
+ Args:
1428
+ matcher: Can be:
1429
+ - A type: matches via isinstance(value, matcher)
1430
+ - A callable: matches if matcher(value) returns True
1431
+ - Any other value: matches via value == matcher
1432
+
1433
+ Returns:
1434
+ A builder that accepts a child node spec
1435
+ """
1436
+ return _CaseBuilder(matcher)
1437
+
1438
+ def defaultcase(self, child: Union[NodeSpec, Callable[[Any], Any]]) -> CaseSpec:
1439
+ """
1440
+ Define a default case for a match expression (matches anything).
1441
+
1442
+ Args:
1443
+ child: The node to execute if this case matches
1444
+
1445
+ Returns:
1446
+ A CaseSpec that always matches
1447
+ """
1448
+ child_spec = self.as_spec(child)
1449
+ return CaseSpec(matcher=_DefaultCase, child=child_spec, label="default")
1450
+
1451
+ def do_while(
1452
+ self, condition: Union[NodeSpec, Callable[[Any], Any]]
1453
+ ) -> "_DoWhileBuilder":
1454
+ """
1455
+ Create a loop that repeats its child while a condition is true.
1456
+
1457
+ Usage:
1458
+ @bt.condition
1459
+ def samples_remain(bb):
1460
+ return bb.sample_index < bb.total_samples
1461
+
1462
+ yield bt.do_while(samples_remain)(process_sample)
1463
+
1464
+ Behavior:
1465
+ 1. Evaluate condition
1466
+ 2. If condition is FALSE → return SUCCESS (loop complete)
1467
+ 3. If condition is TRUE → tick child
1468
+ - If child returns RUNNING → return RUNNING (resume child next tick)
1469
+ - If child returns SUCCESS → reset child, return RUNNING (re-check next tick)
1470
+ - If child returns FAILURE → return FAILURE (loop aborted)
1471
+
1472
+ Args:
1473
+ condition: A condition node or function to evaluate each iteration
1474
+
1475
+ Returns:
1476
+ A builder that accepts a child node spec
1477
+ """
1478
+ cond_spec = self.as_spec(condition)
1479
+ return _DoWhileBuilder(cond_spec)
1480
+
1481
+ def subtree(self, tree: SimpleNamespace) -> NodeSpec:
1482
+ """
1483
+ Mount another tree's root spec as a subtree.
1484
+
1485
+ Args:
1486
+ tree: A tree namespace created with @bt.tree
1487
+ """
1488
+ if not hasattr(tree, "root"):
1489
+ raise ValueError(
1490
+ f"Tree namespace must have a 'root' attribute. Did you forget @bt.root on a composite?"
1491
+ )
1492
+
1493
+ root_spec = tree.root
1494
+ tree_name = getattr(tree, "_tree_name", root_spec.name)
1495
+ return NodeSpec(
1496
+ kind=NodeSpecKind.SUBTREE,
1497
+ name=tree_name,
1498
+ payload={"root": root_spec},
1499
+ children=[root_spec],
1500
+ )
1501
+
1502
+ def tree(self, fn: Callable[[], Any]) -> SimpleNamespace:
1503
+ """
1504
+ Decorator to create a behavior tree namespace.
1505
+
1506
+ Usage:
1507
+ @bt.tree
1508
+ def MyTree():
1509
+ @bt.action
1510
+ def my_action(bb: BB) -> Status:
1511
+ ...
1512
+
1513
+ @bt.sequence()
1514
+ def root(N):
1515
+ yield N.my_action
1516
+ """
1517
+ created_nodes = []
1518
+ self._tracking_stack.append(created_nodes)
1519
+
1520
+ try:
1521
+ fn()
1522
+ finally:
1523
+ self._tracking_stack.pop()
1524
+
1525
+ nodes = {
1526
+ name: node for name, node in created_nodes if hasattr(node, "node_spec")
1527
+ }
1528
+ namespace = SimpleNamespace(**nodes)
1529
+
1530
+ # Store the tree's name for use in subtree references
1531
+ namespace._tree_name = fn.__name__
1532
+
1533
+ for name, node in nodes.items():
1534
+ if hasattr(node, "node_spec"):
1535
+ node.node_spec.owner = namespace
1536
+
1537
+ root_nodes = [v for v in nodes.values() if hasattr(v.node_spec, "is_root")]
1538
+ if root_nodes:
1539
+ namespace.root = root_nodes[0].node_spec
1540
+
1541
+ namespace.to_mermaid = lambda: _generate_mermaid(namespace)
1542
+
1543
+ return namespace
1544
+
1545
+ def root(
1546
+ self, fn: Callable[..., Generator[Any, None, None]]
1547
+ ) -> Callable[..., Generator[Any, None, None]]:
1548
+ """Mark a composite as the root of the tree."""
1549
+ fn.node_spec.is_root = True # type: ignore
1550
+ return fn
1551
+
1552
+ def as_spec(self, maybe: Union[NodeSpec, Callable[[Any], Any]]) -> NodeSpec:
1553
+ """Convert a function or NodeSpec to a NodeSpec."""
1554
+ if isinstance(maybe, NodeSpec):
1555
+ return maybe
1556
+
1557
+ spec = getattr(maybe, "node_spec", None)
1558
+ if spec is None:
1559
+ raise TypeError(
1560
+ f"{maybe!r} is not a BT node (missing node_spec attribute)."
1561
+ )
1562
+ return spec
1563
+
1564
+
1565
+ bt = _BT()
1566
+
1567
+
1568
+ # ======================================================================================
1569
+ # Mermaid Generation
1570
+ # ======================================================================================
1571
+
1572
+
1573
+ def _generate_mermaid(tree: SimpleNamespace) -> str:
1574
+ """
1575
+ Render a static structure graph for a tree namespace.
1576
+ """
1577
+ if not hasattr(tree, "root"):
1578
+ raise ValueError("Tree namespace must have a 'root' attribute")
1579
+
1580
+ lines: List[str] = ["flowchart TD"]
1581
+ node_ids: Dict[NodeSpec, str] = {}
1582
+ counter = 0
1583
+
1584
+ def nid(spec: NodeSpec) -> str:
1585
+ nonlocal counter
1586
+ if spec in node_ids:
1587
+ return node_ids[spec]
1588
+ counter += 1
1589
+ node_ids[spec] = f"N{counter}"
1590
+ return node_ids[spec]
1591
+
1592
+ def label(spec: NodeSpec) -> str:
1593
+ match spec.kind:
1594
+ case NodeSpecKind.ACTION | NodeSpecKind.CONDITION:
1595
+ return f"{spec.kind.value.upper()}<br/>{spec.name}"
1596
+ case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
1597
+ return f"{spec.kind.value.capitalize()}<br/>{spec.name}"
1598
+ case NodeSpecKind.DECORATOR:
1599
+ return f"Decor<br/>{spec.name}"
1600
+ case NodeSpecKind.SUBTREE:
1601
+ return f"Subtree<br/>{spec.name}"
1602
+ case NodeSpecKind.MATCH:
1603
+ return f"Match<br/>{spec.name}"
1604
+ case NodeSpecKind.DO_WHILE:
1605
+ return f"DoWhile<br/>{spec.name}"
1606
+ case _:
1607
+ return spec.name
1608
+
1609
+ def ensure_children(spec: NodeSpec) -> List[NodeSpec]:
1610
+ match spec.kind:
1611
+ case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
1612
+ owner = getattr(spec, "owner", None) or tree
1613
+ factory = spec.payload["factory"]
1614
+ spec.children = _bt_expand_children(factory, owner)
1615
+ case NodeSpecKind.SUBTREE:
1616
+ subtree_root = spec.payload["root"]
1617
+ spec.children = [subtree_root]
1618
+ case NodeSpecKind.MATCH:
1619
+ case_specs: List[CaseSpec] = spec.payload["cases"]
1620
+ spec.children = [cs.child for cs in case_specs]
1621
+ case NodeSpecKind.DO_WHILE:
1622
+ # Children already set (just the body), but we also want to show condition
1623
+ cond_spec = spec.payload["condition"]
1624
+ spec.children = [cond_spec] + spec.children
1625
+ return spec.children
1626
+
1627
+ def walk(spec: NodeSpec) -> None:
1628
+ this_id = nid(spec)
1629
+ shape = (
1630
+ "((%s))"
1631
+ if spec.kind in (NodeSpecKind.ACTION, NodeSpecKind.CONDITION)
1632
+ else '["%s"]'
1633
+ ) % label(spec)
1634
+ lines.append(f" {this_id}{shape}")
1635
+
1636
+ children = ensure_children(spec)
1637
+
1638
+ if spec.kind == NodeSpecKind.MATCH:
1639
+ case_specs: List[CaseSpec] = spec.payload["cases"]
1640
+ for case_spec, child in zip(case_specs, children):
1641
+ child_id = nid(child)
1642
+ edge_label = case_spec.label.replace('"', "'")
1643
+ lines.append(f' {this_id} -->|"{edge_label}"| {child_id}')
1644
+ walk(child)
1645
+ elif spec.kind == NodeSpecKind.DO_WHILE:
1646
+ # First child is condition, second is body
1647
+ cond_id = nid(children[0])
1648
+ lines.append(f' {this_id} -->|"condition"| {cond_id}')
1649
+ walk(children[0])
1650
+ if len(children) > 1:
1651
+ body_id = nid(children[1])
1652
+ lines.append(f' {this_id} -->|"body"| {body_id}')
1653
+ walk(children[1])
1654
+ else:
1655
+ for child in children:
1656
+ child_id = nid(child)
1657
+ lines.append(f" {this_id} --> {child_id}")
1658
+ walk(child)
1659
+
1660
+ walk(tree.root)
1661
+ return "\n".join(lines)
1662
+
1663
+
1664
+ # ======================================================================================
1665
+ # Runner
1666
+ # ======================================================================================
1667
+
1668
+
1669
+ class Runner(Generic[BB]):
1670
+ """Runtime for executing behavior trees.
1671
+
1672
+ The Runner manages the execution of a behavior tree, handling tick
1673
+ calls, timebase management, and result processing.
1674
+
1675
+ Args:
1676
+ tree: A behavior tree namespace with a 'root' attribute (from @bt.tree)
1677
+ bb: Blackboard containing shared state
1678
+ tb: Optional timebase for time management (defaults to MonotonicClock)
1679
+ exception_policy: How to handle exceptions during tree execution
1680
+
1681
+ Methods:
1682
+ tick(): Execute one tick of the behavior tree
1683
+ tick_until_complete(): Run the tree until it returns a terminal status
1684
+
1685
+ Example:
1686
+ @bt.tree
1687
+ def MyTree():
1688
+ @bt.root
1689
+ @bt.sequence
1690
+ def root():
1691
+ yield check_condition
1692
+ yield do_work
1693
+
1694
+ runner = Runner(MyTree, bb=blackboard)
1695
+ result = await runner.tick_until_complete()
1696
+ """
1697
+ def __init__(
1698
+ self,
1699
+ tree: SimpleNamespace,
1700
+ bb: BB,
1701
+ tb: Optional[Timebase] = None,
1702
+ exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
1703
+ ) -> None:
1704
+ self.tree = tree
1705
+ self.bb: BB = bb
1706
+ self.tb = tb or MonotonicClock()
1707
+ self.exception_policy = exception_policy
1708
+
1709
+ if not hasattr(tree, "root"):
1710
+ raise ValueError("Tree namespace must have a 'root' attribute")
1711
+
1712
+ self.root: Node[BB] = tree.root.to_node(
1713
+ owner=tree, exception_policy=exception_policy
1714
+ )
1715
+
1716
+ async def tick(self) -> Status:
1717
+ result = await self.root.tick(self.bb, self.tb)
1718
+ self.tb.advance()
1719
+ return result
1720
+
1721
+ async def tick_until_complete(self, timeout: Optional[float] = None) -> Status:
1722
+ start = self.tb.now()
1723
+ while True:
1724
+ st = await self.tick()
1725
+ if st in (Status.SUCCESS, Status.FAILURE, Status.CANCELLED, Status.ERROR):
1726
+ return st
1727
+ if timeout is not None and (self.tb.now() - start) > timeout:
1728
+ return Status.CANCELLED
1729
+ await asyncio.sleep(0)