mycorrhizal 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mycorrhizal/__init__.py +3 -0
- mycorrhizal/common/__init__.py +68 -0
- mycorrhizal/common/interface_builder.py +203 -0
- mycorrhizal/common/interfaces.py +412 -0
- mycorrhizal/common/timebase.py +99 -0
- mycorrhizal/common/wrappers.py +532 -0
- mycorrhizal/enoki/__init__.py +0 -0
- mycorrhizal/enoki/core.py +1545 -0
- mycorrhizal/enoki/testing_utils.py +529 -0
- mycorrhizal/enoki/util.py +220 -0
- mycorrhizal/hypha/__init__.py +0 -0
- mycorrhizal/hypha/core/__init__.py +107 -0
- mycorrhizal/hypha/core/builder.py +404 -0
- mycorrhizal/hypha/core/runtime.py +890 -0
- mycorrhizal/hypha/core/specs.py +234 -0
- mycorrhizal/hypha/util.py +38 -0
- mycorrhizal/rhizomorph/README.md +220 -0
- mycorrhizal/rhizomorph/__init__.py +0 -0
- mycorrhizal/rhizomorph/core.py +1729 -0
- mycorrhizal/rhizomorph/util.py +45 -0
- mycorrhizal/spores/__init__.py +124 -0
- mycorrhizal/spores/cache.py +208 -0
- mycorrhizal/spores/core.py +419 -0
- mycorrhizal/spores/dsl/__init__.py +48 -0
- mycorrhizal/spores/dsl/enoki.py +514 -0
- mycorrhizal/spores/dsl/hypha.py +399 -0
- mycorrhizal/spores/dsl/rhizomorph.py +351 -0
- mycorrhizal/spores/encoder/__init__.py +11 -0
- mycorrhizal/spores/encoder/base.py +42 -0
- mycorrhizal/spores/encoder/json.py +159 -0
- mycorrhizal/spores/extraction.py +484 -0
- mycorrhizal/spores/models.py +288 -0
- mycorrhizal/spores/transport/__init__.py +10 -0
- mycorrhizal/spores/transport/base.py +46 -0
- mycorrhizal-0.1.0.dist-info/METADATA +198 -0
- mycorrhizal-0.1.0.dist-info/RECORD +37 -0
- mycorrhizal-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1545 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Enoki - Asyncio Finite State Machine Framework
|
|
4
|
+
|
|
5
|
+
A decorator-based DSL for defining and executing state machines with support for
|
|
6
|
+
asyncio, timeouts, message passing, and hierarchical state composition.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from mycorrhizal.enoki.core import enoki, StateMachine, LabeledTransition
|
|
10
|
+
from enum import Enum, auto
|
|
11
|
+
|
|
12
|
+
@enoki.state()
|
|
13
|
+
def IdleState():
|
|
14
|
+
class Events(Enum):
|
|
15
|
+
START = auto()
|
|
16
|
+
QUIT = auto()
|
|
17
|
+
|
|
18
|
+
@enoki.on_state
|
|
19
|
+
async def on_state(ctx):
|
|
20
|
+
if ctx.msg == "start":
|
|
21
|
+
return Events.START
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
@enoki.transitions
|
|
25
|
+
def transitions():
|
|
26
|
+
return [
|
|
27
|
+
LabeledTransition(Events.START, ProcessingState),
|
|
28
|
+
LabeledTransition(Events.QUIT, DoneState),
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
# Create and run the FSM
|
|
32
|
+
fsm = StateMachine(initial_state=IdleState, common_data={})
|
|
33
|
+
await fsm.initialize()
|
|
34
|
+
fsm.send_message("start")
|
|
35
|
+
await fsm.tick()
|
|
36
|
+
|
|
37
|
+
Key Classes:
|
|
38
|
+
StateMachine - Main FSM runtime with message queue and tick-based execution
|
|
39
|
+
StateConfiguration - Configuration for states (timeout, retries, terminal, can_dwell)
|
|
40
|
+
SharedContext - Context object passed to state handlers with msg and common data
|
|
41
|
+
LabeledTransition - Maps events to target states
|
|
42
|
+
|
|
43
|
+
Transition Types:
|
|
44
|
+
State references - Direct transition to another state
|
|
45
|
+
Again - Re-execute current state immediately
|
|
46
|
+
Unhandled - Wait for next message
|
|
47
|
+
Retry - Re-enter state with retry counter
|
|
48
|
+
Restart - Reset retry counter and wait for message
|
|
49
|
+
Repeat - Re-enter state from on_enter
|
|
50
|
+
Push(state1, state2, ...) - Push states onto stack
|
|
51
|
+
Pop - Pop and return to previous state
|
|
52
|
+
|
|
53
|
+
State Handlers:
|
|
54
|
+
on_state(ctx) - Main state logic, return transition or None to wait
|
|
55
|
+
on_enter(ctx) - Called when entering state
|
|
56
|
+
on_leave(ctx) - Called when leaving state
|
|
57
|
+
on_timeout(ctx) - Called if timeout expires
|
|
58
|
+
on_fail(ctx) - Called if exception occurs
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
from __future__ import annotations
|
|
62
|
+
import inspect
|
|
63
|
+
import sys
|
|
64
|
+
import os
|
|
65
|
+
import time
|
|
66
|
+
import traceback
|
|
67
|
+
import asyncio
|
|
68
|
+
from pathlib import Path
|
|
69
|
+
from asyncio import PriorityQueue
|
|
70
|
+
from dataclasses import dataclass, field
|
|
71
|
+
from enum import Enum, auto, IntEnum
|
|
72
|
+
from typing import (
|
|
73
|
+
Any,
|
|
74
|
+
Callable,
|
|
75
|
+
Dict,
|
|
76
|
+
List,
|
|
77
|
+
Optional,
|
|
78
|
+
Union,
|
|
79
|
+
Awaitable,
|
|
80
|
+
TypeVar,
|
|
81
|
+
Generic,
|
|
82
|
+
)
|
|
83
|
+
from functools import cache
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
T = TypeVar('T')
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ============================================================================
|
|
90
|
+
# Interface Integration Helper
|
|
91
|
+
# ============================================================================
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _create_interface_view_for_context(context: SharedContext, handler: Callable) -> SharedContext:
|
|
95
|
+
"""
|
|
96
|
+
Create a constrained view of context.common if the handler has an interface type hint.
|
|
97
|
+
|
|
98
|
+
This enables type-safe, constrained access to blackboard state based on
|
|
99
|
+
interface definitions created with @blackboard_interface.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
context: The SharedContext object
|
|
103
|
+
handler: The state handler function to check for interface type hints
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Either the original context or a new context with constrained common field
|
|
107
|
+
"""
|
|
108
|
+
from typing import get_type_hints
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
sig = inspect.signature(handler)
|
|
112
|
+
params = list(sig.parameters.values())
|
|
113
|
+
|
|
114
|
+
# Check first parameter (should be SharedContext)
|
|
115
|
+
if params and params[0].name == 'ctx':
|
|
116
|
+
ctx_type = get_type_hints(handler).get('ctx')
|
|
117
|
+
|
|
118
|
+
# If type hint exists and is a generic SharedContext with interface
|
|
119
|
+
if ctx_type and hasattr(ctx_type, '__args__'):
|
|
120
|
+
# Extract the type argument from SharedContext[InterfaceType]
|
|
121
|
+
interface_type = ctx_type.__args__[0]
|
|
122
|
+
|
|
123
|
+
# If it has interface metadata, create constrained view
|
|
124
|
+
if hasattr(interface_type, '_readonly_fields'):
|
|
125
|
+
from mycorrhizal.common.wrappers import create_view_from_protocol
|
|
126
|
+
from dataclasses import replace
|
|
127
|
+
|
|
128
|
+
# Create constrained view of common
|
|
129
|
+
constrained_common = create_view_from_protocol(context.common, interface_type)
|
|
130
|
+
|
|
131
|
+
# Return new context with constrained common
|
|
132
|
+
return replace(context, common=constrained_common)
|
|
133
|
+
except Exception:
|
|
134
|
+
# If anything goes wrong with type inspection, fall back to original context
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
return context
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class StateConfiguration:
|
|
142
|
+
"""Configuration options for states.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
timeout: Optional timeout in seconds. If no message received within
|
|
146
|
+
this time, on_timeout handler is called.
|
|
147
|
+
retries: Optional number of retries allowed. If state returns Retry,
|
|
148
|
+
the retry counter increments. When max retries exceeded, state fails.
|
|
149
|
+
terminal: If True, reaching this state completes the FSM (raises
|
|
150
|
+
StateMachineComplete exception).
|
|
151
|
+
can_dwell: If True, state can wait indefinitely without returning a
|
|
152
|
+
transition (returning None from on_state is allowed).
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
@enoki.state(config=StateConfiguration(timeout=5.0, retries=3))
|
|
156
|
+
def MyState():
|
|
157
|
+
# State with timeout and retry handling
|
|
158
|
+
pass
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
timeout: Optional[float] = None
|
|
162
|
+
retries: Optional[int] = None
|
|
163
|
+
terminal: bool = False
|
|
164
|
+
can_dwell: bool = False
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class SharedContext(Generic[T]):
|
|
169
|
+
"""Context passed to state handler methods.
|
|
170
|
+
|
|
171
|
+
The SharedContext provides access to the current message, shared data,
|
|
172
|
+
and utilities for state handlers.
|
|
173
|
+
|
|
174
|
+
Attributes:
|
|
175
|
+
send_message: Function to send messages to the state machine
|
|
176
|
+
log: Function to log messages (default: prints with [FSM] prefix)
|
|
177
|
+
common: Shared data passed to StateMachine (accessible as ctx.common)
|
|
178
|
+
msg: The current message being processed (if any)
|
|
179
|
+
|
|
180
|
+
Type parameter T represents the type of the 'common' field for type safety.
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
@enoki.on_state
|
|
184
|
+
async def on_state(ctx: SharedContext):
|
|
185
|
+
# Access shared data
|
|
186
|
+
counter = ctx.common.get("counter", 0)
|
|
187
|
+
|
|
188
|
+
# Check for messages
|
|
189
|
+
if ctx.msg == "start":
|
|
190
|
+
return Events.START
|
|
191
|
+
|
|
192
|
+
# Send new messages
|
|
193
|
+
ctx.send_message("ping")
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
send_message: Callable
|
|
197
|
+
log: Callable
|
|
198
|
+
common: T
|
|
199
|
+
msg: Optional[Any] = None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# ============================================================================
|
|
203
|
+
# Transition Type Hierarchy
|
|
204
|
+
# ============================================================================
|
|
205
|
+
|
|
206
|
+
@dataclass
|
|
207
|
+
class TransitionType:
|
|
208
|
+
"""Base type for all transitions"""
|
|
209
|
+
|
|
210
|
+
AWAITS: bool = False
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def name(self) -> str:
|
|
214
|
+
return self.__class__.__name__
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def base_name(self) -> str:
|
|
218
|
+
if hasattr(self.__class__, "__bases__") and self.__class__.__bases__:
|
|
219
|
+
return self.__class__.__bases__[0].__name__
|
|
220
|
+
return self.name
|
|
221
|
+
|
|
222
|
+
def __hash__(self):
|
|
223
|
+
return hash(f"{self.base_name()}.{self.name}")
|
|
224
|
+
|
|
225
|
+
def __eq__(self, other):
|
|
226
|
+
if isinstance(other, TransitionType):
|
|
227
|
+
return self.name == other.name
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@dataclass
|
|
232
|
+
class StateTransition(TransitionType):
|
|
233
|
+
"""Represents a transition to a different state"""
|
|
234
|
+
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass
|
|
239
|
+
class StateContinuation(TransitionType):
|
|
240
|
+
"""Represents staying in the same state"""
|
|
241
|
+
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@dataclass
|
|
246
|
+
class StateRenewal(TransitionType):
|
|
247
|
+
"""Represents a transition back into the current state"""
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@dataclass
|
|
251
|
+
class Again(StateContinuation):
|
|
252
|
+
"""Execute current state immediately without affecting retry counter"""
|
|
253
|
+
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dataclass
|
|
258
|
+
class Unhandled(StateContinuation):
|
|
259
|
+
"""Wait for next message/event without affecting retry counter"""
|
|
260
|
+
|
|
261
|
+
AWAITS: bool = True
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass
|
|
265
|
+
class Retry(StateContinuation):
|
|
266
|
+
"""Retry current state, decrementing retry counter, starting from on_enter"""
|
|
267
|
+
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@dataclass
|
|
272
|
+
class Restart(StateRenewal):
|
|
273
|
+
"""Restart current state, reset retry counter, await message"""
|
|
274
|
+
|
|
275
|
+
AWAITS: bool = True
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@dataclass
|
|
279
|
+
class Repeat(StateRenewal):
|
|
280
|
+
"""Repeat current state, reset retry counter, execute on_enter immediately"""
|
|
281
|
+
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@dataclass
|
|
286
|
+
class StateRef(StateTransition):
|
|
287
|
+
"""String-based state reference for breaking circular imports"""
|
|
288
|
+
|
|
289
|
+
state: str = ""
|
|
290
|
+
|
|
291
|
+
def __init__(self, state: str):
|
|
292
|
+
self.state = state
|
|
293
|
+
object.__setattr__(self, "AWAITS", False)
|
|
294
|
+
|
|
295
|
+
def __hash__(self):
|
|
296
|
+
return hash(self.state)
|
|
297
|
+
|
|
298
|
+
def __eq__(self, other):
|
|
299
|
+
if isinstance(other, StateRef):
|
|
300
|
+
return self.state == other.state
|
|
301
|
+
return False
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@dataclass
|
|
305
|
+
class LabeledTransition:
|
|
306
|
+
"""A transition with a label (typically an enum value)"""
|
|
307
|
+
|
|
308
|
+
label: Enum
|
|
309
|
+
transition: TransitionType
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
# ============================================================================
|
|
313
|
+
# Push/Pop Transitions (defined later to depend on StateSpec)
|
|
314
|
+
# ============================================================================
|
|
315
|
+
|
|
316
|
+
@dataclass
|
|
317
|
+
class Push(StateTransition):
|
|
318
|
+
"""Push one or more states onto the stack"""
|
|
319
|
+
|
|
320
|
+
push_states: List[StateTransition] = field(default_factory=list)
|
|
321
|
+
|
|
322
|
+
def __init__(self, *states: StateTransition):
|
|
323
|
+
object.__setattr__(self, "push_states", list(states))
|
|
324
|
+
|
|
325
|
+
def __hash__(self):
|
|
326
|
+
return hash(".".join([s.name for s in self.push_states]))
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@dataclass
|
|
330
|
+
class Pop(StateTransition):
|
|
331
|
+
"""Pop from the stack to return to previous state"""
|
|
332
|
+
|
|
333
|
+
pass
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
# ============================================================================
|
|
337
|
+
# StateSpec - The heart of the decorator system
|
|
338
|
+
# ============================================================================
|
|
339
|
+
|
|
340
|
+
@dataclass
|
|
341
|
+
class StateSpec(StateTransition):
|
|
342
|
+
"""
|
|
343
|
+
A state defined via decorators.
|
|
344
|
+
|
|
345
|
+
This is a StateTransition, so it can be used directly in transition lists.
|
|
346
|
+
It has the same interface as the old State class but is constructed by
|
|
347
|
+
decorators rather than metaclass magic.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
# Core identity
|
|
351
|
+
name: str = ""
|
|
352
|
+
qualname: str = ""
|
|
353
|
+
module: str = ""
|
|
354
|
+
|
|
355
|
+
# Configuration
|
|
356
|
+
config: StateConfiguration = field(default_factory=StateConfiguration)
|
|
357
|
+
|
|
358
|
+
# Lifecycle methods (all optional except on_state)
|
|
359
|
+
on_state: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
|
|
360
|
+
on_enter: Optional[Callable[[SharedContext], Awaitable[None]]] = None
|
|
361
|
+
on_leave: Optional[Callable[[SharedContext], Awaitable[None]]] = None
|
|
362
|
+
on_timeout: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
|
|
363
|
+
on_fail: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
|
|
364
|
+
|
|
365
|
+
# Transitions
|
|
366
|
+
transitions: Optional[Callable[[], List[Union[LabeledTransition, TransitionType]]]] = None
|
|
367
|
+
|
|
368
|
+
# Nested event enum (optional)
|
|
369
|
+
Events: Optional[type[Enum]] = None
|
|
370
|
+
|
|
371
|
+
# Group information (for state groups/namespaces)
|
|
372
|
+
group_name: str = field(default="")
|
|
373
|
+
parent_state: Optional[StateSpec] = None
|
|
374
|
+
|
|
375
|
+
# Metadata
|
|
376
|
+
_is_root: bool = field(default=False, repr=False)
|
|
377
|
+
|
|
378
|
+
@property
|
|
379
|
+
def base_name(self) -> str:
|
|
380
|
+
"""Get just the state name without module path"""
|
|
381
|
+
return self.qualname.split(".")[-1]
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def CONFIG(self) -> StateConfiguration:
|
|
385
|
+
"""Alias for config (for compatibility with State class API)"""
|
|
386
|
+
return self.config
|
|
387
|
+
|
|
388
|
+
def __hash__(self):
|
|
389
|
+
return hash(self.name)
|
|
390
|
+
|
|
391
|
+
def __eq__(self, other):
|
|
392
|
+
if isinstance(other, StateSpec):
|
|
393
|
+
return self.name == other.name
|
|
394
|
+
if isinstance(other, str):
|
|
395
|
+
return self.name == other
|
|
396
|
+
return False
|
|
397
|
+
|
|
398
|
+
def get_transitions(self) -> List[Union[LabeledTransition, TransitionType]]:
|
|
399
|
+
"""Get the transition list for this state"""
|
|
400
|
+
if self.transitions is None:
|
|
401
|
+
return []
|
|
402
|
+
result = self.transitions()
|
|
403
|
+
# Handle single transition returned
|
|
404
|
+
if isinstance(result, (TransitionType, LabeledTransition)):
|
|
405
|
+
return [result]
|
|
406
|
+
return result
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
# ============================================================================
|
|
410
|
+
# Global Registry
|
|
411
|
+
# ============================================================================
|
|
412
|
+
|
|
413
|
+
_state_registry: Dict[str, StateSpec] = {}
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def register_state(state: StateSpec) -> None:
|
|
417
|
+
"""Register a state in the global registry"""
|
|
418
|
+
_state_registry[state.name] = state
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_state(name: str) -> Optional[StateSpec]:
|
|
422
|
+
"""Get a state from the global registry"""
|
|
423
|
+
return _state_registry.get(name)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def get_all_states() -> Dict[str, StateSpec]:
|
|
427
|
+
"""Get all registered states"""
|
|
428
|
+
return _state_registry.copy()
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
# ============================================================================
|
|
432
|
+
# StateRegistry - Validation and Resolution
|
|
433
|
+
# ============================================================================
|
|
434
|
+
|
|
435
|
+
class ValidationError(Exception):
|
|
436
|
+
"""Raised when state machine validation fails"""
|
|
437
|
+
pass
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@dataclass
|
|
441
|
+
class ValidationResult:
|
|
442
|
+
"""Results of state machine validation"""
|
|
443
|
+
valid: bool
|
|
444
|
+
errors: list[str]
|
|
445
|
+
warnings: list[str]
|
|
446
|
+
discovered_states: set[str]
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class StateRegistry:
|
|
450
|
+
"""
|
|
451
|
+
Manages state resolution and validation for StateSpec-based states.
|
|
452
|
+
|
|
453
|
+
This registry handles:
|
|
454
|
+
- Resolving StateRef strings to StateSpec objects
|
|
455
|
+
- Validating all states in a state machine
|
|
456
|
+
- Getting transition mappings for states
|
|
457
|
+
- Detecting circular references and invalid transitions
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
def __init__(self):
|
|
461
|
+
self._state_cache: Dict[str, StateSpec] = {}
|
|
462
|
+
self._transition_cache: Dict[str, Dict] = {}
|
|
463
|
+
self._resolved_modules: Dict[str, Any] = {}
|
|
464
|
+
self._validation_errors: list[str] = []
|
|
465
|
+
self._validation_warnings: list[str] = []
|
|
466
|
+
|
|
467
|
+
def resolve_state(
|
|
468
|
+
self, state_ref: Union[str, StateRef, StateSpec], validate_only: bool = False
|
|
469
|
+
) -> Optional[StateSpec]:
|
|
470
|
+
"""
|
|
471
|
+
Resolve a state reference to an actual StateSpec object.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
state_ref: Can be a StateSpec, a StateRef, or a fully qualified state name string
|
|
475
|
+
validate_only: If True, don't cache the result (used during validation)
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
The resolved StateSpec, or None if not found (when validate_only=True)
|
|
479
|
+
"""
|
|
480
|
+
# If it's already a StateSpec, return it
|
|
481
|
+
if isinstance(state_ref, StateSpec):
|
|
482
|
+
# Track resolved module
|
|
483
|
+
if state_ref.module and state_ref.module not in self._resolved_modules:
|
|
484
|
+
mod = sys.modules.get(state_ref.module)
|
|
485
|
+
self._resolved_modules[state_ref.module] = mod
|
|
486
|
+
|
|
487
|
+
# Store short name for easier local file lookup
|
|
488
|
+
if mod and hasattr(mod, "__file__"):
|
|
489
|
+
fp = Path(getattr(mod, "__file__"))
|
|
490
|
+
self._resolved_modules[fp.stem] = mod
|
|
491
|
+
|
|
492
|
+
return state_ref
|
|
493
|
+
|
|
494
|
+
# If it's a StateRef, resolve the string
|
|
495
|
+
if isinstance(state_ref, StateRef):
|
|
496
|
+
state_name = state_ref.state
|
|
497
|
+
elif isinstance(state_ref, str):
|
|
498
|
+
state_name = state_ref
|
|
499
|
+
else:
|
|
500
|
+
return None
|
|
501
|
+
|
|
502
|
+
# Check cache first
|
|
503
|
+
if state_name in self._state_cache and not validate_only:
|
|
504
|
+
return self._state_cache[state_name]
|
|
505
|
+
|
|
506
|
+
# Try to get from global registry
|
|
507
|
+
state = get_state(state_name)
|
|
508
|
+
if state:
|
|
509
|
+
if not validate_only:
|
|
510
|
+
self._state_cache[state_name] = state
|
|
511
|
+
return state
|
|
512
|
+
|
|
513
|
+
# If not in global registry, we can't resolve it
|
|
514
|
+
# (In the old system, it would try dynamic imports, but with
|
|
515
|
+
# decorator-based states, everything should be pre-registered)
|
|
516
|
+
if validate_only:
|
|
517
|
+
self._validation_errors.append(
|
|
518
|
+
f"State '{state_name}' not found in registry"
|
|
519
|
+
)
|
|
520
|
+
return None
|
|
521
|
+
else:
|
|
522
|
+
raise ValidationError(f"State '{state_name}' not found in registry")
|
|
523
|
+
|
|
524
|
+
def validate_all_states(
|
|
525
|
+
self,
|
|
526
|
+
initial_state: Union[str, StateRef, StateSpec],
|
|
527
|
+
error_state: Optional[Union[str, StateRef, StateSpec]] = None,
|
|
528
|
+
) -> ValidationResult:
|
|
529
|
+
"""
|
|
530
|
+
Validate all states reachable from the initial state.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
initial_state: The starting state
|
|
534
|
+
error_state: Optional error state
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
ValidationResult with errors, warnings, and discovered states
|
|
538
|
+
"""
|
|
539
|
+
self._validation_errors = []
|
|
540
|
+
self._validation_warnings = []
|
|
541
|
+
|
|
542
|
+
# Discover all reachable states
|
|
543
|
+
discovered_states = set()
|
|
544
|
+
state_references = set()
|
|
545
|
+
resolved_references = set()
|
|
546
|
+
states_to_visit = []
|
|
547
|
+
|
|
548
|
+
# Resolve initial state
|
|
549
|
+
initial_resolved = self.resolve_state(initial_state, validate_only=True)
|
|
550
|
+
if initial_resolved:
|
|
551
|
+
states_to_visit.append(initial_resolved)
|
|
552
|
+
discovered_states.add(initial_resolved.name)
|
|
553
|
+
|
|
554
|
+
# Resolve error state if provided
|
|
555
|
+
if error_state:
|
|
556
|
+
error_resolved = self.resolve_state(error_state, validate_only=True)
|
|
557
|
+
if error_resolved:
|
|
558
|
+
states_to_visit.append(error_resolved)
|
|
559
|
+
discovered_states.add(error_resolved.name)
|
|
560
|
+
|
|
561
|
+
# Traverse all reachable states
|
|
562
|
+
visited = set()
|
|
563
|
+
while states_to_visit:
|
|
564
|
+
current_state = states_to_visit.pop()
|
|
565
|
+
state_name = current_state.name
|
|
566
|
+
|
|
567
|
+
if state_name in visited:
|
|
568
|
+
continue
|
|
569
|
+
visited.add(state_name)
|
|
570
|
+
|
|
571
|
+
def handle_state(s):
|
|
572
|
+
discovered_states.add(s.name if hasattr(s, 'name') else str(s))
|
|
573
|
+
states_to_visit.append(s)
|
|
574
|
+
|
|
575
|
+
def handle_reference(r):
|
|
576
|
+
state_references.add(r)
|
|
577
|
+
|
|
578
|
+
def handle_push(p):
|
|
579
|
+
for state in p.push_states:
|
|
580
|
+
match state:
|
|
581
|
+
case s if isinstance(s, StateSpec):
|
|
582
|
+
handle_state(s)
|
|
583
|
+
case r if isinstance(r, StateRef):
|
|
584
|
+
handle_reference(r)
|
|
585
|
+
case invalid:
|
|
586
|
+
self._validation_errors.append(
|
|
587
|
+
f"State '{state_name}' attempted to push an invalid transition: {invalid}"
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
# Get transitions
|
|
592
|
+
raw_transitions = current_state.get_transitions()
|
|
593
|
+
|
|
594
|
+
# Handle different return types
|
|
595
|
+
match raw_transitions:
|
|
596
|
+
case list() | tuple():
|
|
597
|
+
pass # Already a list
|
|
598
|
+
case _ if isinstance(raw_transitions, (TransitionType, LabeledTransition)):
|
|
599
|
+
raw_transitions = [raw_transitions]
|
|
600
|
+
case _:
|
|
601
|
+
raw_transitions = []
|
|
602
|
+
|
|
603
|
+
# Validate each transition
|
|
604
|
+
for transition in raw_transitions:
|
|
605
|
+
match transition:
|
|
606
|
+
case _ if isinstance(transition, Push):
|
|
607
|
+
# Push transitions must have labels
|
|
608
|
+
self._validation_errors.append(
|
|
609
|
+
f"State '{state_name}' has a push transition without a label"
|
|
610
|
+
)
|
|
611
|
+
handle_push(transition)
|
|
612
|
+
case LabeledTransition(label, s):
|
|
613
|
+
match s:
|
|
614
|
+
case st if isinstance(st, StateSpec):
|
|
615
|
+
handle_state(st)
|
|
616
|
+
case r if isinstance(r, StateRef):
|
|
617
|
+
handle_reference(r)
|
|
618
|
+
case p if isinstance(p, Push):
|
|
619
|
+
handle_push(p)
|
|
620
|
+
case s if isinstance(s, StateSpec):
|
|
621
|
+
handle_state(s)
|
|
622
|
+
case r if isinstance(r, StateRef):
|
|
623
|
+
self._validation_warnings.append(
|
|
624
|
+
f"State '{state_name}' has a StateReference transition without a label"
|
|
625
|
+
)
|
|
626
|
+
handle_reference(r)
|
|
627
|
+
case _ if isinstance(transition, (StateContinuation, StateRenewal)):
|
|
628
|
+
# Valid continuation/renewal
|
|
629
|
+
pass
|
|
630
|
+
case invalid:
|
|
631
|
+
self._validation_errors.append(
|
|
632
|
+
f"State '{state_name}' has an invalid transition: {invalid}"
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
# Validate that the state has an on_state handler
|
|
636
|
+
if current_state.on_state is None:
|
|
637
|
+
self._validation_errors.append(
|
|
638
|
+
f"State {state_name} does not have an on_state handler defined"
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
# Make sure non-terminal states define transitions
|
|
642
|
+
if not current_state.config.terminal:
|
|
643
|
+
if current_state.transitions is None:
|
|
644
|
+
self._validation_errors.append(
|
|
645
|
+
f"Non-terminal state {state_name} does not define any transitions"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
# Validate timeout configuration
|
|
649
|
+
if current_state.config.timeout is not None:
|
|
650
|
+
if (
|
|
651
|
+
not isinstance(current_state.config.timeout, (int, float))
|
|
652
|
+
or current_state.config.timeout <= 0
|
|
653
|
+
):
|
|
654
|
+
self._validation_errors.append(
|
|
655
|
+
f"State '{state_name}' has invalid timeout: {current_state.config.timeout}"
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
has_timeout_handler = current_state.on_timeout is not None
|
|
659
|
+
if not has_timeout_handler:
|
|
660
|
+
self._validation_warnings.append(
|
|
661
|
+
f"State '{state_name}' has timeout but no on_timeout handler defined"
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Pass 2: Validate all string references
|
|
665
|
+
for state_ref in list(state_references - resolved_references):
|
|
666
|
+
resolved = self.resolve_state(state_ref, validate_only=True)
|
|
667
|
+
if resolved:
|
|
668
|
+
resolved_references.add(state_ref)
|
|
669
|
+
if resolved.name not in discovered_states:
|
|
670
|
+
discovered_states.add(resolved.name)
|
|
671
|
+
states_to_visit.append(resolved)
|
|
672
|
+
|
|
673
|
+
except Exception as e:
|
|
674
|
+
self._validation_errors.append(
|
|
675
|
+
f"Error processing transitions for state '{state_name}': {traceback.format_exc()}"
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
return ValidationResult(
|
|
679
|
+
valid=len(self._validation_errors) == 0,
|
|
680
|
+
errors=self._validation_errors.copy(),
|
|
681
|
+
warnings=self._validation_warnings.copy(),
|
|
682
|
+
discovered_states=discovered_states,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
def get_transitions(self, state: StateSpec) -> Dict:
|
|
686
|
+
"""
|
|
687
|
+
Get the transition mapping for a state, with caching.
|
|
688
|
+
|
|
689
|
+
Returns a dict mapping event labels/enum values to transition targets.
|
|
690
|
+
"""
|
|
691
|
+
state_name = state.name
|
|
692
|
+
|
|
693
|
+
if state_name in self._transition_cache:
|
|
694
|
+
return self._transition_cache[state_name]
|
|
695
|
+
|
|
696
|
+
resolved_transitions = {}
|
|
697
|
+
|
|
698
|
+
raw_transitions = state.get_transitions()
|
|
699
|
+
match raw_transitions:
|
|
700
|
+
case list() | tuple():
|
|
701
|
+
pass # Already a list
|
|
702
|
+
case _ if isinstance(raw_transitions, (TransitionType, LabeledTransition)):
|
|
703
|
+
raw_transitions = [raw_transitions]
|
|
704
|
+
case _:
|
|
705
|
+
raw_transitions = []
|
|
706
|
+
|
|
707
|
+
for transition in raw_transitions:
|
|
708
|
+
match transition:
|
|
709
|
+
case LabeledTransition(label, target):
|
|
710
|
+
# Use the label enum itself as the key
|
|
711
|
+
resolved_transitions[label] = target
|
|
712
|
+
# Also allow lookup by label name
|
|
713
|
+
resolved_transitions[label.name] = target
|
|
714
|
+
# Also allow lookup by label value
|
|
715
|
+
if hasattr(label, 'value'):
|
|
716
|
+
resolved_transitions[label.value] = target
|
|
717
|
+
case s if isinstance(s, StateSpec):
|
|
718
|
+
# Direct state reference
|
|
719
|
+
resolved_transitions[s.name] = s
|
|
720
|
+
resolved_transitions[s] = s
|
|
721
|
+
case r if isinstance(r, StateRef):
|
|
722
|
+
# String reference - resolve it
|
|
723
|
+
resolved = self.resolve_state(r)
|
|
724
|
+
if resolved:
|
|
725
|
+
resolved_transitions[r.state] = resolved
|
|
726
|
+
case _ if isinstance(transition, (StateContinuation, StateRenewal)):
|
|
727
|
+
# Special transitions - allow lookup by type
|
|
728
|
+
resolved_transitions[transition] = transition
|
|
729
|
+
|
|
730
|
+
self._transition_cache[state_name] = resolved_transitions
|
|
731
|
+
return resolved_transitions
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
# ============================================================================
|
|
735
|
+
# Message Types for StateMachine
|
|
736
|
+
# ============================================================================
|
|
737
|
+
|
|
738
|
+
@dataclass
|
|
739
|
+
class PrioritizedMessage:
|
|
740
|
+
priority: int
|
|
741
|
+
item: Any = field(compare=False)
|
|
742
|
+
|
|
743
|
+
def __lt__(self, other):
|
|
744
|
+
"""Compare based on priority only"""
|
|
745
|
+
if not isinstance(other, PrioritizedMessage):
|
|
746
|
+
return NotImplemented
|
|
747
|
+
return self.priority < other.priority
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
@dataclass
|
|
751
|
+
class EnokiInternalMessage:
|
|
752
|
+
"""Base class for internal FSM messages"""
|
|
753
|
+
pass
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
@dataclass
|
|
757
|
+
class TimeoutMessage(EnokiInternalMessage):
|
|
758
|
+
"""Internal message for timeout events"""
|
|
759
|
+
state_name: str
|
|
760
|
+
timeout_id: int
|
|
761
|
+
timeout_duration: float
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
# ============================================================================
|
|
765
|
+
# Exceptions for StateMachine
|
|
766
|
+
# ============================================================================
|
|
767
|
+
|
|
768
|
+
class StateMachineComplete(Exception):
|
|
769
|
+
"""Raised when the state machine has reached a terminal state"""
|
|
770
|
+
pass
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class BlockedInUntimedState(Exception):
|
|
774
|
+
"""Raised when a non-dwelling state blocks without a timeout"""
|
|
775
|
+
pass
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
class PopFromEmptyStack(Exception):
|
|
779
|
+
"""Raised when we attempt to pop from an empty stack"""
|
|
780
|
+
pass
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
class NoStateToTick(Exception):
|
|
784
|
+
"""Raised when the FSM object doesn't have a state attached"""
|
|
785
|
+
pass
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
# ============================================================================
|
|
789
|
+
# StateMachine Runtime
|
|
790
|
+
# ============================================================================
|
|
791
|
+
|
|
792
|
+
class StateMachine:
|
|
793
|
+
"""Asyncio-native finite state machine for executing StateSpec-based states.
|
|
794
|
+
|
|
795
|
+
The StateMachine manages state execution, transitions, message passing,
|
|
796
|
+
and lifecycle handling. States are defined using the @enoki.state decorator
|
|
797
|
+
and should define on_state, on_enter, on_leave, on_timeout, or on_fail handlers.
|
|
798
|
+
|
|
799
|
+
Args:
|
|
800
|
+
initial_state: The initial state (state function, StateRef, or StateSpec)
|
|
801
|
+
error_state: Optional error state for handling exceptions
|
|
802
|
+
filter_fn: Optional function to filter messages (ctx, msg) -> bool
|
|
803
|
+
trap_fn: Optional function to trap exceptions (exc) -> None
|
|
804
|
+
on_error_fn: Optional function called on errors (ctx, exc) -> None
|
|
805
|
+
common_data: Shared data accessible via ctx.common in all states
|
|
806
|
+
|
|
807
|
+
Attributes:
|
|
808
|
+
current_state: The currently executing state
|
|
809
|
+
context: SharedContext containing msg and common data
|
|
810
|
+
state_stack: Stack for Push/Pop hierarchical state management
|
|
811
|
+
|
|
812
|
+
Methods:
|
|
813
|
+
initialize(): Initialize the state machine (must be called before tick)
|
|
814
|
+
tick(timeout): Process one state machine tick
|
|
815
|
+
send_message(msg): Send a message to the state machine
|
|
816
|
+
reset(): Reset to initial state
|
|
817
|
+
|
|
818
|
+
Example:
|
|
819
|
+
fsm = StateMachine(
|
|
820
|
+
initial_state=IdleState,
|
|
821
|
+
error_state=ErrorState,
|
|
822
|
+
common_data={"counter": 0}
|
|
823
|
+
)
|
|
824
|
+
await fsm.initialize()
|
|
825
|
+
fsm.send_message("start")
|
|
826
|
+
await fsm.tick()
|
|
827
|
+
"""
|
|
828
|
+
|
|
829
|
+
class MessagePriorities(IntEnum):
|
|
830
|
+
ERROR = 0
|
|
831
|
+
INTERNAL_MESSAGE = 1
|
|
832
|
+
MESSAGE = 2
|
|
833
|
+
|
|
834
|
+
def __init__(
|
|
835
|
+
self,
|
|
836
|
+
initial_state: Union[str, StateRef, StateSpec],
|
|
837
|
+
error_state: Optional[Union[str, StateRef, StateSpec]] = None,
|
|
838
|
+
filter_fn: Optional[Callable] = None,
|
|
839
|
+
trap_fn: Optional[Callable] = None,
|
|
840
|
+
on_error_fn: Optional[Callable] = None,
|
|
841
|
+
common_data: Optional[Any] = None,
|
|
842
|
+
):
|
|
843
|
+
|
|
844
|
+
self.registry = StateRegistry()
|
|
845
|
+
|
|
846
|
+
# The filter and trap functions
|
|
847
|
+
self._filter_fn = filter_fn or (lambda x, y: None)
|
|
848
|
+
self._trap_fn = trap_fn or (lambda x: None)
|
|
849
|
+
self._on_err_fn = on_error_fn or (lambda x, y: None)
|
|
850
|
+
|
|
851
|
+
# Validate the state machine structure
|
|
852
|
+
validation_result = self.registry.validate_all_states(
|
|
853
|
+
initial_state, error_state
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
if not validation_result.valid:
|
|
857
|
+
error_msg = "State machine validation failed:\n" + "\n".join(
|
|
858
|
+
validation_result.errors
|
|
859
|
+
)
|
|
860
|
+
if validation_result.warnings:
|
|
861
|
+
error_msg += "\n\nWarnings:\n" + "\n".join(validation_result.warnings)
|
|
862
|
+
raise ValidationError(error_msg)
|
|
863
|
+
|
|
864
|
+
# Log warnings
|
|
865
|
+
if validation_result.warnings:
|
|
866
|
+
for warning in validation_result.warnings:
|
|
867
|
+
self.log(f"WARNING: {warning}")
|
|
868
|
+
|
|
869
|
+
# Resolve initial and error states
|
|
870
|
+
self.initial_state = self.registry.resolve_state(initial_state)
|
|
871
|
+
self.error_state = (
|
|
872
|
+
self.registry.resolve_state(error_state) if error_state else None
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
# Asyncio-based infrastructure
|
|
876
|
+
self._message_queue = PriorityQueue()
|
|
877
|
+
self._timeout_task = None
|
|
878
|
+
self._timeout_counter = 0
|
|
879
|
+
self._current_timeout_id = None
|
|
880
|
+
|
|
881
|
+
# Runtime state
|
|
882
|
+
self.current_state: Optional[StateSpec] = None
|
|
883
|
+
self.state_stack: List[StateSpec] = []
|
|
884
|
+
self.context = SharedContext(
|
|
885
|
+
send_message=self.send_message,
|
|
886
|
+
log=self.log,
|
|
887
|
+
common=common_data if common_data is not None else dict(),
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
self.retry_counters: Dict[str, int] = {}
|
|
891
|
+
self.state_enter_times: Dict[str, float] = {}
|
|
892
|
+
|
|
893
|
+
# Initialize asynchronously
|
|
894
|
+
self._initialized = False
|
|
895
|
+
|
|
896
|
+
async def initialize(self):
|
|
897
|
+
"""Initialize the state machine asynchronously"""
|
|
898
|
+
if not self._initialized:
|
|
899
|
+
await self.reset()
|
|
900
|
+
self._initialized = True
|
|
901
|
+
|
|
902
|
+
def send_message(self, message: Any):
|
|
903
|
+
"""Send a message to the state machine"""
|
|
904
|
+
try:
|
|
905
|
+
if message and self._filter_fn(self.context, message):
|
|
906
|
+
return
|
|
907
|
+
self._send_message_internal(message)
|
|
908
|
+
except Exception as e:
|
|
909
|
+
self._send_message_internal(e)
|
|
910
|
+
|
|
911
|
+
def _send_message_internal(self, message: Any):
|
|
912
|
+
"""Internal message sending"""
|
|
913
|
+
|
|
914
|
+
match message:
|
|
915
|
+
case _ if isinstance(message, EnokiInternalMessage):
|
|
916
|
+
try:
|
|
917
|
+
self._message_queue.put_nowait(
|
|
918
|
+
PrioritizedMessage(
|
|
919
|
+
self.MessagePriorities.INTERNAL_MESSAGE, message
|
|
920
|
+
)
|
|
921
|
+
)
|
|
922
|
+
except Exception:
|
|
923
|
+
pass # Queue full
|
|
924
|
+
case _ if isinstance(message, Exception):
|
|
925
|
+
try:
|
|
926
|
+
self._message_queue.put_nowait(
|
|
927
|
+
PrioritizedMessage(self.MessagePriorities.ERROR, message)
|
|
928
|
+
)
|
|
929
|
+
except Exception:
|
|
930
|
+
pass
|
|
931
|
+
case _:
|
|
932
|
+
try:
|
|
933
|
+
self._message_queue.put_nowait(
|
|
934
|
+
PrioritizedMessage(self.MessagePriorities.MESSAGE, message)
|
|
935
|
+
)
|
|
936
|
+
except Exception:
|
|
937
|
+
pass
|
|
938
|
+
|
|
939
|
+
async def reset(self):
|
|
940
|
+
"""Reset the state machine to initial state"""
|
|
941
|
+
self._cancel_timeout()
|
|
942
|
+
self.state_stack = []
|
|
943
|
+
self.retry_counters = {}
|
|
944
|
+
self.state_enter_times = {}
|
|
945
|
+
|
|
946
|
+
# Clear message queue
|
|
947
|
+
while not self._message_queue.empty():
|
|
948
|
+
try:
|
|
949
|
+
self._message_queue.get_nowait()
|
|
950
|
+
except Exception:
|
|
951
|
+
break
|
|
952
|
+
|
|
953
|
+
# Transition into our initial state
|
|
954
|
+
await self._transition_to_state(self.initial_state)
|
|
955
|
+
|
|
956
|
+
def log(self, message: str):
|
|
957
|
+
"""Log a message"""
|
|
958
|
+
print(f"[FSM] {message}")
|
|
959
|
+
|
|
960
|
+
async def tick(self, timeout: Optional[Union[float, int]] = 0):
|
|
961
|
+
"""Process one state machine tick"""
|
|
962
|
+
|
|
963
|
+
if not self._initialized:
|
|
964
|
+
await self.initialize()
|
|
965
|
+
|
|
966
|
+
if not self.current_state:
|
|
967
|
+
raise NoStateToTick()
|
|
968
|
+
|
|
969
|
+
# Validate timeout parameter
|
|
970
|
+
match timeout:
|
|
971
|
+
case x if x and not isinstance(x, (int, float)):
|
|
972
|
+
raise ValueError(f"Tick timeout must be None, or an int/float >= 0")
|
|
973
|
+
case 0:
|
|
974
|
+
block = False
|
|
975
|
+
case None:
|
|
976
|
+
block = True
|
|
977
|
+
case x if x > 0:
|
|
978
|
+
block = True
|
|
979
|
+
case _:
|
|
980
|
+
raise ValueError("Tick timeout must be None or >= 0")
|
|
981
|
+
|
|
982
|
+
while True:
|
|
983
|
+
try:
|
|
984
|
+
if isinstance(self.context.msg, Exception):
|
|
985
|
+
raise self.context.msg
|
|
986
|
+
|
|
987
|
+
# Handle timeout messages specially
|
|
988
|
+
if isinstance(self.context.msg, TimeoutMessage):
|
|
989
|
+
if not self._handle_timeout_message():
|
|
990
|
+
return
|
|
991
|
+
transition = await self._call_on_timeout(self.current_state)
|
|
992
|
+
else:
|
|
993
|
+
transition = await self._call_on_state(self.current_state)
|
|
994
|
+
|
|
995
|
+
if self.current_state.config.terminal:
|
|
996
|
+
raise StateMachineComplete()
|
|
997
|
+
|
|
998
|
+
if transition in (None, Unhandled):
|
|
999
|
+
self._trap_fn(self.context)
|
|
1000
|
+
|
|
1001
|
+
self.context.msg = None
|
|
1002
|
+
|
|
1003
|
+
should_check_queue = await self._process_transition(transition, block)
|
|
1004
|
+
|
|
1005
|
+
# In manual mode (block=False), stop after processing one transition if not checking queue
|
|
1006
|
+
if not block and not should_check_queue:
|
|
1007
|
+
break
|
|
1008
|
+
|
|
1009
|
+
if should_check_queue:
|
|
1010
|
+
if block:
|
|
1011
|
+
message = await asyncio.wait_for(
|
|
1012
|
+
self._message_queue.get(), timeout=timeout
|
|
1013
|
+
)
|
|
1014
|
+
message = message.item
|
|
1015
|
+
else:
|
|
1016
|
+
message = self._message_queue.get_nowait().item
|
|
1017
|
+
self.context.msg = message
|
|
1018
|
+
except asyncio.QueueEmpty:
|
|
1019
|
+
break
|
|
1020
|
+
except asyncio.TimeoutError:
|
|
1021
|
+
break
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
if (
|
|
1024
|
+
isinstance(e, StateMachineComplete)
|
|
1025
|
+
or self.current_state.config.terminal
|
|
1026
|
+
):
|
|
1027
|
+
raise
|
|
1028
|
+
|
|
1029
|
+
next_transition = self._on_err_fn(self.context, e)
|
|
1030
|
+
if next_transition:
|
|
1031
|
+
await self._process_transition(next_transition, block)
|
|
1032
|
+
elif self.error_state and self.current_state != self.error_state:
|
|
1033
|
+
# Clear the exception message before transitioning
|
|
1034
|
+
self.context.msg = None
|
|
1035
|
+
await self._transition_to_state(self.error_state)
|
|
1036
|
+
elif self.current_state == self.error_state:
|
|
1037
|
+
# Already in error state and got another exception - stop
|
|
1038
|
+
break
|
|
1039
|
+
else:
|
|
1040
|
+
raise
|
|
1041
|
+
|
|
1042
|
+
return None
|
|
1043
|
+
|
|
1044
|
+
async def _call_on_state(self, state: StateSpec):
|
|
1045
|
+
"""Call the state's on_state handler"""
|
|
1046
|
+
if state.on_state is None:
|
|
1047
|
+
raise ValueError(f"State {state.name} has no on_state handler")
|
|
1048
|
+
# Create interface view if handler has interface type hint
|
|
1049
|
+
ctx_to_pass = _create_interface_view_for_context(self.context, state.on_state)
|
|
1050
|
+
return await state.on_state(ctx_to_pass)
|
|
1051
|
+
|
|
1052
|
+
async def _call_on_enter(self, state: StateSpec):
|
|
1053
|
+
"""Call the state's on_enter handler"""
|
|
1054
|
+
if state.on_enter is not None:
|
|
1055
|
+
self.log(f" [DEBUG] Calling on_enter for {state.name}")
|
|
1056
|
+
# Create interface view if handler has interface type hint
|
|
1057
|
+
ctx_to_pass = _create_interface_view_for_context(self.context, state.on_enter)
|
|
1058
|
+
await state.on_enter(ctx_to_pass)
|
|
1059
|
+
self.log(f" [DEBUG] on_enter completed for {state.name}")
|
|
1060
|
+
|
|
1061
|
+
async def _call_on_leave(self, state: StateSpec):
|
|
1062
|
+
"""Call the state's on_leave handler"""
|
|
1063
|
+
if state.on_leave is not None:
|
|
1064
|
+
# Create interface view if handler has interface type hint
|
|
1065
|
+
ctx_to_pass = _create_interface_view_for_context(self.context, state.on_leave)
|
|
1066
|
+
await state.on_leave(ctx_to_pass)
|
|
1067
|
+
|
|
1068
|
+
async def _call_on_timeout(self, state: StateSpec):
|
|
1069
|
+
"""Call the state's on_timeout handler"""
|
|
1070
|
+
if state.on_timeout is None:
|
|
1071
|
+
raise ValueError(f"State {state.name} has no on_timeout handler")
|
|
1072
|
+
# Create interface view if handler has interface type hint
|
|
1073
|
+
ctx_to_pass = _create_interface_view_for_context(self.context, state.on_timeout)
|
|
1074
|
+
return await state.on_timeout(ctx_to_pass)
|
|
1075
|
+
|
|
1076
|
+
async def _call_on_fail(self, state: StateSpec):
|
|
1077
|
+
"""Call the state's on_fail handler"""
|
|
1078
|
+
if state.on_fail is None:
|
|
1079
|
+
raise ValueError(f"State {state.name} has no on_fail handler")
|
|
1080
|
+
# Create interface view if handler has interface type hint
|
|
1081
|
+
ctx_to_pass = _create_interface_view_for_context(self.context, state.on_fail)
|
|
1082
|
+
return await state.on_fail(ctx_to_pass)
|
|
1083
|
+
|
|
1084
|
+
async def _process_transition(self, transition: Any, block: bool = False) -> bool:
|
|
1085
|
+
"""Process a state transition"""
|
|
1086
|
+
|
|
1087
|
+
valid_transitions = self.registry.get_transitions(self.current_state)
|
|
1088
|
+
|
|
1089
|
+
# Try to get the transition target
|
|
1090
|
+
target = valid_transitions.get(transition)
|
|
1091
|
+
|
|
1092
|
+
# If not found by the transition itself, try by value
|
|
1093
|
+
if target is None and transition in valid_transitions.values():
|
|
1094
|
+
target = transition
|
|
1095
|
+
elif transition is not None and target is None:
|
|
1096
|
+
raise ValueError(f"Unknown transition {transition}: {valid_transitions}")
|
|
1097
|
+
|
|
1098
|
+
# Process the transition
|
|
1099
|
+
|
|
1100
|
+
if transition in (None, Unhandled):
|
|
1101
|
+
if (
|
|
1102
|
+
self.current_state.config.can_dwell
|
|
1103
|
+
) or self.current_state.config.timeout:
|
|
1104
|
+
return True
|
|
1105
|
+
raise BlockedInUntimedState(
|
|
1106
|
+
f"{self.current_state.name} cannot dwell and does not have a timeout"
|
|
1107
|
+
)
|
|
1108
|
+
elif isinstance(target, StateSpec):
|
|
1109
|
+
await self._transition_to_state(target)
|
|
1110
|
+
elif isinstance(target, Push):
|
|
1111
|
+
await self._handle_push(target)
|
|
1112
|
+
elif isinstance(target, Pop) or target is Pop:
|
|
1113
|
+
await self._handle_pop()
|
|
1114
|
+
elif target == Again:
|
|
1115
|
+
# Re-execute current state immediately
|
|
1116
|
+
# In manual tick mode (block=False), stop after one execution
|
|
1117
|
+
# In automatic mode (block=True), continue the loop
|
|
1118
|
+
if block:
|
|
1119
|
+
# Automatic mode: continue the loop
|
|
1120
|
+
return True
|
|
1121
|
+
else:
|
|
1122
|
+
# Manual mode: stop after this execution
|
|
1123
|
+
return False
|
|
1124
|
+
elif target == Repeat:
|
|
1125
|
+
await self._transition_to_state(self.current_state)
|
|
1126
|
+
elif target == Restart:
|
|
1127
|
+
self._reset_state_context()
|
|
1128
|
+
await self._transition_to_state(self.current_state)
|
|
1129
|
+
return True
|
|
1130
|
+
elif target == Retry:
|
|
1131
|
+
await self._handle_retry()
|
|
1132
|
+
elif isinstance(target, (StateContinuation, StateRenewal)):
|
|
1133
|
+
# Other continuations/renewals - handle appropriately
|
|
1134
|
+
if target == Again:
|
|
1135
|
+
return True # Continue the loop
|
|
1136
|
+
elif target == Repeat:
|
|
1137
|
+
await self._transition_to_state(self.current_state)
|
|
1138
|
+
elif target == Restart:
|
|
1139
|
+
self._reset_state_context()
|
|
1140
|
+
await self._transition_to_state(self.current_state)
|
|
1141
|
+
return True
|
|
1142
|
+
elif target == Retry:
|
|
1143
|
+
await self._handle_retry()
|
|
1144
|
+
|
|
1145
|
+
return False
|
|
1146
|
+
|
|
1147
|
+
async def _transition_to_state(self, next_state: StateSpec):
|
|
1148
|
+
"""Transition to a new state"""
|
|
1149
|
+
|
|
1150
|
+
if self.current_state:
|
|
1151
|
+
await self._call_on_leave(self.current_state)
|
|
1152
|
+
self._cancel_timeout()
|
|
1153
|
+
|
|
1154
|
+
self.current_state = next_state
|
|
1155
|
+
self.state_enter_times[next_state.name] = time.time()
|
|
1156
|
+
self.log(f"Transitioned to {next_state.name}")
|
|
1157
|
+
self.log(f" [DEBUG] common object id: {id(self.context.common)}")
|
|
1158
|
+
self.log(f" [DEBUG] on_enter is None? {next_state.on_enter is None}")
|
|
1159
|
+
|
|
1160
|
+
self._start_timeout(next_state)
|
|
1161
|
+
await self._call_on_enter(next_state)
|
|
1162
|
+
self.log(f" [DEBUG] After on_enter, common: {self.context.common}")
|
|
1163
|
+
|
|
1164
|
+
def _start_timeout(self, state: StateSpec) -> Optional[int]:
|
|
1165
|
+
"""Start a timeout for the given state"""
|
|
1166
|
+
|
|
1167
|
+
if state.config.timeout is None:
|
|
1168
|
+
return None
|
|
1169
|
+
|
|
1170
|
+
self._cancel_timeout()
|
|
1171
|
+
|
|
1172
|
+
self._timeout_counter += 1
|
|
1173
|
+
timeout_id = self._timeout_counter
|
|
1174
|
+
self._current_timeout_id = timeout_id
|
|
1175
|
+
|
|
1176
|
+
async def timeout_handler():
|
|
1177
|
+
try:
|
|
1178
|
+
await asyncio.sleep(state.config.timeout)
|
|
1179
|
+
self._send_timeout_message(state.name, timeout_id, state.config.timeout)
|
|
1180
|
+
except asyncio.CancelledError:
|
|
1181
|
+
pass
|
|
1182
|
+
|
|
1183
|
+
self._timeout_task = asyncio.create_task(timeout_handler())
|
|
1184
|
+
return timeout_id
|
|
1185
|
+
|
|
1186
|
+
def _cancel_timeout(self):
|
|
1187
|
+
"""Cancel any active timeout"""
|
|
1188
|
+
if self._timeout_task is not None:
|
|
1189
|
+
self._timeout_task.cancel()
|
|
1190
|
+
self._timeout_task = None
|
|
1191
|
+
self._current_timeout_id = None
|
|
1192
|
+
|
|
1193
|
+
def _send_timeout_message(self, state_name: str, timeout_id: int, duration: float):
|
|
1194
|
+
"""Internal method to send timeout messages"""
|
|
1195
|
+
|
|
1196
|
+
timeout_msg = TimeoutMessage(state_name, timeout_id, duration)
|
|
1197
|
+
self._send_message_internal(timeout_msg)
|
|
1198
|
+
|
|
1199
|
+
def _handle_timeout_message(self) -> bool:
|
|
1200
|
+
"""Handle a timeout message"""
|
|
1201
|
+
|
|
1202
|
+
timeout_msg = self.context.msg
|
|
1203
|
+
self.context.msg = None
|
|
1204
|
+
|
|
1205
|
+
if timeout_msg.timeout_id != self._current_timeout_id:
|
|
1206
|
+
self.log(f"Ignoring stale timeout {timeout_msg.timeout_id}")
|
|
1207
|
+
return False
|
|
1208
|
+
|
|
1209
|
+
if timeout_msg.state_name != self.current_state.name:
|
|
1210
|
+
self.log(
|
|
1211
|
+
f"Ignoring timeout for {timeout_msg.state_name}, current state is {self.current_state.name}"
|
|
1212
|
+
)
|
|
1213
|
+
return False
|
|
1214
|
+
|
|
1215
|
+
self.log(
|
|
1216
|
+
f"Handling timeout for {timeout_msg.state_name} ({timeout_msg.timeout_duration}s)"
|
|
1217
|
+
)
|
|
1218
|
+
self._cancel_timeout()
|
|
1219
|
+
return True
|
|
1220
|
+
|
|
1221
|
+
async def _handle_retry(self):
|
|
1222
|
+
"""Handle retry logic with counter"""
|
|
1223
|
+
state_name = self.current_state.name
|
|
1224
|
+
|
|
1225
|
+
if state_name not in self.retry_counters:
|
|
1226
|
+
self.retry_counters[state_name] = 0
|
|
1227
|
+
|
|
1228
|
+
self.retry_counters[state_name] += 1
|
|
1229
|
+
|
|
1230
|
+
if (
|
|
1231
|
+
self.current_state.config.retries is not None
|
|
1232
|
+
and self.retry_counters[state_name] > self.current_state.config.retries
|
|
1233
|
+
):
|
|
1234
|
+
self.log(f"Retry limit exceeded for {state_name}")
|
|
1235
|
+
transition = await self._call_on_fail(self.current_state)
|
|
1236
|
+
await self._process_transition(transition)
|
|
1237
|
+
else:
|
|
1238
|
+
await self._transition_to_state(self.current_state)
|
|
1239
|
+
|
|
1240
|
+
def _reset_state_context(self):
|
|
1241
|
+
"""Reset context for current state"""
|
|
1242
|
+
state_name = self.current_state.name
|
|
1243
|
+
if state_name in self.retry_counters:
|
|
1244
|
+
del self.retry_counters[state_name]
|
|
1245
|
+
|
|
1246
|
+
async def _handle_push(self, push: Push):
|
|
1247
|
+
"""Handle push transition"""
|
|
1248
|
+
resolved_states = [self.registry.resolve_state(s) for s in push.push_states]
|
|
1249
|
+
|
|
1250
|
+
for state in reversed(resolved_states[1:]):
|
|
1251
|
+
self.state_stack.append(state)
|
|
1252
|
+
|
|
1253
|
+
await self._transition_to_state(resolved_states[0])
|
|
1254
|
+
|
|
1255
|
+
async def _handle_pop(self):
|
|
1256
|
+
"""Handle pop transition"""
|
|
1257
|
+
if not self.state_stack:
|
|
1258
|
+
raise PopFromEmptyStack()
|
|
1259
|
+
|
|
1260
|
+
next_state = self.state_stack.pop()
|
|
1261
|
+
await self._transition_to_state(next_state)
|
|
1262
|
+
|
|
1263
|
+
async def run(
|
|
1264
|
+
self, max_iterations: Optional[int] = None, timeout: Optional[float] = 1
|
|
1265
|
+
):
|
|
1266
|
+
"""Run the state machine until terminal state"""
|
|
1267
|
+
if not self._initialized:
|
|
1268
|
+
await self.initialize()
|
|
1269
|
+
|
|
1270
|
+
iteration = 0
|
|
1271
|
+
|
|
1272
|
+
while True:
|
|
1273
|
+
if max_iterations is not None and iteration >= max_iterations:
|
|
1274
|
+
break
|
|
1275
|
+
|
|
1276
|
+
if self.current_state.config.terminal:
|
|
1277
|
+
self.log(f"Reached terminal state: {self.current_state.name}")
|
|
1278
|
+
break
|
|
1279
|
+
|
|
1280
|
+
try:
|
|
1281
|
+
await self.tick(timeout=timeout)
|
|
1282
|
+
iteration += 1
|
|
1283
|
+
except StateMachineComplete:
|
|
1284
|
+
raise
|
|
1285
|
+
except Exception as e:
|
|
1286
|
+
self.log(f"Error in FSM: {e}")
|
|
1287
|
+
if self.error_state:
|
|
1288
|
+
await self._transition_to_state(self.error_state)
|
|
1289
|
+
else:
|
|
1290
|
+
raise
|
|
1291
|
+
|
|
1292
|
+
|
|
1293
|
+
# ============================================================================
|
|
1294
|
+
# Decorators
|
|
1295
|
+
# ============================================================================
|
|
1296
|
+
|
|
1297
|
+
class _EnokiDecoratorAPI:
|
|
1298
|
+
"""Decorator API for defining states"""
|
|
1299
|
+
|
|
1300
|
+
def __init__(self):
|
|
1301
|
+
self._tracking_stack: List[List[Tuple[str, Any]]] = []
|
|
1302
|
+
|
|
1303
|
+
def state(
|
|
1304
|
+
self,
|
|
1305
|
+
config: StateConfiguration = StateConfiguration(),
|
|
1306
|
+
name: Optional[str] = None,
|
|
1307
|
+
group: Optional[str] = None,
|
|
1308
|
+
) -> Callable[[Callable], StateSpec]:
|
|
1309
|
+
"""
|
|
1310
|
+
Decorator to define a state.
|
|
1311
|
+
|
|
1312
|
+
Args:
|
|
1313
|
+
config: StateConfiguration for this state
|
|
1314
|
+
name: Optional name (auto-generated if not provided)
|
|
1315
|
+
group: Optional group name for organization
|
|
1316
|
+
|
|
1317
|
+
The decorated function should contain nested decorated functions
|
|
1318
|
+
(@on_state, @on_enter, @transitions, etc.). No need to return a dict!
|
|
1319
|
+
"""
|
|
1320
|
+
|
|
1321
|
+
def decorator(func: Callable[..., Any]) -> StateSpec:
|
|
1322
|
+
# Get function metadata
|
|
1323
|
+
module = func.__module__
|
|
1324
|
+
qualname = func.__qualname__
|
|
1325
|
+
|
|
1326
|
+
# Auto-generate name if not provided
|
|
1327
|
+
if name is None:
|
|
1328
|
+
# Handle __main__ module
|
|
1329
|
+
if module == "__main__":
|
|
1330
|
+
filename = inspect.getfile(func)
|
|
1331
|
+
module = os.path.splitext(os.path.basename(filename))[0]
|
|
1332
|
+
state_name = f"{module}.{qualname}"
|
|
1333
|
+
else:
|
|
1334
|
+
state_name = name
|
|
1335
|
+
|
|
1336
|
+
# Set up tracking for this state
|
|
1337
|
+
tracked_items = []
|
|
1338
|
+
self._tracking_stack.append(tracked_items)
|
|
1339
|
+
|
|
1340
|
+
try:
|
|
1341
|
+
# Call the function to execute inner decorators
|
|
1342
|
+
func()
|
|
1343
|
+
finally:
|
|
1344
|
+
# Always pop the tracking stack
|
|
1345
|
+
self._tracking_stack.pop()
|
|
1346
|
+
|
|
1347
|
+
# Extract decorated methods from tracked items
|
|
1348
|
+
on_state = None
|
|
1349
|
+
on_enter = None
|
|
1350
|
+
on_leave = None
|
|
1351
|
+
on_timeout = None
|
|
1352
|
+
on_fail = None
|
|
1353
|
+
transitions = None
|
|
1354
|
+
Events = None
|
|
1355
|
+
|
|
1356
|
+
for item_name, item_fn in tracked_items:
|
|
1357
|
+
if hasattr(item_fn, "_enoki_on_state"):
|
|
1358
|
+
on_state = item_fn
|
|
1359
|
+
elif hasattr(item_fn, "_enoki_on_enter"):
|
|
1360
|
+
on_enter = item_fn
|
|
1361
|
+
elif hasattr(item_fn, "_enoki_on_leave"):
|
|
1362
|
+
on_leave = item_fn
|
|
1363
|
+
elif hasattr(item_fn, "_enoki_on_timeout"):
|
|
1364
|
+
on_timeout = item_fn
|
|
1365
|
+
elif hasattr(item_fn, "_enoki_on_fail"):
|
|
1366
|
+
on_fail = item_fn
|
|
1367
|
+
elif hasattr(item_fn, "_enoki_transitions"):
|
|
1368
|
+
transitions = item_fn
|
|
1369
|
+
elif hasattr(item_fn, "_enoki_events"):
|
|
1370
|
+
Events = item_fn
|
|
1371
|
+
|
|
1372
|
+
# Validate required methods
|
|
1373
|
+
if on_state is None:
|
|
1374
|
+
raise ValueError(f"State '{state_name}' must have an @on_state decorated method")
|
|
1375
|
+
|
|
1376
|
+
# Create the StateSpec
|
|
1377
|
+
state_spec = StateSpec(
|
|
1378
|
+
name=state_name,
|
|
1379
|
+
qualname=qualname,
|
|
1380
|
+
module=module,
|
|
1381
|
+
config=config,
|
|
1382
|
+
on_state=on_state,
|
|
1383
|
+
on_enter=on_enter,
|
|
1384
|
+
on_leave=on_leave,
|
|
1385
|
+
on_timeout=on_timeout,
|
|
1386
|
+
on_fail=on_fail,
|
|
1387
|
+
transitions=transitions,
|
|
1388
|
+
Events=Events,
|
|
1389
|
+
group_name=group or "",
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
# Register the state
|
|
1393
|
+
register_state(state_spec)
|
|
1394
|
+
|
|
1395
|
+
return state_spec
|
|
1396
|
+
|
|
1397
|
+
return decorator
|
|
1398
|
+
|
|
1399
|
+
def on_state(self, func: Callable) -> Callable:
|
|
1400
|
+
"""Decorator for the main state logic method"""
|
|
1401
|
+
func._enoki_on_state = func
|
|
1402
|
+
if self._tracking_stack:
|
|
1403
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1404
|
+
return func
|
|
1405
|
+
|
|
1406
|
+
def on_enter(self, func: Callable) -> Callable:
|
|
1407
|
+
"""Decorator for on_enter lifecycle method"""
|
|
1408
|
+
func._enoki_on_enter = func
|
|
1409
|
+
if self._tracking_stack:
|
|
1410
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1411
|
+
return func
|
|
1412
|
+
|
|
1413
|
+
def on_leave(self, func: Callable) -> Callable:
|
|
1414
|
+
"""Decorator for on_leave lifecycle method"""
|
|
1415
|
+
func._enoki_on_leave = func
|
|
1416
|
+
if self._tracking_stack:
|
|
1417
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1418
|
+
return func
|
|
1419
|
+
|
|
1420
|
+
def on_timeout(self, func: Callable) -> Callable:
|
|
1421
|
+
"""Decorator for on_timeout lifecycle method"""
|
|
1422
|
+
func._enoki_on_timeout = func
|
|
1423
|
+
if self._tracking_stack:
|
|
1424
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1425
|
+
return func
|
|
1426
|
+
|
|
1427
|
+
def on_fail(self, func: Callable) -> Callable:
|
|
1428
|
+
"""Decorator for on_fail lifecycle method"""
|
|
1429
|
+
func._enoki_on_fail = func
|
|
1430
|
+
if self._tracking_stack:
|
|
1431
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1432
|
+
return func
|
|
1433
|
+
|
|
1434
|
+
def transitions(self, func: Callable) -> Callable:
|
|
1435
|
+
"""Decorator for declaring transitions"""
|
|
1436
|
+
func._enoki_transitions = func
|
|
1437
|
+
if self._tracking_stack:
|
|
1438
|
+
self._tracking_stack[-1].append((func.__name__, func))
|
|
1439
|
+
return func
|
|
1440
|
+
|
|
1441
|
+
def events(self, events_class: type) -> type:
|
|
1442
|
+
"""
|
|
1443
|
+
Decorator for registering the Events enum.
|
|
1444
|
+
|
|
1445
|
+
Optional decorator that makes the Events enum accessible via state.Events
|
|
1446
|
+
for introspection and testing. The Events enum works via closure even
|
|
1447
|
+
without this decorator, but state.Events will be None.
|
|
1448
|
+
|
|
1449
|
+
Use this decorator when:
|
|
1450
|
+
- You need to access state.Events programmatically (e.g., in tests)
|
|
1451
|
+
- You want to inspect what events a state defines
|
|
1452
|
+
- Registry methods need to reference the Events enum
|
|
1453
|
+
|
|
1454
|
+
Skip this decorator when:
|
|
1455
|
+
- You only use Events within on_state/transitions (closure handles it)
|
|
1456
|
+
- You don't need external access to the Events enum
|
|
1457
|
+
|
|
1458
|
+
Example:
|
|
1459
|
+
@enoki.state()
|
|
1460
|
+
def MyState():
|
|
1461
|
+
@enoki.events # Optional - makes MyState.Events accessible
|
|
1462
|
+
class Events(Enum):
|
|
1463
|
+
GO = auto()
|
|
1464
|
+
|
|
1465
|
+
@enoki.on_state
|
|
1466
|
+
async def on_state(ctx):
|
|
1467
|
+
return Events.GO # Works via closure even without @enoki.events
|
|
1468
|
+
"""
|
|
1469
|
+
events_class._enoki_events = events_class
|
|
1470
|
+
if self._tracking_stack:
|
|
1471
|
+
self._tracking_stack[-1].append(("Events", events_class))
|
|
1472
|
+
return events_class
|
|
1473
|
+
|
|
1474
|
+
def root(self, func: Callable) -> Callable:
|
|
1475
|
+
"""Decorator to mark a state as the initial/root state"""
|
|
1476
|
+
func._enoki_is_root = True
|
|
1477
|
+
return func
|
|
1478
|
+
|
|
1479
|
+
|
|
1480
|
+
# Create the decorator API instance
|
|
1481
|
+
enoki = _EnokiDecoratorAPI()
|
|
1482
|
+
|
|
1483
|
+
# Export decorators for convenience
|
|
1484
|
+
state = enoki.state
|
|
1485
|
+
on_state = enoki.on_state
|
|
1486
|
+
on_enter = enoki.on_enter
|
|
1487
|
+
on_leave = enoki.on_leave
|
|
1488
|
+
on_timeout = enoki.on_timeout
|
|
1489
|
+
on_fail = enoki.on_fail
|
|
1490
|
+
transitions = enoki.transitions
|
|
1491
|
+
events = enoki.events
|
|
1492
|
+
root = enoki.root
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
# ============================================================================
|
|
1496
|
+
# Export all public components
|
|
1497
|
+
# ============================================================================
|
|
1498
|
+
|
|
1499
|
+
__all__ = [
|
|
1500
|
+
# Decorators
|
|
1501
|
+
"enoki",
|
|
1502
|
+
"state",
|
|
1503
|
+
"on_state",
|
|
1504
|
+
"on_enter",
|
|
1505
|
+
"on_leave",
|
|
1506
|
+
"on_timeout",
|
|
1507
|
+
"on_fail",
|
|
1508
|
+
"transitions",
|
|
1509
|
+
"root",
|
|
1510
|
+
# Core types
|
|
1511
|
+
"StateSpec",
|
|
1512
|
+
"StateConfiguration",
|
|
1513
|
+
"SharedContext",
|
|
1514
|
+
# Transitions
|
|
1515
|
+
"TransitionType",
|
|
1516
|
+
"StateTransition",
|
|
1517
|
+
"StateContinuation",
|
|
1518
|
+
"StateRenewal",
|
|
1519
|
+
"Again",
|
|
1520
|
+
"Unhandled",
|
|
1521
|
+
"Retry",
|
|
1522
|
+
"Restart",
|
|
1523
|
+
"Repeat",
|
|
1524
|
+
"StateRef",
|
|
1525
|
+
"LabeledTransition",
|
|
1526
|
+
"Push",
|
|
1527
|
+
"Pop",
|
|
1528
|
+
# Registry
|
|
1529
|
+
"register_state",
|
|
1530
|
+
"get_state",
|
|
1531
|
+
"get_all_states",
|
|
1532
|
+
"StateRegistry",
|
|
1533
|
+
"ValidationError",
|
|
1534
|
+
"ValidationResult",
|
|
1535
|
+
# StateMachine
|
|
1536
|
+
"StateMachine",
|
|
1537
|
+
"StateMachineComplete",
|
|
1538
|
+
"BlockedInUntimedState",
|
|
1539
|
+
"PopFromEmptyStack",
|
|
1540
|
+
"NoStateToTick",
|
|
1541
|
+
# Messages
|
|
1542
|
+
"PrioritizedMessage",
|
|
1543
|
+
"EnokiInternalMessage",
|
|
1544
|
+
"TimeoutMessage",
|
|
1545
|
+
]
|