mycorrhizal 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mycorrhizal/__init__.py +3 -0
- mycorrhizal/common/__init__.py +68 -0
- mycorrhizal/common/interface_builder.py +203 -0
- mycorrhizal/common/interfaces.py +412 -0
- mycorrhizal/common/timebase.py +99 -0
- mycorrhizal/common/wrappers.py +532 -0
- mycorrhizal/enoki/__init__.py +0 -0
- mycorrhizal/enoki/core.py +1545 -0
- mycorrhizal/enoki/testing_utils.py +529 -0
- mycorrhizal/enoki/util.py +220 -0
- mycorrhizal/hypha/__init__.py +0 -0
- mycorrhizal/hypha/core/__init__.py +107 -0
- mycorrhizal/hypha/core/builder.py +404 -0
- mycorrhizal/hypha/core/runtime.py +890 -0
- mycorrhizal/hypha/core/specs.py +234 -0
- mycorrhizal/hypha/util.py +38 -0
- mycorrhizal/rhizomorph/README.md +220 -0
- mycorrhizal/rhizomorph/__init__.py +0 -0
- mycorrhizal/rhizomorph/core.py +1729 -0
- mycorrhizal/rhizomorph/util.py +45 -0
- mycorrhizal/spores/__init__.py +124 -0
- mycorrhizal/spores/cache.py +208 -0
- mycorrhizal/spores/core.py +419 -0
- mycorrhizal/spores/dsl/__init__.py +48 -0
- mycorrhizal/spores/dsl/enoki.py +514 -0
- mycorrhizal/spores/dsl/hypha.py +399 -0
- mycorrhizal/spores/dsl/rhizomorph.py +351 -0
- mycorrhizal/spores/encoder/__init__.py +11 -0
- mycorrhizal/spores/encoder/base.py +42 -0
- mycorrhizal/spores/encoder/json.py +159 -0
- mycorrhizal/spores/extraction.py +484 -0
- mycorrhizal/spores/models.py +288 -0
- mycorrhizal/spores/transport/__init__.py +10 -0
- mycorrhizal/spores/transport/base.py +46 -0
- mycorrhizal-0.1.0.dist-info/METADATA +198 -0
- mycorrhizal-0.1.0.dist-info/RECORD +37 -0
- mycorrhizal-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1729 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Rhizomorph - Asyncio Behavior Tree Framework
|
|
4
|
+
|
|
5
|
+
A decorator-based DSL for defining and executing behavior trees with support for
|
|
6
|
+
asyncio, multi-file composition, and type-safe blackboard interfaces.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from mycorrhizal.rhizomorph.core import bt, Runner, Status
|
|
10
|
+
|
|
11
|
+
@bt.tree
|
|
12
|
+
def MyBehaviorTree():
|
|
13
|
+
@bt.action
|
|
14
|
+
async def do_work(bb) -> Status:
|
|
15
|
+
# Do some work
|
|
16
|
+
return Status.SUCCESS
|
|
17
|
+
|
|
18
|
+
@bt.condition
|
|
19
|
+
def should_work(bb) -> bool:
|
|
20
|
+
return bb.work_available
|
|
21
|
+
|
|
22
|
+
@bt.root
|
|
23
|
+
@bt.sequence
|
|
24
|
+
def root():
|
|
25
|
+
yield should_work
|
|
26
|
+
yield do_work
|
|
27
|
+
|
|
28
|
+
# Run the tree
|
|
29
|
+
runner = Runner(MyBehaviorTree, bb=blackboard)
|
|
30
|
+
result = await runner.tick_until_complete()
|
|
31
|
+
|
|
32
|
+
Key Classes:
|
|
33
|
+
Status - Result status for behavior tree nodes (SUCCESS, FAILURE, RUNNING, etc.)
|
|
34
|
+
Node - Base class for all behavior tree nodes
|
|
35
|
+
Action - Leaf node that executes a function
|
|
36
|
+
Condition - Leaf node that evaluates a predicate
|
|
37
|
+
Sequence - Composite that runs children in order
|
|
38
|
+
Selector - Composite that runs children until one succeeds
|
|
39
|
+
Parallel - Composite that runs children concurrently
|
|
40
|
+
|
|
41
|
+
Multi-file Composition:
|
|
42
|
+
Use bt.subtree() to reference trees defined in other modules:
|
|
43
|
+
from other_module import OtherTree
|
|
44
|
+
|
|
45
|
+
@bt.tree
|
|
46
|
+
def MainTree():
|
|
47
|
+
@bt.root
|
|
48
|
+
@bt.sequence
|
|
49
|
+
def root():
|
|
50
|
+
yield bt.subtree(OtherTree, owner=MainTree)
|
|
51
|
+
"""
|
|
52
|
+
from __future__ import annotations
|
|
53
|
+
|
|
54
|
+
import asyncio
|
|
55
|
+
import inspect
|
|
56
|
+
import logging
|
|
57
|
+
import traceback
|
|
58
|
+
from dataclasses import dataclass, field
|
|
59
|
+
from enum import Enum
|
|
60
|
+
from typing import (
|
|
61
|
+
Any,
|
|
62
|
+
Callable,
|
|
63
|
+
Dict,
|
|
64
|
+
Generator,
|
|
65
|
+
Generic,
|
|
66
|
+
List,
|
|
67
|
+
Optional,
|
|
68
|
+
Tuple,
|
|
69
|
+
TypeVar,
|
|
70
|
+
Union,
|
|
71
|
+
Set,
|
|
72
|
+
Protocol,
|
|
73
|
+
)
|
|
74
|
+
from typing import Sequence as SequenceT
|
|
75
|
+
from types import SimpleNamespace
|
|
76
|
+
|
|
77
|
+
from mycorrhizal.common.timebase import *
|
|
78
|
+
|
|
79
|
+
logger = logging.getLogger(__name__)
|
|
80
|
+
|
|
81
|
+
BB = TypeVar("BB")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ======================================================================================
|
|
85
|
+
# Interface Integration Helper
|
|
86
|
+
# ======================================================================================
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _create_interface_view_if_needed(bb: Any, func: Callable) -> Any:
|
|
90
|
+
"""
|
|
91
|
+
Create a constrained view if the function has an interface type hint on its
|
|
92
|
+
first parameter (typically named 'bb').
|
|
93
|
+
|
|
94
|
+
This enables type-safe, constrained access to blackboard state based on
|
|
95
|
+
interface definitions created with @blackboard_interface.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
bb: The blackboard instance
|
|
99
|
+
func: The function to check for interface type hints
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Either the original blackboard or a constrained view based on interface metadata
|
|
103
|
+
"""
|
|
104
|
+
from typing import get_type_hints
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
sig = inspect.signature(func)
|
|
108
|
+
params = list(sig.parameters.values())
|
|
109
|
+
|
|
110
|
+
# Check first parameter (usually 'bb')
|
|
111
|
+
if params and params[0].name == 'bb':
|
|
112
|
+
bb_type = get_type_hints(func).get('bb')
|
|
113
|
+
|
|
114
|
+
# If type hint exists and has interface metadata
|
|
115
|
+
if bb_type and hasattr(bb_type, '_readonly_fields'):
|
|
116
|
+
from mycorrhizal.common.wrappers import create_view_from_protocol
|
|
117
|
+
return create_view_from_protocol(bb, bb_type)
|
|
118
|
+
except Exception:
|
|
119
|
+
# If anything goes wrong with type inspection, fall back to original bb
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
return bb
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# ======================================================================================
|
|
126
|
+
# Function signature protocols
|
|
127
|
+
# ======================================================================================
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _supports_timebase(func: Callable) -> bool:
|
|
131
|
+
"""Check if a function accepts a 'tb' parameter."""
|
|
132
|
+
try:
|
|
133
|
+
sig = inspect.signature(func)
|
|
134
|
+
return 'tb' in sig.parameters
|
|
135
|
+
except (ValueError, TypeError):
|
|
136
|
+
return False
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def _call_node_function(func: Callable, bb: Any, tb: Timebase) -> Any:
|
|
140
|
+
"""
|
|
141
|
+
Call a node function with appropriate parameters based on its signature.
|
|
142
|
+
|
|
143
|
+
If the function has an interface type hint on its 'bb' parameter, a
|
|
144
|
+
constrained view will be created automatically to enforce access control.
|
|
145
|
+
"""
|
|
146
|
+
# Create interface view if function has interface type hint
|
|
147
|
+
bb_to_pass = _create_interface_view_if_needed(bb, func)
|
|
148
|
+
|
|
149
|
+
if _supports_timebase(func):
|
|
150
|
+
if inspect.iscoroutinefunction(func):
|
|
151
|
+
return await func(bb=bb_to_pass, tb=tb)
|
|
152
|
+
else:
|
|
153
|
+
return func(bb=bb_to_pass, tb=tb)
|
|
154
|
+
else:
|
|
155
|
+
if inspect.iscoroutinefunction(func):
|
|
156
|
+
return await func(bb=bb_to_pass)
|
|
157
|
+
else:
|
|
158
|
+
return func(bb=bb_to_pass)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ======================================================================================
|
|
162
|
+
# Status / helpers
|
|
163
|
+
# ======================================================================================
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class Status(Enum):
|
|
167
|
+
"""Result status for behavior tree node execution.
|
|
168
|
+
|
|
169
|
+
Attributes:
|
|
170
|
+
SUCCESS - Node completed successfully
|
|
171
|
+
FAILURE - Node failed
|
|
172
|
+
RUNNING - Node is still running (async operation in progress)
|
|
173
|
+
CANCELLED - Node was cancelled before completion
|
|
174
|
+
ERROR - Node encountered an error
|
|
175
|
+
|
|
176
|
+
Note:
|
|
177
|
+
Composite nodes use these statuses to determine control flow:
|
|
178
|
+
- Sequence stops on first FAILURE
|
|
179
|
+
- Selector stops on first SUCCESS
|
|
180
|
+
- Parallel waits for all children, fails if any fail
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
SUCCESS = 1
|
|
184
|
+
FAILURE = 2
|
|
185
|
+
RUNNING = 3
|
|
186
|
+
CANCELLED = 4
|
|
187
|
+
ERROR = 5
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ExceptionPolicy(Enum):
|
|
191
|
+
"""Policy for handling exceptions in action/condition nodes."""
|
|
192
|
+
|
|
193
|
+
LOG_AND_CONTINUE = 1
|
|
194
|
+
PROPAGATE = 2
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _name_of(obj: Any) -> str:
|
|
198
|
+
if hasattr(obj, "__name__"):
|
|
199
|
+
return obj.__name__
|
|
200
|
+
if hasattr(obj, "name"):
|
|
201
|
+
return str(obj.name)
|
|
202
|
+
return f"{obj.__class__.__name__}@{id(obj):x}"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# ======================================================================================
|
|
206
|
+
# Recursion Detection
|
|
207
|
+
# ======================================================================================
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class RecursionError(Exception):
|
|
211
|
+
"""Raised when recursive behavior tree structure is detected"""
|
|
212
|
+
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# ======================================================================================
|
|
217
|
+
# Node model
|
|
218
|
+
# ======================================================================================
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class Node(Generic[BB]):
|
|
222
|
+
"""Base class for behavior tree nodes.
|
|
223
|
+
|
|
224
|
+
All behavior tree nodes inherit from this class. Nodes are executed
|
|
225
|
+
by calling the tick() method, which returns a Status indicating
|
|
226
|
+
the result of execution.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
name: Optional name for this node (used in debugging and logging)
|
|
230
|
+
exception_policy: How to handle exceptions during execution
|
|
231
|
+
|
|
232
|
+
Attributes:
|
|
233
|
+
name: Node name
|
|
234
|
+
parent: Parent node in the tree
|
|
235
|
+
exception_policy: Exception handling policy
|
|
236
|
+
_entered: Whether on_enter has been called
|
|
237
|
+
_last_status: Last status returned by tick
|
|
238
|
+
|
|
239
|
+
Methods:
|
|
240
|
+
tick(bb, tb): Execute the node, return Status
|
|
241
|
+
on_enter(bb, tb): Called when node is first entered
|
|
242
|
+
on_exit(bb, status, tb): Called when node exits (not RUNNING)
|
|
243
|
+
reset(): Reset node state for reuse
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(
|
|
247
|
+
self,
|
|
248
|
+
name: Optional[str] = None,
|
|
249
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
250
|
+
) -> None:
|
|
251
|
+
self.name: str = name or _name_of(self)
|
|
252
|
+
self.parent: Optional[Node[BB]] = None
|
|
253
|
+
self.exception_policy = exception_policy
|
|
254
|
+
self._entered: bool = False
|
|
255
|
+
self._last_status: Optional[Status] = None
|
|
256
|
+
|
|
257
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
258
|
+
raise NotImplementedError
|
|
259
|
+
|
|
260
|
+
async def on_enter(self, bb: BB, tb: Timebase) -> None:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
async def on_exit(self, bb: BB, status: Status, tb: Timebase) -> None:
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
def reset(self) -> None:
|
|
267
|
+
self._entered = False
|
|
268
|
+
self._last_status = None
|
|
269
|
+
|
|
270
|
+
async def _ensure_entered(self, bb: BB, tb: Timebase) -> None:
|
|
271
|
+
if not self._entered:
|
|
272
|
+
await self.on_enter(bb, tb)
|
|
273
|
+
self._entered = True
|
|
274
|
+
|
|
275
|
+
async def _finish(self, bb: BB, status: Status, tb: Timebase) -> Status:
|
|
276
|
+
self._last_status = status
|
|
277
|
+
if status is not Status.RUNNING:
|
|
278
|
+
try:
|
|
279
|
+
await self.on_exit(bb, status, tb)
|
|
280
|
+
finally:
|
|
281
|
+
self._entered = False
|
|
282
|
+
return status
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ======================================================================================
|
|
286
|
+
# Leaves
|
|
287
|
+
# ======================================================================================
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class Action(Node[BB]):
|
|
291
|
+
"""Leaf node that executes a function.
|
|
292
|
+
|
|
293
|
+
Action nodes wrap sync or async functions that perform work or
|
|
294
|
+
check conditions. The function can return:
|
|
295
|
+
- Status enum (SUCCESS, FAILURE, RUNNING, etc.)
|
|
296
|
+
- bool (True -> SUCCESS, False -> FAILURE)
|
|
297
|
+
- None (treated as SUCCESS)
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
func: Function to execute (sync or async)
|
|
301
|
+
name: Optional name for this action
|
|
302
|
+
exception_policy: How to handle exceptions
|
|
303
|
+
|
|
304
|
+
The function signature can be:
|
|
305
|
+
- func(bb) -> Status | bool | None
|
|
306
|
+
- func(bb, tb) -> Status | bool | None
|
|
307
|
+
|
|
308
|
+
where bb is the blackboard and tb is an optional timebase.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
def __init__(
|
|
312
|
+
self,
|
|
313
|
+
func: Callable[..., Any],
|
|
314
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
315
|
+
) -> None:
|
|
316
|
+
super().__init__(name=_name_of(func), exception_policy=exception_policy)
|
|
317
|
+
self._func = func
|
|
318
|
+
|
|
319
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
320
|
+
await self._ensure_entered(bb, tb)
|
|
321
|
+
try:
|
|
322
|
+
result = await _call_node_function(self._func, bb, tb)
|
|
323
|
+
except asyncio.CancelledError:
|
|
324
|
+
return await self._finish(bb, Status.CANCELLED, tb)
|
|
325
|
+
except Exception as e:
|
|
326
|
+
logger.error(
|
|
327
|
+
f"Exception in Action '{self.name}':\n{traceback.format_exc()}"
|
|
328
|
+
)
|
|
329
|
+
if self.exception_policy == ExceptionPolicy.PROPAGATE:
|
|
330
|
+
raise
|
|
331
|
+
return await self._finish(bb, Status.ERROR, tb)
|
|
332
|
+
|
|
333
|
+
if isinstance(result, Status):
|
|
334
|
+
return await self._finish(bb, result, tb)
|
|
335
|
+
if isinstance(result, bool):
|
|
336
|
+
return await self._finish(
|
|
337
|
+
bb, Status.SUCCESS if result else Status.FAILURE, tb
|
|
338
|
+
)
|
|
339
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class Condition(Action[BB]):
|
|
343
|
+
"""Boolean leaf: True→SUCCESS, False→FAILURE (Status accepted but discouraged)."""
|
|
344
|
+
|
|
345
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
346
|
+
await self._ensure_entered(bb, tb)
|
|
347
|
+
try:
|
|
348
|
+
result = await _call_node_function(self._func, bb, tb)
|
|
349
|
+
except asyncio.CancelledError:
|
|
350
|
+
return await self._finish(bb, Status.CANCELLED, tb)
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.error(
|
|
353
|
+
f"Exception in Condition '{self.name}':\n{traceback.format_exc()}"
|
|
354
|
+
)
|
|
355
|
+
if self.exception_policy == ExceptionPolicy.PROPAGATE:
|
|
356
|
+
raise
|
|
357
|
+
return await self._finish(bb, Status.ERROR, tb)
|
|
358
|
+
|
|
359
|
+
if isinstance(result, Status):
|
|
360
|
+
return await self._finish(bb, result, tb)
|
|
361
|
+
return await self._finish(
|
|
362
|
+
bb, Status.SUCCESS if bool(result) else Status.FAILURE, tb
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# ======================================================================================
|
|
367
|
+
# Composites
|
|
368
|
+
# ======================================================================================
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class Sequence(Node[BB]):
|
|
372
|
+
"""Sequence (AND): fail/err fast; RUNNING bubbles; all SUCCESS → SUCCESS."""
|
|
373
|
+
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
children: SequenceT[Node[BB]],
|
|
377
|
+
*,
|
|
378
|
+
memory: bool = True,
|
|
379
|
+
name: Optional[str] = None,
|
|
380
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
381
|
+
) -> None:
|
|
382
|
+
super().__init__(name or "Sequence", exception_policy=exception_policy)
|
|
383
|
+
self.children = list(children)
|
|
384
|
+
for ch in self.children:
|
|
385
|
+
ch.parent = self
|
|
386
|
+
self.memory = memory
|
|
387
|
+
self._idx = 0
|
|
388
|
+
|
|
389
|
+
def reset(self) -> None:
|
|
390
|
+
super().reset()
|
|
391
|
+
self._idx = 0
|
|
392
|
+
for ch in self.children:
|
|
393
|
+
ch.reset()
|
|
394
|
+
|
|
395
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
396
|
+
await self._ensure_entered(bb, tb)
|
|
397
|
+
start = self._idx if self.memory else 0
|
|
398
|
+
for i in range(start, len(self.children)):
|
|
399
|
+
st = await self.children[i].tick(bb, tb)
|
|
400
|
+
if st is Status.RUNNING:
|
|
401
|
+
if self.memory:
|
|
402
|
+
self._idx = i
|
|
403
|
+
return Status.RUNNING
|
|
404
|
+
if st in (Status.FAILURE, Status.ERROR, Status.CANCELLED):
|
|
405
|
+
self._idx = 0
|
|
406
|
+
return await self._finish(bb, st, tb)
|
|
407
|
+
self._idx = 0
|
|
408
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class Selector(Node[BB]):
|
|
412
|
+
"""Selector (Fallback): first SUCCESS wins; RUNNING bubbles; else FAILURE."""
|
|
413
|
+
|
|
414
|
+
def __init__(
|
|
415
|
+
self,
|
|
416
|
+
children: SequenceT[Node[BB]],
|
|
417
|
+
*,
|
|
418
|
+
memory: bool = True,
|
|
419
|
+
name: Optional[str] = None,
|
|
420
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
421
|
+
) -> None:
|
|
422
|
+
super().__init__(name or "Selector", exception_policy=exception_policy)
|
|
423
|
+
self.children = list(children)
|
|
424
|
+
for ch in self.children:
|
|
425
|
+
ch.parent = self
|
|
426
|
+
self.memory = memory
|
|
427
|
+
self._idx = 0
|
|
428
|
+
|
|
429
|
+
def reset(self) -> None:
|
|
430
|
+
super().reset()
|
|
431
|
+
self._idx = 0
|
|
432
|
+
for ch in self.children:
|
|
433
|
+
ch.reset()
|
|
434
|
+
|
|
435
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
436
|
+
await self._ensure_entered(bb, tb)
|
|
437
|
+
start = self._idx if self.memory else 0
|
|
438
|
+
for i in range(start, len(self.children)):
|
|
439
|
+
st = await self.children[i].tick(bb, tb)
|
|
440
|
+
if st is Status.SUCCESS:
|
|
441
|
+
self._idx = 0
|
|
442
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
443
|
+
if st is Status.RUNNING:
|
|
444
|
+
if self.memory:
|
|
445
|
+
self._idx = i
|
|
446
|
+
return Status.RUNNING
|
|
447
|
+
self._idx = 0
|
|
448
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
class Parallel(Node[BB]):
|
|
452
|
+
"""
|
|
453
|
+
Parallel: tick all children per tick concurrently.
|
|
454
|
+
- success_threshold (k) SUCCESS to report SUCCESS
|
|
455
|
+
- failure_threshold defaults to n-k+1 (cannot reach success anymore) to report FAILURE
|
|
456
|
+
Otherwise RUNNING.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
def __init__(
|
|
460
|
+
self,
|
|
461
|
+
children: SequenceT[Node[BB]],
|
|
462
|
+
*,
|
|
463
|
+
success_threshold: int,
|
|
464
|
+
failure_threshold: Optional[int] = None,
|
|
465
|
+
name: Optional[str] = None,
|
|
466
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
467
|
+
) -> None:
|
|
468
|
+
super().__init__(name or "Parallel", exception_policy=exception_policy)
|
|
469
|
+
self.children = list(children)
|
|
470
|
+
for ch in self.children:
|
|
471
|
+
ch.parent = self
|
|
472
|
+
self.n = len(self.children)
|
|
473
|
+
self.k = max(1, min(success_threshold, self.n))
|
|
474
|
+
self.f = (
|
|
475
|
+
failure_threshold
|
|
476
|
+
if failure_threshold is not None
|
|
477
|
+
else (self.n - self.k + 1)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
def reset(self) -> None:
|
|
481
|
+
super().reset()
|
|
482
|
+
for ch in self.children:
|
|
483
|
+
ch.reset()
|
|
484
|
+
|
|
485
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
486
|
+
await self._ensure_entered(bb, tb)
|
|
487
|
+
tasks = [asyncio.create_task(ch.tick(bb, tb)) for ch in self.children]
|
|
488
|
+
try:
|
|
489
|
+
results = await asyncio.gather(*tasks, return_exceptions=False)
|
|
490
|
+
finally:
|
|
491
|
+
pass
|
|
492
|
+
succ = sum(1 for s in results if s is Status.SUCCESS)
|
|
493
|
+
fail = sum(
|
|
494
|
+
1 for s in results if s in (Status.FAILURE, Status.ERROR, Status.CANCELLED)
|
|
495
|
+
)
|
|
496
|
+
if succ >= self.k:
|
|
497
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
498
|
+
if fail >= self.f:
|
|
499
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
500
|
+
return Status.RUNNING
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
# ======================================================================================
|
|
504
|
+
# Decorators (wrappers)
|
|
505
|
+
# ======================================================================================
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class Inverter(Node[BB]):
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
child: Node[BB],
|
|
512
|
+
name: Optional[str] = None,
|
|
513
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
514
|
+
) -> None:
|
|
515
|
+
super().__init__(
|
|
516
|
+
name or f"Inverter({_name_of(child)})", exception_policy=exception_policy
|
|
517
|
+
)
|
|
518
|
+
self.child = child
|
|
519
|
+
self.child.parent = self
|
|
520
|
+
|
|
521
|
+
def reset(self) -> None:
|
|
522
|
+
super().reset()
|
|
523
|
+
self.child.reset()
|
|
524
|
+
|
|
525
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
526
|
+
await self._ensure_entered(bb, tb)
|
|
527
|
+
st = await self.child.tick(bb, tb)
|
|
528
|
+
if st is Status.SUCCESS:
|
|
529
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
530
|
+
if st is Status.FAILURE:
|
|
531
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
532
|
+
return st
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
class Retry(Node[BB]):
|
|
536
|
+
def __init__(
|
|
537
|
+
self,
|
|
538
|
+
child: Node[BB],
|
|
539
|
+
*,
|
|
540
|
+
max_attempts: int,
|
|
541
|
+
retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
|
|
542
|
+
name: Optional[str] = None,
|
|
543
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
544
|
+
) -> None:
|
|
545
|
+
super().__init__(
|
|
546
|
+
name or f"Retry({_name_of(child)},{max_attempts})",
|
|
547
|
+
exception_policy=exception_policy,
|
|
548
|
+
)
|
|
549
|
+
self.child = child
|
|
550
|
+
self.child.parent = self
|
|
551
|
+
self.max_attempts = max(1, int(max_attempts))
|
|
552
|
+
self.retry_on = retry_on
|
|
553
|
+
self._attempt = 0
|
|
554
|
+
|
|
555
|
+
def reset(self) -> None:
|
|
556
|
+
super().reset()
|
|
557
|
+
self._attempt = 0
|
|
558
|
+
self.child.reset()
|
|
559
|
+
|
|
560
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
561
|
+
await self._ensure_entered(bb, tb)
|
|
562
|
+
st = await self.child.tick(bb, tb)
|
|
563
|
+
if st is Status.RUNNING:
|
|
564
|
+
return Status.RUNNING
|
|
565
|
+
if st is Status.SUCCESS:
|
|
566
|
+
self._attempt = 0
|
|
567
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
568
|
+
self._attempt += 1
|
|
569
|
+
if st in self.retry_on and self._attempt < self.max_attempts:
|
|
570
|
+
self.child.reset()
|
|
571
|
+
return Status.RUNNING
|
|
572
|
+
self._attempt = 0
|
|
573
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class Timeout(Node[BB]):
|
|
577
|
+
def __init__(
|
|
578
|
+
self,
|
|
579
|
+
child: Node[BB],
|
|
580
|
+
*,
|
|
581
|
+
seconds: float,
|
|
582
|
+
name: Optional[str] = None,
|
|
583
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
584
|
+
) -> None:
|
|
585
|
+
super().__init__(
|
|
586
|
+
name or f"Timeout({_name_of(child)},{seconds}s)",
|
|
587
|
+
exception_policy=exception_policy,
|
|
588
|
+
)
|
|
589
|
+
self.child = child
|
|
590
|
+
self.child.parent = self
|
|
591
|
+
self.seconds = max(0.0, float(seconds))
|
|
592
|
+
self._deadline: Optional[float] = None
|
|
593
|
+
|
|
594
|
+
def reset(self) -> None:
|
|
595
|
+
super().reset()
|
|
596
|
+
self._deadline = None
|
|
597
|
+
self.child.reset()
|
|
598
|
+
|
|
599
|
+
async def on_enter(self, bb: BB, tb: Timebase) -> None:
|
|
600
|
+
self._deadline = tb.now() + self.seconds
|
|
601
|
+
|
|
602
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
603
|
+
await self._ensure_entered(bb, tb)
|
|
604
|
+
if self._deadline is None:
|
|
605
|
+
self._deadline = tb.now() + self.seconds
|
|
606
|
+
if tb.now() > self._deadline:
|
|
607
|
+
self.child.reset()
|
|
608
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
609
|
+
st = await self.child.tick(bb, tb)
|
|
610
|
+
if st is Status.RUNNING:
|
|
611
|
+
return Status.RUNNING
|
|
612
|
+
return await self._finish(bb, st, tb)
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class Succeeder(Node[BB]):
|
|
616
|
+
def __init__(
|
|
617
|
+
self,
|
|
618
|
+
child: Node[BB],
|
|
619
|
+
name: Optional[str] = None,
|
|
620
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
621
|
+
) -> None:
|
|
622
|
+
super().__init__(
|
|
623
|
+
name or f"Succeeder({_name_of(child)})", exception_policy=exception_policy
|
|
624
|
+
)
|
|
625
|
+
self.child = child
|
|
626
|
+
self.child.parent = self
|
|
627
|
+
|
|
628
|
+
def reset(self) -> None:
|
|
629
|
+
super().reset()
|
|
630
|
+
self.child.reset()
|
|
631
|
+
|
|
632
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
633
|
+
await self._ensure_entered(bb, tb)
|
|
634
|
+
st = await self.child.tick(bb, tb)
|
|
635
|
+
if st is Status.RUNNING:
|
|
636
|
+
return Status.RUNNING
|
|
637
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
class Failer(Node[BB]):
|
|
641
|
+
def __init__(
|
|
642
|
+
self,
|
|
643
|
+
child: Node[BB],
|
|
644
|
+
name: Optional[str] = None,
|
|
645
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
646
|
+
) -> None:
|
|
647
|
+
super().__init__(
|
|
648
|
+
name or f"Failer({_name_of(child)})", exception_policy=exception_policy
|
|
649
|
+
)
|
|
650
|
+
self.child = child
|
|
651
|
+
self.child.parent = self
|
|
652
|
+
|
|
653
|
+
def reset(self) -> None:
|
|
654
|
+
super().reset()
|
|
655
|
+
self.child.reset()
|
|
656
|
+
|
|
657
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
658
|
+
await self._ensure_entered(bb, tb)
|
|
659
|
+
st = await self.child.tick(bb, tb)
|
|
660
|
+
if st is Status.RUNNING:
|
|
661
|
+
return Status.RUNNING
|
|
662
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
class RateLimit(Node[BB]):
|
|
666
|
+
"""
|
|
667
|
+
Throttle *starting* the child to at most 1 per period (or hz).
|
|
668
|
+
If child is RUNNING, do not throttle (avoid starvation).
|
|
669
|
+
"""
|
|
670
|
+
|
|
671
|
+
def __init__(
|
|
672
|
+
self,
|
|
673
|
+
child: Node[BB],
|
|
674
|
+
*,
|
|
675
|
+
hz: Optional[float] = None,
|
|
676
|
+
period: Optional[float] = None,
|
|
677
|
+
name: Optional[str] = None,
|
|
678
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
679
|
+
) -> None:
|
|
680
|
+
if (hz is not None) and (period is not None):
|
|
681
|
+
raise ValueError("RateLimit requires exactly one of (hz, period)")
|
|
682
|
+
|
|
683
|
+
if period is not None:
|
|
684
|
+
per = period
|
|
685
|
+
elif hz is not None:
|
|
686
|
+
per = 1.0 / float(hz)
|
|
687
|
+
else:
|
|
688
|
+
raise ValueError("RateLimit requires exactly one of (hz, period)")
|
|
689
|
+
|
|
690
|
+
super().__init__(
|
|
691
|
+
name or f"RateLimit({_name_of(child)},{per:.6f}s)",
|
|
692
|
+
exception_policy=exception_policy,
|
|
693
|
+
)
|
|
694
|
+
self.child = child
|
|
695
|
+
self.child.parent = self
|
|
696
|
+
self._period = max(0.0, per)
|
|
697
|
+
self._next_allowed: Optional[float] = None
|
|
698
|
+
self._last: Optional[Status] = None
|
|
699
|
+
|
|
700
|
+
def reset(self) -> None:
|
|
701
|
+
super().reset()
|
|
702
|
+
self._next_allowed = None
|
|
703
|
+
self._last = None
|
|
704
|
+
self.child.reset()
|
|
705
|
+
|
|
706
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
707
|
+
await self._ensure_entered(bb, tb)
|
|
708
|
+
if self._last is Status.RUNNING:
|
|
709
|
+
st = await self.child.tick(bb, tb)
|
|
710
|
+
self._last = st
|
|
711
|
+
if st is Status.RUNNING:
|
|
712
|
+
return Status.RUNNING
|
|
713
|
+
self._next_allowed = tb.now() + self._period
|
|
714
|
+
return await self._finish(bb, st, tb)
|
|
715
|
+
|
|
716
|
+
now = tb.now()
|
|
717
|
+
if self._next_allowed is not None and now < self._next_allowed:
|
|
718
|
+
return Status.RUNNING
|
|
719
|
+
st = await self.child.tick(bb, tb)
|
|
720
|
+
self._last = st
|
|
721
|
+
if st is Status.RUNNING:
|
|
722
|
+
return Status.RUNNING
|
|
723
|
+
self._next_allowed = now + self._period
|
|
724
|
+
return await self._finish(bb, st, tb)
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class Gate(Node[BB]):
|
|
728
|
+
"""
|
|
729
|
+
Guard a child with a condition:
|
|
730
|
+
SUCCESS → tick child
|
|
731
|
+
RUNNING → RUNNING
|
|
732
|
+
else → FAILURE
|
|
733
|
+
"""
|
|
734
|
+
|
|
735
|
+
def __init__(
|
|
736
|
+
self,
|
|
737
|
+
condition: Node[BB],
|
|
738
|
+
child: Node[BB],
|
|
739
|
+
name: Optional[str] = None,
|
|
740
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
741
|
+
) -> None:
|
|
742
|
+
super().__init__(
|
|
743
|
+
name or f"Gate(cond={_name_of(condition)}, child={_name_of(child)})",
|
|
744
|
+
exception_policy=exception_policy,
|
|
745
|
+
)
|
|
746
|
+
self.condition = condition
|
|
747
|
+
self.child = child
|
|
748
|
+
self.condition.parent = self
|
|
749
|
+
self.child.parent = self
|
|
750
|
+
|
|
751
|
+
def reset(self) -> None:
|
|
752
|
+
super().reset()
|
|
753
|
+
self.condition.reset()
|
|
754
|
+
self.child.reset()
|
|
755
|
+
|
|
756
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
757
|
+
await self._ensure_entered(bb, tb)
|
|
758
|
+
c = await self.condition.tick(bb, tb)
|
|
759
|
+
if c is Status.RUNNING:
|
|
760
|
+
return Status.RUNNING
|
|
761
|
+
if c is not Status.SUCCESS:
|
|
762
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
763
|
+
st = await self.child.tick(bb, tb)
|
|
764
|
+
if st is Status.RUNNING:
|
|
765
|
+
return Status.RUNNING
|
|
766
|
+
return await self._finish(bb, st, tb)
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
class Match(Node[BB]):
|
|
770
|
+
"""
|
|
771
|
+
Pattern-matching dispatch node.
|
|
772
|
+
|
|
773
|
+
Evaluates a key function against the blackboard, then checks each case
|
|
774
|
+
in order. The first matching case's child is executed. If the child
|
|
775
|
+
completes (SUCCESS or FAILURE), that status is returned immediately.
|
|
776
|
+
|
|
777
|
+
Cases can match by:
|
|
778
|
+
- Type: isinstance(value, case_type)
|
|
779
|
+
- Predicate: case_predicate(value) returns True
|
|
780
|
+
- Value: value == case_value
|
|
781
|
+
- Default: always matches (should be last)
|
|
782
|
+
|
|
783
|
+
If no case matches and there's no default, returns FAILURE.
|
|
784
|
+
"""
|
|
785
|
+
|
|
786
|
+
def __init__(
|
|
787
|
+
self,
|
|
788
|
+
key_fn: Callable[[Any], Any],
|
|
789
|
+
cases: List[Tuple[Any, Node[BB]]],
|
|
790
|
+
name: Optional[str] = None,
|
|
791
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
792
|
+
) -> None:
|
|
793
|
+
super().__init__(
|
|
794
|
+
name or f"Match({_name_of(key_fn)})",
|
|
795
|
+
exception_policy=exception_policy,
|
|
796
|
+
)
|
|
797
|
+
self._key_fn = key_fn
|
|
798
|
+
self._cases = cases
|
|
799
|
+
for _, child in self._cases:
|
|
800
|
+
child.parent = self
|
|
801
|
+
self._matched_idx: Optional[int] = None
|
|
802
|
+
|
|
803
|
+
def reset(self) -> None:
|
|
804
|
+
super().reset()
|
|
805
|
+
self._matched_idx = None
|
|
806
|
+
for _, child in self._cases:
|
|
807
|
+
child.reset()
|
|
808
|
+
|
|
809
|
+
def _matches(self, matcher: Any, value: Any) -> bool:
|
|
810
|
+
"""Check if a matcher matches the given value."""
|
|
811
|
+
if matcher is _DefaultCase:
|
|
812
|
+
return True
|
|
813
|
+
if isinstance(matcher, type):
|
|
814
|
+
return isinstance(value, matcher)
|
|
815
|
+
if callable(matcher):
|
|
816
|
+
return bool(matcher(value))
|
|
817
|
+
return value == matcher
|
|
818
|
+
|
|
819
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
820
|
+
await self._ensure_entered(bb, tb)
|
|
821
|
+
|
|
822
|
+
if self._matched_idx is not None:
|
|
823
|
+
_, child = self._cases[self._matched_idx]
|
|
824
|
+
st = await child.tick(bb, tb)
|
|
825
|
+
if st is Status.RUNNING:
|
|
826
|
+
return Status.RUNNING
|
|
827
|
+
self._matched_idx = None
|
|
828
|
+
return await self._finish(bb, st, tb)
|
|
829
|
+
|
|
830
|
+
value = await _call_node_function(self._key_fn, bb, tb)
|
|
831
|
+
|
|
832
|
+
for i, (matcher, child) in enumerate(self._cases):
|
|
833
|
+
if self._matches(matcher, value):
|
|
834
|
+
st = await child.tick(bb, tb)
|
|
835
|
+
if st is Status.RUNNING:
|
|
836
|
+
self._matched_idx = i
|
|
837
|
+
return Status.RUNNING
|
|
838
|
+
return await self._finish(bb, st, tb)
|
|
839
|
+
|
|
840
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
class DoWhile(Node[BB]):
|
|
844
|
+
"""
|
|
845
|
+
Loop decorator that repeats its child while a condition is true.
|
|
846
|
+
|
|
847
|
+
Behavior:
|
|
848
|
+
1. Evaluate condition
|
|
849
|
+
2. If condition is FALSE → return SUCCESS (loop complete)
|
|
850
|
+
3. If condition is TRUE → tick child
|
|
851
|
+
- If child returns RUNNING → return RUNNING (resume child next tick)
|
|
852
|
+
- If child returns SUCCESS → reset child, return RUNNING (re-check condition next tick)
|
|
853
|
+
- If child returns FAILURE → return FAILURE (loop aborted)
|
|
854
|
+
|
|
855
|
+
The "return RUNNING after child SUCCESS" prevents infinite loops within a single tick.
|
|
856
|
+
"""
|
|
857
|
+
|
|
858
|
+
def __init__(
|
|
859
|
+
self,
|
|
860
|
+
condition: Node[BB],
|
|
861
|
+
child: Node[BB],
|
|
862
|
+
name: Optional[str] = None,
|
|
863
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
864
|
+
) -> None:
|
|
865
|
+
super().__init__(
|
|
866
|
+
name or f"DoWhile(cond={_name_of(condition)}, child={_name_of(child)})",
|
|
867
|
+
exception_policy=exception_policy,
|
|
868
|
+
)
|
|
869
|
+
self.condition = condition
|
|
870
|
+
self.child = child
|
|
871
|
+
self.condition.parent = self
|
|
872
|
+
self.child.parent = self
|
|
873
|
+
self._child_running = False
|
|
874
|
+
|
|
875
|
+
def reset(self) -> None:
|
|
876
|
+
super().reset()
|
|
877
|
+
self.condition.reset()
|
|
878
|
+
self.child.reset()
|
|
879
|
+
self._child_running = False
|
|
880
|
+
|
|
881
|
+
async def tick(self, bb: BB, tb: Timebase) -> Status:
|
|
882
|
+
await self._ensure_entered(bb, tb)
|
|
883
|
+
|
|
884
|
+
# If child was RUNNING, continue it without re-checking condition
|
|
885
|
+
if self._child_running:
|
|
886
|
+
st = await self.child.tick(bb, tb)
|
|
887
|
+
if st is Status.RUNNING:
|
|
888
|
+
return Status.RUNNING
|
|
889
|
+
self._child_running = False
|
|
890
|
+
if st is Status.SUCCESS:
|
|
891
|
+
# Child completed successfully, reset and loop (next tick)
|
|
892
|
+
self.child.reset()
|
|
893
|
+
self.condition.reset()
|
|
894
|
+
return Status.RUNNING
|
|
895
|
+
# Child failed, abort loop
|
|
896
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
897
|
+
|
|
898
|
+
# Check condition
|
|
899
|
+
cond_status = await self.condition.tick(bb, tb)
|
|
900
|
+
if cond_status is Status.RUNNING:
|
|
901
|
+
return Status.RUNNING
|
|
902
|
+
|
|
903
|
+
if cond_status is not Status.SUCCESS:
|
|
904
|
+
# Condition is false, loop complete
|
|
905
|
+
return await self._finish(bb, Status.SUCCESS, tb)
|
|
906
|
+
|
|
907
|
+
# Condition is true, tick child
|
|
908
|
+
st = await self.child.tick(bb, tb)
|
|
909
|
+
if st is Status.RUNNING:
|
|
910
|
+
self._child_running = True
|
|
911
|
+
return Status.RUNNING
|
|
912
|
+
if st is Status.SUCCESS:
|
|
913
|
+
# Child completed successfully, reset and loop (next tick)
|
|
914
|
+
self.child.reset()
|
|
915
|
+
self.condition.reset()
|
|
916
|
+
return Status.RUNNING
|
|
917
|
+
# Child failed, abort loop
|
|
918
|
+
return await self._finish(bb, Status.FAILURE, tb)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
# ======================================================================================
|
|
922
|
+
# Authoring DSL (NodeSpec + bt namespace)
|
|
923
|
+
# ======================================================================================
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
class NodeSpecKind(Enum):
|
|
927
|
+
ACTION = "action"
|
|
928
|
+
CONDITION = "condition"
|
|
929
|
+
SEQUENCE = "sequence"
|
|
930
|
+
SELECTOR = "selector"
|
|
931
|
+
PARALLEL = "parallel"
|
|
932
|
+
DECORATOR = "decorator"
|
|
933
|
+
SUBTREE = "subtree"
|
|
934
|
+
MATCH = "match"
|
|
935
|
+
DO_WHILE = "do_while"
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
class _DefaultCase:
|
|
939
|
+
"""Sentinel for default case in match expressions."""
|
|
940
|
+
pass
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
@dataclass
|
|
944
|
+
class CaseSpec:
|
|
945
|
+
"""Specification for a case in a match expression."""
|
|
946
|
+
matcher: Any
|
|
947
|
+
child: "NodeSpec"
|
|
948
|
+
label: str = ""
|
|
949
|
+
|
|
950
|
+
def __post_init__(self):
|
|
951
|
+
if not self.label:
|
|
952
|
+
if self.matcher is _DefaultCase:
|
|
953
|
+
self.label = "default"
|
|
954
|
+
elif isinstance(self.matcher, type):
|
|
955
|
+
self.label = self.matcher.__name__
|
|
956
|
+
elif callable(self.matcher):
|
|
957
|
+
self.label = _name_of(self.matcher)
|
|
958
|
+
else:
|
|
959
|
+
self.label = repr(self.matcher)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
@dataclass
|
|
963
|
+
class NodeSpec:
|
|
964
|
+
kind: NodeSpecKind
|
|
965
|
+
name: str
|
|
966
|
+
payload: Any = None
|
|
967
|
+
children: List["NodeSpec"] = field(default_factory=list)
|
|
968
|
+
|
|
969
|
+
def __hash__(self):
|
|
970
|
+
return hash((self.kind, self.name, id(self.payload), tuple(self.children)))
|
|
971
|
+
|
|
972
|
+
def to_node(
|
|
973
|
+
self,
|
|
974
|
+
owner: Optional[Any] = None,
|
|
975
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
976
|
+
) -> Node[Any]:
|
|
977
|
+
match self.kind:
|
|
978
|
+
case NodeSpecKind.ACTION:
|
|
979
|
+
return Action(self.payload, exception_policy=exception_policy)
|
|
980
|
+
case NodeSpecKind.CONDITION:
|
|
981
|
+
return Condition(self.payload, exception_policy=exception_policy)
|
|
982
|
+
case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
|
|
983
|
+
owner_effective = getattr(self, "owner", None) or owner
|
|
984
|
+
factory = self.payload["factory"]
|
|
985
|
+
expanded = _bt_expand_children(factory, owner_effective)
|
|
986
|
+
self.children = expanded
|
|
987
|
+
built = [
|
|
988
|
+
ch.to_node(owner_effective, exception_policy) for ch in expanded
|
|
989
|
+
]
|
|
990
|
+
match self.kind:
|
|
991
|
+
case NodeSpecKind.SEQUENCE:
|
|
992
|
+
return Sequence(
|
|
993
|
+
built,
|
|
994
|
+
memory=self.payload.get("memory", True),
|
|
995
|
+
name=self.name,
|
|
996
|
+
exception_policy=exception_policy,
|
|
997
|
+
)
|
|
998
|
+
case NodeSpecKind.SELECTOR:
|
|
999
|
+
return Selector(
|
|
1000
|
+
built,
|
|
1001
|
+
memory=self.payload.get("memory", True),
|
|
1002
|
+
name=self.name,
|
|
1003
|
+
exception_policy=exception_policy,
|
|
1004
|
+
)
|
|
1005
|
+
case NodeSpecKind.PARALLEL:
|
|
1006
|
+
return Parallel(
|
|
1007
|
+
built,
|
|
1008
|
+
success_threshold=self.payload["success_threshold"],
|
|
1009
|
+
failure_threshold=self.payload.get("failure_threshold"),
|
|
1010
|
+
name=self.name,
|
|
1011
|
+
exception_policy=exception_policy,
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
case NodeSpecKind.DECORATOR:
|
|
1015
|
+
assert len(self.children) == 1, "Decorator must wrap exactly one child"
|
|
1016
|
+
child_node = self.children[0].to_node(owner, exception_policy)
|
|
1017
|
+
builder = self.payload
|
|
1018
|
+
return builder(child_node)
|
|
1019
|
+
|
|
1020
|
+
case NodeSpecKind.SUBTREE:
|
|
1021
|
+
subtree_root = self.payload["root"]
|
|
1022
|
+
return subtree_root.to_node(exception_policy=exception_policy)
|
|
1023
|
+
|
|
1024
|
+
case NodeSpecKind.MATCH:
|
|
1025
|
+
key_fn = self.payload["key_fn"]
|
|
1026
|
+
case_specs: List[CaseSpec] = self.payload["cases"]
|
|
1027
|
+
cases = [
|
|
1028
|
+
(cs.matcher, cs.child.to_node(owner, exception_policy))
|
|
1029
|
+
for cs in case_specs
|
|
1030
|
+
]
|
|
1031
|
+
return Match(
|
|
1032
|
+
key_fn,
|
|
1033
|
+
cases,
|
|
1034
|
+
name=self.name,
|
|
1035
|
+
exception_policy=exception_policy,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
case NodeSpecKind.DO_WHILE:
|
|
1039
|
+
cond_spec = self.payload["condition"]
|
|
1040
|
+
child_spec = self.children[0]
|
|
1041
|
+
return DoWhile(
|
|
1042
|
+
cond_spec.to_node(owner, exception_policy),
|
|
1043
|
+
child_spec.to_node(owner, exception_policy),
|
|
1044
|
+
name=self.name,
|
|
1045
|
+
exception_policy=exception_policy,
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
case _:
|
|
1049
|
+
raise ValueError(f"Unknown spec kind: {self.kind}")
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _bt_expand_children(
|
|
1053
|
+
factory: Callable[..., Generator[Any, None, None]],
|
|
1054
|
+
owner: Optional[Any],
|
|
1055
|
+
expansion_stack: Optional[Set[str]] = None,
|
|
1056
|
+
) -> List[NodeSpec]:
|
|
1057
|
+
"""
|
|
1058
|
+
Execute a composite factory to get child specs.
|
|
1059
|
+
|
|
1060
|
+
Args:
|
|
1061
|
+
factory: The generator function that yields child specs
|
|
1062
|
+
owner: The namespace object for resolving N references
|
|
1063
|
+
expansion_stack: Stack of factory names to detect recursion
|
|
1064
|
+
"""
|
|
1065
|
+
if expansion_stack is None:
|
|
1066
|
+
expansion_stack = set()
|
|
1067
|
+
|
|
1068
|
+
factory_name = _name_of(factory)
|
|
1069
|
+
if factory_name in expansion_stack:
|
|
1070
|
+
chain = " -> ".join(expansion_stack) + f" -> {factory_name}"
|
|
1071
|
+
raise RecursionError(
|
|
1072
|
+
f"Recursive behavior tree structure detected: {chain}\n"
|
|
1073
|
+
f"Behavior trees must be acyclic. Consider using:\n"
|
|
1074
|
+
f" - A Selector with memory to iterate through options\n"
|
|
1075
|
+
f" - A Sequence with conditions to control flow\n"
|
|
1076
|
+
f" - A Retry decorator for repeated attempts\n"
|
|
1077
|
+
f" - State in the blackboard to track progress"
|
|
1078
|
+
)
|
|
1079
|
+
|
|
1080
|
+
expansion_stack = expansion_stack.copy()
|
|
1081
|
+
expansion_stack.add(factory_name)
|
|
1082
|
+
|
|
1083
|
+
try:
|
|
1084
|
+
gen = factory(owner)
|
|
1085
|
+
except TypeError:
|
|
1086
|
+
gen = factory()
|
|
1087
|
+
|
|
1088
|
+
if not inspect.isgenerator(gen):
|
|
1089
|
+
raise TypeError(
|
|
1090
|
+
f"Composite factory {factory.__name__} must be a generator (use 'yield')."
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
out: List[NodeSpec] = []
|
|
1094
|
+
for yielded in gen:
|
|
1095
|
+
if isinstance(yielded, (list, tuple)):
|
|
1096
|
+
for y in yielded:
|
|
1097
|
+
spec = bt.as_spec(y)
|
|
1098
|
+
if (
|
|
1099
|
+
hasattr(spec, "payload")
|
|
1100
|
+
and isinstance(spec.payload, dict)
|
|
1101
|
+
and "factory" in spec.payload
|
|
1102
|
+
):
|
|
1103
|
+
spec._expansion_stack = expansion_stack # type: ignore
|
|
1104
|
+
out.append(spec)
|
|
1105
|
+
continue
|
|
1106
|
+
spec = bt.as_spec(yielded)
|
|
1107
|
+
if (
|
|
1108
|
+
hasattr(spec, "payload")
|
|
1109
|
+
and isinstance(spec.payload, dict)
|
|
1110
|
+
and "factory" in spec.payload
|
|
1111
|
+
):
|
|
1112
|
+
spec._expansion_stack = expansion_stack # type: ignore
|
|
1113
|
+
out.append(spec)
|
|
1114
|
+
|
|
1115
|
+
for spec in out:
|
|
1116
|
+
if spec.kind in (
|
|
1117
|
+
NodeSpecKind.SEQUENCE,
|
|
1118
|
+
NodeSpecKind.SELECTOR,
|
|
1119
|
+
NodeSpecKind.PARALLEL,
|
|
1120
|
+
) and hasattr(spec, "_expansion_stack"):
|
|
1121
|
+
child_factory = spec.payload["factory"]
|
|
1122
|
+
child_owner = getattr(spec, "owner", None) or owner
|
|
1123
|
+
_bt_expand_children(child_factory, child_owner, spec._expansion_stack) # type: ignore
|
|
1124
|
+
|
|
1125
|
+
return out
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
# --------------------------------------------------------------------------------------
|
|
1129
|
+
# Fluent wrapper chain (left-to-right)
|
|
1130
|
+
# --------------------------------------------------------------------------------------
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
class _WrapperChain:
|
|
1134
|
+
"""
|
|
1135
|
+
Fluent factory for decorator stacks that read left→right.
|
|
1136
|
+
|
|
1137
|
+
chain = bt.failer().gate(battery_ok).timeout(0.12)
|
|
1138
|
+
yield chain(engage)
|
|
1139
|
+
"""
|
|
1140
|
+
|
|
1141
|
+
def __init__(
|
|
1142
|
+
self,
|
|
1143
|
+
builders: Optional[List[Callable[..., Any]]] = None,
|
|
1144
|
+
labels: Optional[List[str]] = None,
|
|
1145
|
+
) -> None:
|
|
1146
|
+
self._builders: List[Callable[..., Any]] = list(builders or [])
|
|
1147
|
+
self._labels: List[str] = list(labels or [])
|
|
1148
|
+
|
|
1149
|
+
def _append(self, label: str, builder: Callable[..., Any]) -> "_WrapperChain":
|
|
1150
|
+
self._builders.append(builder)
|
|
1151
|
+
self._labels.append(label)
|
|
1152
|
+
return self
|
|
1153
|
+
|
|
1154
|
+
def failer(self) -> "_WrapperChain":
|
|
1155
|
+
return self._append("Failer", lambda ch: Failer(ch))
|
|
1156
|
+
|
|
1157
|
+
def succeeder(self) -> "_WrapperChain":
|
|
1158
|
+
return self._append("Succeeder", lambda ch: Succeeder(ch))
|
|
1159
|
+
|
|
1160
|
+
def inverter(self) -> "_WrapperChain":
|
|
1161
|
+
return self._append("Inverter", lambda ch: Inverter(ch))
|
|
1162
|
+
|
|
1163
|
+
def timeout(self, seconds: float) -> "_WrapperChain":
|
|
1164
|
+
return self._append(
|
|
1165
|
+
f"Timeout({seconds}s)", lambda ch: Timeout(ch, seconds=seconds)
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
def retry(
|
|
1169
|
+
self,
|
|
1170
|
+
max_attempts: int,
|
|
1171
|
+
retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
|
|
1172
|
+
) -> "_WrapperChain":
|
|
1173
|
+
return self._append(
|
|
1174
|
+
f"Retry({max_attempts})",
|
|
1175
|
+
lambda ch: Retry(ch, max_attempts=max_attempts, retry_on=retry_on),
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
def ratelimit(
|
|
1179
|
+
self, *, hz: Optional[float] = None, period: Optional[float] = None
|
|
1180
|
+
) -> "_WrapperChain":
|
|
1181
|
+
label = "RateLimit(?)"
|
|
1182
|
+
if hz is not None:
|
|
1183
|
+
label = f"RateLimit({1.0/float(hz):.6f}s)"
|
|
1184
|
+
elif period is not None:
|
|
1185
|
+
label = f"RateLimit({float(period):.6f}s)"
|
|
1186
|
+
return self._append(label, lambda ch: RateLimit(ch, hz=hz, period=period))
|
|
1187
|
+
|
|
1188
|
+
def gate(
|
|
1189
|
+
self, condition_spec_or_fn: Union["NodeSpec", Callable[[Any], Any]]
|
|
1190
|
+
) -> "_WrapperChain":
|
|
1191
|
+
cond_spec = bt.as_spec(condition_spec_or_fn)
|
|
1192
|
+
return self._append(
|
|
1193
|
+
f"Gate(cond={_name_of(cond_spec)})",
|
|
1194
|
+
lambda ch: Gate(cond_spec.to_node(), ch),
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
def __call__(self, inner: Union["NodeSpec", Callable[[Any], Any]]) -> "NodeSpec":
|
|
1198
|
+
"""
|
|
1199
|
+
Apply the chain to a child spec → nested decorator NodeSpecs.
|
|
1200
|
+
Left→right call order becomes outermost→…→innermost when built.
|
|
1201
|
+
"""
|
|
1202
|
+
spec = bt.as_spec(inner)
|
|
1203
|
+
result = spec
|
|
1204
|
+
for label, builder in reversed(list(zip(self._labels, self._builders))):
|
|
1205
|
+
result = NodeSpec(
|
|
1206
|
+
kind=NodeSpecKind.DECORATOR,
|
|
1207
|
+
name=f"{label}({_name_of(result)})",
|
|
1208
|
+
payload=builder,
|
|
1209
|
+
children=[result],
|
|
1210
|
+
)
|
|
1211
|
+
return result
|
|
1212
|
+
|
|
1213
|
+
def __rshift__(self, inner: Union["NodeSpec", Callable[[Any], Any]]) -> "NodeSpec":
|
|
1214
|
+
return self(inner)
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
# --------------------------------------------------------------------------------------
|
|
1218
|
+
# Match expression builders
|
|
1219
|
+
# --------------------------------------------------------------------------------------
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
class _CaseBuilder:
|
|
1223
|
+
"""Builder for individual match cases."""
|
|
1224
|
+
|
|
1225
|
+
def __init__(self, matcher: Any) -> None:
|
|
1226
|
+
self._matcher = matcher
|
|
1227
|
+
|
|
1228
|
+
def __call__(self, child: Union["NodeSpec", Callable[[Any], Any]]) -> CaseSpec:
|
|
1229
|
+
child_spec = bt.as_spec(child)
|
|
1230
|
+
return CaseSpec(matcher=self._matcher, child=child_spec)
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
class _MatchBuilder:
|
|
1234
|
+
"""Builder for match expressions."""
|
|
1235
|
+
|
|
1236
|
+
def __init__(self, key_fn: Callable[[Any], Any], name: Optional[str] = None) -> None:
|
|
1237
|
+
self._key_fn = key_fn
|
|
1238
|
+
self._name = name
|
|
1239
|
+
|
|
1240
|
+
def __call__(self, *cases: CaseSpec) -> NodeSpec:
|
|
1241
|
+
if not cases:
|
|
1242
|
+
raise ValueError("bt.match() requires at least one case")
|
|
1243
|
+
|
|
1244
|
+
for case in cases:
|
|
1245
|
+
if not isinstance(case, CaseSpec):
|
|
1246
|
+
raise TypeError(
|
|
1247
|
+
f"bt.match() expects CaseSpec instances (from bt.case() or bt.defaultcase()), "
|
|
1248
|
+
f"got {type(case).__name__}"
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
children = [case.child for case in cases]
|
|
1252
|
+
|
|
1253
|
+
display_name = self._name or _name_of(self._key_fn)
|
|
1254
|
+
|
|
1255
|
+
return NodeSpec(
|
|
1256
|
+
kind=NodeSpecKind.MATCH,
|
|
1257
|
+
name=f"Match({display_name})",
|
|
1258
|
+
payload={
|
|
1259
|
+
"key_fn": self._key_fn,
|
|
1260
|
+
"cases": list(cases),
|
|
1261
|
+
},
|
|
1262
|
+
children=children,
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
class _DoWhileBuilder:
|
|
1267
|
+
"""Builder for do_while loops."""
|
|
1268
|
+
|
|
1269
|
+
def __init__(self, condition_spec: NodeSpec) -> None:
|
|
1270
|
+
self._condition_spec = condition_spec
|
|
1271
|
+
|
|
1272
|
+
def __call__(self, child: Union["NodeSpec", Callable[[Any], Any]]) -> NodeSpec:
|
|
1273
|
+
child_spec = bt.as_spec(child)
|
|
1274
|
+
return NodeSpec(
|
|
1275
|
+
kind=NodeSpecKind.DO_WHILE,
|
|
1276
|
+
name=f"DoWhile({_name_of(self._condition_spec)})",
|
|
1277
|
+
payload={
|
|
1278
|
+
"condition": self._condition_spec,
|
|
1279
|
+
},
|
|
1280
|
+
children=[child_spec],
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
# --------------------------------------------------------------------------------------
|
|
1285
|
+
# User-facing decorator/constructor namespace
|
|
1286
|
+
# --------------------------------------------------------------------------------------
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
class _BT:
|
|
1290
|
+
"""User-facing decorator/constructor namespace."""
|
|
1291
|
+
|
|
1292
|
+
# TODO: Fix decorators to not throw a fit when passing timebase arguments to actions/conditions
|
|
1293
|
+
def __init__(self):
|
|
1294
|
+
self._tracking_stack: List[List[Tuple[str, Any]]] = []
|
|
1295
|
+
|
|
1296
|
+
def action(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
1297
|
+
"""Decorator to mark a function as an action node."""
|
|
1298
|
+
spec = NodeSpec(kind=NodeSpecKind.ACTION, name=_name_of(fn), payload=fn)
|
|
1299
|
+
fn.node_spec = spec # type: ignore
|
|
1300
|
+
if self._tracking_stack:
|
|
1301
|
+
self._tracking_stack[-1].append((fn.__name__, fn))
|
|
1302
|
+
return fn
|
|
1303
|
+
|
|
1304
|
+
def condition(self, fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
1305
|
+
"""Decorator to mark a function as a condition node."""
|
|
1306
|
+
spec = NodeSpec(kind=NodeSpecKind.CONDITION, name=_name_of(fn), payload=fn)
|
|
1307
|
+
fn.node_spec = spec # type: ignore
|
|
1308
|
+
if self._tracking_stack:
|
|
1309
|
+
self._tracking_stack[-1].append((fn.__name__, fn))
|
|
1310
|
+
return fn
|
|
1311
|
+
|
|
1312
|
+
def sequence(self, *, memory: bool = True):
|
|
1313
|
+
"""Decorator to mark a generator function as a sequence composite."""
|
|
1314
|
+
|
|
1315
|
+
def deco(
|
|
1316
|
+
factory: Callable[..., Generator[Any, None, None]],
|
|
1317
|
+
) -> Callable[..., Generator[Any, None, None]]:
|
|
1318
|
+
spec = NodeSpec(
|
|
1319
|
+
kind=NodeSpecKind.SEQUENCE,
|
|
1320
|
+
name=_name_of(factory),
|
|
1321
|
+
payload={"factory": factory, "memory": memory},
|
|
1322
|
+
)
|
|
1323
|
+
factory.node_spec = spec # type: ignore
|
|
1324
|
+
if self._tracking_stack:
|
|
1325
|
+
self._tracking_stack[-1].append((factory.__name__, factory))
|
|
1326
|
+
return factory
|
|
1327
|
+
|
|
1328
|
+
return deco
|
|
1329
|
+
|
|
1330
|
+
def selector(self, *, memory: bool = True, reactive: bool = False):
|
|
1331
|
+
"""Decorator to mark a generator function as a selector composite."""
|
|
1332
|
+
|
|
1333
|
+
def deco(
|
|
1334
|
+
factory: Callable[..., Generator[Any, None, None]],
|
|
1335
|
+
) -> Callable[..., Generator[Any, None, None]]:
|
|
1336
|
+
spec = NodeSpec(
|
|
1337
|
+
kind=NodeSpecKind.SELECTOR,
|
|
1338
|
+
name=_name_of(factory),
|
|
1339
|
+
payload={"factory": factory, "memory": memory, "reactive": reactive},
|
|
1340
|
+
)
|
|
1341
|
+
factory.node_spec = spec # type: ignore
|
|
1342
|
+
if self._tracking_stack:
|
|
1343
|
+
self._tracking_stack[-1].append((factory.__name__, factory))
|
|
1344
|
+
return factory
|
|
1345
|
+
|
|
1346
|
+
return deco
|
|
1347
|
+
|
|
1348
|
+
def parallel(
|
|
1349
|
+
self, *, success_threshold: int, failure_threshold: Optional[int] = None
|
|
1350
|
+
):
|
|
1351
|
+
"""Decorator to mark a generator function as a parallel composite."""
|
|
1352
|
+
|
|
1353
|
+
def deco(
|
|
1354
|
+
factory: Callable[..., Generator[Any, None, None]],
|
|
1355
|
+
) -> Callable[..., Generator[Any, None, None]]:
|
|
1356
|
+
spec = NodeSpec(
|
|
1357
|
+
kind=NodeSpecKind.PARALLEL,
|
|
1358
|
+
name=_name_of(factory),
|
|
1359
|
+
payload={
|
|
1360
|
+
"factory": factory,
|
|
1361
|
+
"success_threshold": success_threshold,
|
|
1362
|
+
"failure_threshold": failure_threshold,
|
|
1363
|
+
},
|
|
1364
|
+
)
|
|
1365
|
+
factory.node_spec = spec # type: ignore
|
|
1366
|
+
if self._tracking_stack:
|
|
1367
|
+
self._tracking_stack[-1].append((factory.__name__, factory))
|
|
1368
|
+
return factory
|
|
1369
|
+
|
|
1370
|
+
return deco
|
|
1371
|
+
|
|
1372
|
+
def inverter(self) -> _WrapperChain:
|
|
1373
|
+
return _WrapperChain().inverter()
|
|
1374
|
+
|
|
1375
|
+
def succeeder(self) -> _WrapperChain:
|
|
1376
|
+
return _WrapperChain().succeeder()
|
|
1377
|
+
|
|
1378
|
+
def failer(self) -> _WrapperChain:
|
|
1379
|
+
return _WrapperChain().failer()
|
|
1380
|
+
|
|
1381
|
+
def timeout(self, seconds: float) -> _WrapperChain:
|
|
1382
|
+
return _WrapperChain().timeout(seconds)
|
|
1383
|
+
|
|
1384
|
+
def retry(
|
|
1385
|
+
self,
|
|
1386
|
+
max_attempts: int,
|
|
1387
|
+
retry_on: Tuple[Status, ...] = (Status.FAILURE, Status.ERROR),
|
|
1388
|
+
) -> _WrapperChain:
|
|
1389
|
+
return _WrapperChain().retry(max_attempts, retry_on=retry_on)
|
|
1390
|
+
|
|
1391
|
+
def ratelimit(
|
|
1392
|
+
self, hz: Optional[float] = None, period: Optional[float] = None
|
|
1393
|
+
) -> _WrapperChain:
|
|
1394
|
+
return _WrapperChain().ratelimit(hz=hz, period=period)
|
|
1395
|
+
|
|
1396
|
+
def gate(self, condition: Union[NodeSpec, Callable[[Any], Any]]) -> _WrapperChain:
|
|
1397
|
+
cond_spec = self.as_spec(condition)
|
|
1398
|
+
return _WrapperChain().gate(cond_spec)
|
|
1399
|
+
|
|
1400
|
+
def match(
|
|
1401
|
+
self, key_fn: Callable[[Any], Any], name: Optional[str] = None
|
|
1402
|
+
) -> "_MatchBuilder":
|
|
1403
|
+
"""
|
|
1404
|
+
Create a pattern-matching dispatch node.
|
|
1405
|
+
|
|
1406
|
+
Usage:
|
|
1407
|
+
bt.match(lambda bb: bb.current_action, name="action_type")(
|
|
1408
|
+
bt.case(ImageAction)(bt.subtree(handle_image)),
|
|
1409
|
+
bt.case(MoveAction)(bt.subtree(handle_move)),
|
|
1410
|
+
bt.case(lambda a: a.priority > 5)(bt.subtree(handle_urgent)),
|
|
1411
|
+
bt.defaultcase(log_unknown),
|
|
1412
|
+
)
|
|
1413
|
+
|
|
1414
|
+
Args:
|
|
1415
|
+
key_fn: Function that extracts the value to match against from the blackboard
|
|
1416
|
+
name: Optional display name for the match node (useful when key_fn is a lambda)
|
|
1417
|
+
|
|
1418
|
+
Returns:
|
|
1419
|
+
A builder that accepts case specs and returns a NodeSpec
|
|
1420
|
+
"""
|
|
1421
|
+
return _MatchBuilder(key_fn, name=name)
|
|
1422
|
+
|
|
1423
|
+
def case(self, matcher: Any) -> "_CaseBuilder":
|
|
1424
|
+
"""
|
|
1425
|
+
Define a case for a match expression.
|
|
1426
|
+
|
|
1427
|
+
Args:
|
|
1428
|
+
matcher: Can be:
|
|
1429
|
+
- A type: matches via isinstance(value, matcher)
|
|
1430
|
+
- A callable: matches if matcher(value) returns True
|
|
1431
|
+
- Any other value: matches via value == matcher
|
|
1432
|
+
|
|
1433
|
+
Returns:
|
|
1434
|
+
A builder that accepts a child node spec
|
|
1435
|
+
"""
|
|
1436
|
+
return _CaseBuilder(matcher)
|
|
1437
|
+
|
|
1438
|
+
def defaultcase(self, child: Union[NodeSpec, Callable[[Any], Any]]) -> CaseSpec:
|
|
1439
|
+
"""
|
|
1440
|
+
Define a default case for a match expression (matches anything).
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
child: The node to execute if this case matches
|
|
1444
|
+
|
|
1445
|
+
Returns:
|
|
1446
|
+
A CaseSpec that always matches
|
|
1447
|
+
"""
|
|
1448
|
+
child_spec = self.as_spec(child)
|
|
1449
|
+
return CaseSpec(matcher=_DefaultCase, child=child_spec, label="default")
|
|
1450
|
+
|
|
1451
|
+
def do_while(
|
|
1452
|
+
self, condition: Union[NodeSpec, Callable[[Any], Any]]
|
|
1453
|
+
) -> "_DoWhileBuilder":
|
|
1454
|
+
"""
|
|
1455
|
+
Create a loop that repeats its child while a condition is true.
|
|
1456
|
+
|
|
1457
|
+
Usage:
|
|
1458
|
+
@bt.condition
|
|
1459
|
+
def samples_remain(bb):
|
|
1460
|
+
return bb.sample_index < bb.total_samples
|
|
1461
|
+
|
|
1462
|
+
yield bt.do_while(samples_remain)(process_sample)
|
|
1463
|
+
|
|
1464
|
+
Behavior:
|
|
1465
|
+
1. Evaluate condition
|
|
1466
|
+
2. If condition is FALSE → return SUCCESS (loop complete)
|
|
1467
|
+
3. If condition is TRUE → tick child
|
|
1468
|
+
- If child returns RUNNING → return RUNNING (resume child next tick)
|
|
1469
|
+
- If child returns SUCCESS → reset child, return RUNNING (re-check next tick)
|
|
1470
|
+
- If child returns FAILURE → return FAILURE (loop aborted)
|
|
1471
|
+
|
|
1472
|
+
Args:
|
|
1473
|
+
condition: A condition node or function to evaluate each iteration
|
|
1474
|
+
|
|
1475
|
+
Returns:
|
|
1476
|
+
A builder that accepts a child node spec
|
|
1477
|
+
"""
|
|
1478
|
+
cond_spec = self.as_spec(condition)
|
|
1479
|
+
return _DoWhileBuilder(cond_spec)
|
|
1480
|
+
|
|
1481
|
+
def subtree(self, tree: SimpleNamespace) -> NodeSpec:
|
|
1482
|
+
"""
|
|
1483
|
+
Mount another tree's root spec as a subtree.
|
|
1484
|
+
|
|
1485
|
+
Args:
|
|
1486
|
+
tree: A tree namespace created with @bt.tree
|
|
1487
|
+
"""
|
|
1488
|
+
if not hasattr(tree, "root"):
|
|
1489
|
+
raise ValueError(
|
|
1490
|
+
f"Tree namespace must have a 'root' attribute. Did you forget @bt.root on a composite?"
|
|
1491
|
+
)
|
|
1492
|
+
|
|
1493
|
+
root_spec = tree.root
|
|
1494
|
+
tree_name = getattr(tree, "_tree_name", root_spec.name)
|
|
1495
|
+
return NodeSpec(
|
|
1496
|
+
kind=NodeSpecKind.SUBTREE,
|
|
1497
|
+
name=tree_name,
|
|
1498
|
+
payload={"root": root_spec},
|
|
1499
|
+
children=[root_spec],
|
|
1500
|
+
)
|
|
1501
|
+
|
|
1502
|
+
def tree(self, fn: Callable[[], Any]) -> SimpleNamespace:
|
|
1503
|
+
"""
|
|
1504
|
+
Decorator to create a behavior tree namespace.
|
|
1505
|
+
|
|
1506
|
+
Usage:
|
|
1507
|
+
@bt.tree
|
|
1508
|
+
def MyTree():
|
|
1509
|
+
@bt.action
|
|
1510
|
+
def my_action(bb: BB) -> Status:
|
|
1511
|
+
...
|
|
1512
|
+
|
|
1513
|
+
@bt.sequence()
|
|
1514
|
+
def root(N):
|
|
1515
|
+
yield N.my_action
|
|
1516
|
+
"""
|
|
1517
|
+
created_nodes = []
|
|
1518
|
+
self._tracking_stack.append(created_nodes)
|
|
1519
|
+
|
|
1520
|
+
try:
|
|
1521
|
+
fn()
|
|
1522
|
+
finally:
|
|
1523
|
+
self._tracking_stack.pop()
|
|
1524
|
+
|
|
1525
|
+
nodes = {
|
|
1526
|
+
name: node for name, node in created_nodes if hasattr(node, "node_spec")
|
|
1527
|
+
}
|
|
1528
|
+
namespace = SimpleNamespace(**nodes)
|
|
1529
|
+
|
|
1530
|
+
# Store the tree's name for use in subtree references
|
|
1531
|
+
namespace._tree_name = fn.__name__
|
|
1532
|
+
|
|
1533
|
+
for name, node in nodes.items():
|
|
1534
|
+
if hasattr(node, "node_spec"):
|
|
1535
|
+
node.node_spec.owner = namespace
|
|
1536
|
+
|
|
1537
|
+
root_nodes = [v for v in nodes.values() if hasattr(v.node_spec, "is_root")]
|
|
1538
|
+
if root_nodes:
|
|
1539
|
+
namespace.root = root_nodes[0].node_spec
|
|
1540
|
+
|
|
1541
|
+
namespace.to_mermaid = lambda: _generate_mermaid(namespace)
|
|
1542
|
+
|
|
1543
|
+
return namespace
|
|
1544
|
+
|
|
1545
|
+
def root(
|
|
1546
|
+
self, fn: Callable[..., Generator[Any, None, None]]
|
|
1547
|
+
) -> Callable[..., Generator[Any, None, None]]:
|
|
1548
|
+
"""Mark a composite as the root of the tree."""
|
|
1549
|
+
fn.node_spec.is_root = True # type: ignore
|
|
1550
|
+
return fn
|
|
1551
|
+
|
|
1552
|
+
def as_spec(self, maybe: Union[NodeSpec, Callable[[Any], Any]]) -> NodeSpec:
|
|
1553
|
+
"""Convert a function or NodeSpec to a NodeSpec."""
|
|
1554
|
+
if isinstance(maybe, NodeSpec):
|
|
1555
|
+
return maybe
|
|
1556
|
+
|
|
1557
|
+
spec = getattr(maybe, "node_spec", None)
|
|
1558
|
+
if spec is None:
|
|
1559
|
+
raise TypeError(
|
|
1560
|
+
f"{maybe!r} is not a BT node (missing node_spec attribute)."
|
|
1561
|
+
)
|
|
1562
|
+
return spec
|
|
1563
|
+
|
|
1564
|
+
|
|
1565
|
+
bt = _BT()
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
# ======================================================================================
|
|
1569
|
+
# Mermaid Generation
|
|
1570
|
+
# ======================================================================================
|
|
1571
|
+
|
|
1572
|
+
|
|
1573
|
+
def _generate_mermaid(tree: SimpleNamespace) -> str:
|
|
1574
|
+
"""
|
|
1575
|
+
Render a static structure graph for a tree namespace.
|
|
1576
|
+
"""
|
|
1577
|
+
if not hasattr(tree, "root"):
|
|
1578
|
+
raise ValueError("Tree namespace must have a 'root' attribute")
|
|
1579
|
+
|
|
1580
|
+
lines: List[str] = ["flowchart TD"]
|
|
1581
|
+
node_ids: Dict[NodeSpec, str] = {}
|
|
1582
|
+
counter = 0
|
|
1583
|
+
|
|
1584
|
+
def nid(spec: NodeSpec) -> str:
|
|
1585
|
+
nonlocal counter
|
|
1586
|
+
if spec in node_ids:
|
|
1587
|
+
return node_ids[spec]
|
|
1588
|
+
counter += 1
|
|
1589
|
+
node_ids[spec] = f"N{counter}"
|
|
1590
|
+
return node_ids[spec]
|
|
1591
|
+
|
|
1592
|
+
def label(spec: NodeSpec) -> str:
|
|
1593
|
+
match spec.kind:
|
|
1594
|
+
case NodeSpecKind.ACTION | NodeSpecKind.CONDITION:
|
|
1595
|
+
return f"{spec.kind.value.upper()}<br/>{spec.name}"
|
|
1596
|
+
case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
|
|
1597
|
+
return f"{spec.kind.value.capitalize()}<br/>{spec.name}"
|
|
1598
|
+
case NodeSpecKind.DECORATOR:
|
|
1599
|
+
return f"Decor<br/>{spec.name}"
|
|
1600
|
+
case NodeSpecKind.SUBTREE:
|
|
1601
|
+
return f"Subtree<br/>{spec.name}"
|
|
1602
|
+
case NodeSpecKind.MATCH:
|
|
1603
|
+
return f"Match<br/>{spec.name}"
|
|
1604
|
+
case NodeSpecKind.DO_WHILE:
|
|
1605
|
+
return f"DoWhile<br/>{spec.name}"
|
|
1606
|
+
case _:
|
|
1607
|
+
return spec.name
|
|
1608
|
+
|
|
1609
|
+
def ensure_children(spec: NodeSpec) -> List[NodeSpec]:
|
|
1610
|
+
match spec.kind:
|
|
1611
|
+
case NodeSpecKind.SEQUENCE | NodeSpecKind.SELECTOR | NodeSpecKind.PARALLEL:
|
|
1612
|
+
owner = getattr(spec, "owner", None) or tree
|
|
1613
|
+
factory = spec.payload["factory"]
|
|
1614
|
+
spec.children = _bt_expand_children(factory, owner)
|
|
1615
|
+
case NodeSpecKind.SUBTREE:
|
|
1616
|
+
subtree_root = spec.payload["root"]
|
|
1617
|
+
spec.children = [subtree_root]
|
|
1618
|
+
case NodeSpecKind.MATCH:
|
|
1619
|
+
case_specs: List[CaseSpec] = spec.payload["cases"]
|
|
1620
|
+
spec.children = [cs.child for cs in case_specs]
|
|
1621
|
+
case NodeSpecKind.DO_WHILE:
|
|
1622
|
+
# Children already set (just the body), but we also want to show condition
|
|
1623
|
+
cond_spec = spec.payload["condition"]
|
|
1624
|
+
spec.children = [cond_spec] + spec.children
|
|
1625
|
+
return spec.children
|
|
1626
|
+
|
|
1627
|
+
def walk(spec: NodeSpec) -> None:
|
|
1628
|
+
this_id = nid(spec)
|
|
1629
|
+
shape = (
|
|
1630
|
+
"((%s))"
|
|
1631
|
+
if spec.kind in (NodeSpecKind.ACTION, NodeSpecKind.CONDITION)
|
|
1632
|
+
else '["%s"]'
|
|
1633
|
+
) % label(spec)
|
|
1634
|
+
lines.append(f" {this_id}{shape}")
|
|
1635
|
+
|
|
1636
|
+
children = ensure_children(spec)
|
|
1637
|
+
|
|
1638
|
+
if spec.kind == NodeSpecKind.MATCH:
|
|
1639
|
+
case_specs: List[CaseSpec] = spec.payload["cases"]
|
|
1640
|
+
for case_spec, child in zip(case_specs, children):
|
|
1641
|
+
child_id = nid(child)
|
|
1642
|
+
edge_label = case_spec.label.replace('"', "'")
|
|
1643
|
+
lines.append(f' {this_id} -->|"{edge_label}"| {child_id}')
|
|
1644
|
+
walk(child)
|
|
1645
|
+
elif spec.kind == NodeSpecKind.DO_WHILE:
|
|
1646
|
+
# First child is condition, second is body
|
|
1647
|
+
cond_id = nid(children[0])
|
|
1648
|
+
lines.append(f' {this_id} -->|"condition"| {cond_id}')
|
|
1649
|
+
walk(children[0])
|
|
1650
|
+
if len(children) > 1:
|
|
1651
|
+
body_id = nid(children[1])
|
|
1652
|
+
lines.append(f' {this_id} -->|"body"| {body_id}')
|
|
1653
|
+
walk(children[1])
|
|
1654
|
+
else:
|
|
1655
|
+
for child in children:
|
|
1656
|
+
child_id = nid(child)
|
|
1657
|
+
lines.append(f" {this_id} --> {child_id}")
|
|
1658
|
+
walk(child)
|
|
1659
|
+
|
|
1660
|
+
walk(tree.root)
|
|
1661
|
+
return "\n".join(lines)
|
|
1662
|
+
|
|
1663
|
+
|
|
1664
|
+
# ======================================================================================
|
|
1665
|
+
# Runner
|
|
1666
|
+
# ======================================================================================
|
|
1667
|
+
|
|
1668
|
+
|
|
1669
|
+
class Runner(Generic[BB]):
|
|
1670
|
+
"""Runtime for executing behavior trees.
|
|
1671
|
+
|
|
1672
|
+
The Runner manages the execution of a behavior tree, handling tick
|
|
1673
|
+
calls, timebase management, and result processing.
|
|
1674
|
+
|
|
1675
|
+
Args:
|
|
1676
|
+
tree: A behavior tree namespace with a 'root' attribute (from @bt.tree)
|
|
1677
|
+
bb: Blackboard containing shared state
|
|
1678
|
+
tb: Optional timebase for time management (defaults to MonotonicClock)
|
|
1679
|
+
exception_policy: How to handle exceptions during tree execution
|
|
1680
|
+
|
|
1681
|
+
Methods:
|
|
1682
|
+
tick(): Execute one tick of the behavior tree
|
|
1683
|
+
tick_until_complete(): Run the tree until it returns a terminal status
|
|
1684
|
+
|
|
1685
|
+
Example:
|
|
1686
|
+
@bt.tree
|
|
1687
|
+
def MyTree():
|
|
1688
|
+
@bt.root
|
|
1689
|
+
@bt.sequence
|
|
1690
|
+
def root():
|
|
1691
|
+
yield check_condition
|
|
1692
|
+
yield do_work
|
|
1693
|
+
|
|
1694
|
+
runner = Runner(MyTree, bb=blackboard)
|
|
1695
|
+
result = await runner.tick_until_complete()
|
|
1696
|
+
"""
|
|
1697
|
+
def __init__(
|
|
1698
|
+
self,
|
|
1699
|
+
tree: SimpleNamespace,
|
|
1700
|
+
bb: BB,
|
|
1701
|
+
tb: Optional[Timebase] = None,
|
|
1702
|
+
exception_policy: ExceptionPolicy = ExceptionPolicy.LOG_AND_CONTINUE,
|
|
1703
|
+
) -> None:
|
|
1704
|
+
self.tree = tree
|
|
1705
|
+
self.bb: BB = bb
|
|
1706
|
+
self.tb = tb or MonotonicClock()
|
|
1707
|
+
self.exception_policy = exception_policy
|
|
1708
|
+
|
|
1709
|
+
if not hasattr(tree, "root"):
|
|
1710
|
+
raise ValueError("Tree namespace must have a 'root' attribute")
|
|
1711
|
+
|
|
1712
|
+
self.root: Node[BB] = tree.root.to_node(
|
|
1713
|
+
owner=tree, exception_policy=exception_policy
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
async def tick(self) -> Status:
|
|
1717
|
+
result = await self.root.tick(self.bb, self.tb)
|
|
1718
|
+
self.tb.advance()
|
|
1719
|
+
return result
|
|
1720
|
+
|
|
1721
|
+
async def tick_until_complete(self, timeout: Optional[float] = None) -> Status:
|
|
1722
|
+
start = self.tb.now()
|
|
1723
|
+
while True:
|
|
1724
|
+
st = await self.tick()
|
|
1725
|
+
if st in (Status.SUCCESS, Status.FAILURE, Status.CANCELLED, Status.ERROR):
|
|
1726
|
+
return st
|
|
1727
|
+
if timeout is not None and (self.tb.now() - start) > timeout:
|
|
1728
|
+
return Status.CANCELLED
|
|
1729
|
+
await asyncio.sleep(0)
|