mycorrhizal 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. mycorrhizal/__init__.py +3 -0
  2. mycorrhizal/common/__init__.py +68 -0
  3. mycorrhizal/common/interface_builder.py +203 -0
  4. mycorrhizal/common/interfaces.py +412 -0
  5. mycorrhizal/common/timebase.py +99 -0
  6. mycorrhizal/common/wrappers.py +532 -0
  7. mycorrhizal/enoki/__init__.py +0 -0
  8. mycorrhizal/enoki/core.py +1545 -0
  9. mycorrhizal/enoki/testing_utils.py +529 -0
  10. mycorrhizal/enoki/util.py +220 -0
  11. mycorrhizal/hypha/__init__.py +0 -0
  12. mycorrhizal/hypha/core/__init__.py +107 -0
  13. mycorrhizal/hypha/core/builder.py +404 -0
  14. mycorrhizal/hypha/core/runtime.py +890 -0
  15. mycorrhizal/hypha/core/specs.py +234 -0
  16. mycorrhizal/hypha/util.py +38 -0
  17. mycorrhizal/rhizomorph/README.md +220 -0
  18. mycorrhizal/rhizomorph/__init__.py +0 -0
  19. mycorrhizal/rhizomorph/core.py +1729 -0
  20. mycorrhizal/rhizomorph/util.py +45 -0
  21. mycorrhizal/spores/__init__.py +124 -0
  22. mycorrhizal/spores/cache.py +208 -0
  23. mycorrhizal/spores/core.py +419 -0
  24. mycorrhizal/spores/dsl/__init__.py +48 -0
  25. mycorrhizal/spores/dsl/enoki.py +514 -0
  26. mycorrhizal/spores/dsl/hypha.py +399 -0
  27. mycorrhizal/spores/dsl/rhizomorph.py +351 -0
  28. mycorrhizal/spores/encoder/__init__.py +11 -0
  29. mycorrhizal/spores/encoder/base.py +42 -0
  30. mycorrhizal/spores/encoder/json.py +159 -0
  31. mycorrhizal/spores/extraction.py +484 -0
  32. mycorrhizal/spores/models.py +288 -0
  33. mycorrhizal/spores/transport/__init__.py +10 -0
  34. mycorrhizal/spores/transport/base.py +46 -0
  35. mycorrhizal-0.1.0.dist-info/METADATA +198 -0
  36. mycorrhizal-0.1.0.dist-info/RECORD +37 -0
  37. mycorrhizal-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1545 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Enoki - Asyncio Finite State Machine Framework
4
+
5
+ A decorator-based DSL for defining and executing state machines with support for
6
+ asyncio, timeouts, message passing, and hierarchical state composition.
7
+
8
+ Usage:
9
+ from mycorrhizal.enoki.core import enoki, StateMachine, LabeledTransition
10
+ from enum import Enum, auto
11
+
12
+ @enoki.state()
13
+ def IdleState():
14
+ class Events(Enum):
15
+ START = auto()
16
+ QUIT = auto()
17
+
18
+ @enoki.on_state
19
+ async def on_state(ctx):
20
+ if ctx.msg == "start":
21
+ return Events.START
22
+ return None
23
+
24
+ @enoki.transitions
25
+ def transitions():
26
+ return [
27
+ LabeledTransition(Events.START, ProcessingState),
28
+ LabeledTransition(Events.QUIT, DoneState),
29
+ ]
30
+
31
+ # Create and run the FSM
32
+ fsm = StateMachine(initial_state=IdleState, common_data={})
33
+ await fsm.initialize()
34
+ fsm.send_message("start")
35
+ await fsm.tick()
36
+
37
+ Key Classes:
38
+ StateMachine - Main FSM runtime with message queue and tick-based execution
39
+ StateConfiguration - Configuration for states (timeout, retries, terminal, can_dwell)
40
+ SharedContext - Context object passed to state handlers with msg and common data
41
+ LabeledTransition - Maps events to target states
42
+
43
+ Transition Types:
44
+ State references - Direct transition to another state
45
+ Again - Re-execute current state immediately
46
+ Unhandled - Wait for next message
47
+ Retry - Re-enter state with retry counter
48
+ Restart - Reset retry counter and wait for message
49
+ Repeat - Re-enter state from on_enter
50
+ Push(state1, state2, ...) - Push states onto stack
51
+ Pop - Pop and return to previous state
52
+
53
+ State Handlers:
54
+ on_state(ctx) - Main state logic, return transition or None to wait
55
+ on_enter(ctx) - Called when entering state
56
+ on_leave(ctx) - Called when leaving state
57
+ on_timeout(ctx) - Called if timeout expires
58
+ on_fail(ctx) - Called if exception occurs
59
+ """
60
+
61
+ from __future__ import annotations
62
+ import inspect
63
+ import sys
64
+ import os
65
+ import time
66
+ import traceback
67
+ import asyncio
68
+ from pathlib import Path
69
+ from asyncio import PriorityQueue
70
+ from dataclasses import dataclass, field
71
+ from enum import Enum, auto, IntEnum
72
+ from typing import (
73
+ Any,
74
+ Callable,
75
+ Dict,
76
+ List,
77
+ Optional,
78
+ Union,
79
+ Awaitable,
80
+ TypeVar,
81
+ Generic,
82
+ )
83
+ from functools import cache
84
+
85
+
86
+ T = TypeVar('T')
87
+
88
+
89
+ # ============================================================================
90
+ # Interface Integration Helper
91
+ # ============================================================================
92
+
93
+
94
+ def _create_interface_view_for_context(context: SharedContext, handler: Callable) -> SharedContext:
95
+ """
96
+ Create a constrained view of context.common if the handler has an interface type hint.
97
+
98
+ This enables type-safe, constrained access to blackboard state based on
99
+ interface definitions created with @blackboard_interface.
100
+
101
+ Args:
102
+ context: The SharedContext object
103
+ handler: The state handler function to check for interface type hints
104
+
105
+ Returns:
106
+ Either the original context or a new context with constrained common field
107
+ """
108
+ from typing import get_type_hints
109
+
110
+ try:
111
+ sig = inspect.signature(handler)
112
+ params = list(sig.parameters.values())
113
+
114
+ # Check first parameter (should be SharedContext)
115
+ if params and params[0].name == 'ctx':
116
+ ctx_type = get_type_hints(handler).get('ctx')
117
+
118
+ # If type hint exists and is a generic SharedContext with interface
119
+ if ctx_type and hasattr(ctx_type, '__args__'):
120
+ # Extract the type argument from SharedContext[InterfaceType]
121
+ interface_type = ctx_type.__args__[0]
122
+
123
+ # If it has interface metadata, create constrained view
124
+ if hasattr(interface_type, '_readonly_fields'):
125
+ from mycorrhizal.common.wrappers import create_view_from_protocol
126
+ from dataclasses import replace
127
+
128
+ # Create constrained view of common
129
+ constrained_common = create_view_from_protocol(context.common, interface_type)
130
+
131
+ # Return new context with constrained common
132
+ return replace(context, common=constrained_common)
133
+ except Exception:
134
+ # If anything goes wrong with type inspection, fall back to original context
135
+ pass
136
+
137
+ return context
138
+
139
+
140
+ @dataclass
141
+ class StateConfiguration:
142
+ """Configuration options for states.
143
+
144
+ Args:
145
+ timeout: Optional timeout in seconds. If no message received within
146
+ this time, on_timeout handler is called.
147
+ retries: Optional number of retries allowed. If state returns Retry,
148
+ the retry counter increments. When max retries exceeded, state fails.
149
+ terminal: If True, reaching this state completes the FSM (raises
150
+ StateMachineComplete exception).
151
+ can_dwell: If True, state can wait indefinitely without returning a
152
+ transition (returning None from on_state is allowed).
153
+
154
+ Example:
155
+ @enoki.state(config=StateConfiguration(timeout=5.0, retries=3))
156
+ def MyState():
157
+ # State with timeout and retry handling
158
+ pass
159
+ """
160
+
161
+ timeout: Optional[float] = None
162
+ retries: Optional[int] = None
163
+ terminal: bool = False
164
+ can_dwell: bool = False
165
+
166
+
167
+ @dataclass
168
+ class SharedContext(Generic[T]):
169
+ """Context passed to state handler methods.
170
+
171
+ The SharedContext provides access to the current message, shared data,
172
+ and utilities for state handlers.
173
+
174
+ Attributes:
175
+ send_message: Function to send messages to the state machine
176
+ log: Function to log messages (default: prints with [FSM] prefix)
177
+ common: Shared data passed to StateMachine (accessible as ctx.common)
178
+ msg: The current message being processed (if any)
179
+
180
+ Type parameter T represents the type of the 'common' field for type safety.
181
+
182
+ Example:
183
+ @enoki.on_state
184
+ async def on_state(ctx: SharedContext):
185
+ # Access shared data
186
+ counter = ctx.common.get("counter", 0)
187
+
188
+ # Check for messages
189
+ if ctx.msg == "start":
190
+ return Events.START
191
+
192
+ # Send new messages
193
+ ctx.send_message("ping")
194
+ """
195
+
196
+ send_message: Callable
197
+ log: Callable
198
+ common: T
199
+ msg: Optional[Any] = None
200
+
201
+
202
+ # ============================================================================
203
+ # Transition Type Hierarchy
204
+ # ============================================================================
205
+
206
+ @dataclass
207
+ class TransitionType:
208
+ """Base type for all transitions"""
209
+
210
+ AWAITS: bool = False
211
+
212
+ @property
213
+ def name(self) -> str:
214
+ return self.__class__.__name__
215
+
216
+ @property
217
+ def base_name(self) -> str:
218
+ if hasattr(self.__class__, "__bases__") and self.__class__.__bases__:
219
+ return self.__class__.__bases__[0].__name__
220
+ return self.name
221
+
222
+ def __hash__(self):
223
+ return hash(f"{self.base_name()}.{self.name}")
224
+
225
+ def __eq__(self, other):
226
+ if isinstance(other, TransitionType):
227
+ return self.name == other.name
228
+ return False
229
+
230
+
231
+ @dataclass
232
+ class StateTransition(TransitionType):
233
+ """Represents a transition to a different state"""
234
+
235
+ pass
236
+
237
+
238
+ @dataclass
239
+ class StateContinuation(TransitionType):
240
+ """Represents staying in the same state"""
241
+
242
+ pass
243
+
244
+
245
+ @dataclass
246
+ class StateRenewal(TransitionType):
247
+ """Represents a transition back into the current state"""
248
+
249
+
250
+ @dataclass
251
+ class Again(StateContinuation):
252
+ """Execute current state immediately without affecting retry counter"""
253
+
254
+ pass
255
+
256
+
257
+ @dataclass
258
+ class Unhandled(StateContinuation):
259
+ """Wait for next message/event without affecting retry counter"""
260
+
261
+ AWAITS: bool = True
262
+
263
+
264
+ @dataclass
265
+ class Retry(StateContinuation):
266
+ """Retry current state, decrementing retry counter, starting from on_enter"""
267
+
268
+ pass
269
+
270
+
271
+ @dataclass
272
+ class Restart(StateRenewal):
273
+ """Restart current state, reset retry counter, await message"""
274
+
275
+ AWAITS: bool = True
276
+
277
+
278
+ @dataclass
279
+ class Repeat(StateRenewal):
280
+ """Repeat current state, reset retry counter, execute on_enter immediately"""
281
+
282
+ pass
283
+
284
+
285
+ @dataclass
286
+ class StateRef(StateTransition):
287
+ """String-based state reference for breaking circular imports"""
288
+
289
+ state: str = ""
290
+
291
+ def __init__(self, state: str):
292
+ self.state = state
293
+ object.__setattr__(self, "AWAITS", False)
294
+
295
+ def __hash__(self):
296
+ return hash(self.state)
297
+
298
+ def __eq__(self, other):
299
+ if isinstance(other, StateRef):
300
+ return self.state == other.state
301
+ return False
302
+
303
+
304
+ @dataclass
305
+ class LabeledTransition:
306
+ """A transition with a label (typically an enum value)"""
307
+
308
+ label: Enum
309
+ transition: TransitionType
310
+
311
+
312
+ # ============================================================================
313
+ # Push/Pop Transitions (defined later to depend on StateSpec)
314
+ # ============================================================================
315
+
316
+ @dataclass
317
+ class Push(StateTransition):
318
+ """Push one or more states onto the stack"""
319
+
320
+ push_states: List[StateTransition] = field(default_factory=list)
321
+
322
+ def __init__(self, *states: StateTransition):
323
+ object.__setattr__(self, "push_states", list(states))
324
+
325
+ def __hash__(self):
326
+ return hash(".".join([s.name for s in self.push_states]))
327
+
328
+
329
+ @dataclass
330
+ class Pop(StateTransition):
331
+ """Pop from the stack to return to previous state"""
332
+
333
+ pass
334
+
335
+
336
+ # ============================================================================
337
+ # StateSpec - The heart of the decorator system
338
+ # ============================================================================
339
+
340
+ @dataclass
341
+ class StateSpec(StateTransition):
342
+ """
343
+ A state defined via decorators.
344
+
345
+ This is a StateTransition, so it can be used directly in transition lists.
346
+ It has the same interface as the old State class but is constructed by
347
+ decorators rather than metaclass magic.
348
+ """
349
+
350
+ # Core identity
351
+ name: str = ""
352
+ qualname: str = ""
353
+ module: str = ""
354
+
355
+ # Configuration
356
+ config: StateConfiguration = field(default_factory=StateConfiguration)
357
+
358
+ # Lifecycle methods (all optional except on_state)
359
+ on_state: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
360
+ on_enter: Optional[Callable[[SharedContext], Awaitable[None]]] = None
361
+ on_leave: Optional[Callable[[SharedContext], Awaitable[None]]] = None
362
+ on_timeout: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
363
+ on_fail: Optional[Callable[[SharedContext], Awaitable[TransitionType]]] = None
364
+
365
+ # Transitions
366
+ transitions: Optional[Callable[[], List[Union[LabeledTransition, TransitionType]]]] = None
367
+
368
+ # Nested event enum (optional)
369
+ Events: Optional[type[Enum]] = None
370
+
371
+ # Group information (for state groups/namespaces)
372
+ group_name: str = field(default="")
373
+ parent_state: Optional[StateSpec] = None
374
+
375
+ # Metadata
376
+ _is_root: bool = field(default=False, repr=False)
377
+
378
+ @property
379
+ def base_name(self) -> str:
380
+ """Get just the state name without module path"""
381
+ return self.qualname.split(".")[-1]
382
+
383
+ @property
384
+ def CONFIG(self) -> StateConfiguration:
385
+ """Alias for config (for compatibility with State class API)"""
386
+ return self.config
387
+
388
+ def __hash__(self):
389
+ return hash(self.name)
390
+
391
+ def __eq__(self, other):
392
+ if isinstance(other, StateSpec):
393
+ return self.name == other.name
394
+ if isinstance(other, str):
395
+ return self.name == other
396
+ return False
397
+
398
+ def get_transitions(self) -> List[Union[LabeledTransition, TransitionType]]:
399
+ """Get the transition list for this state"""
400
+ if self.transitions is None:
401
+ return []
402
+ result = self.transitions()
403
+ # Handle single transition returned
404
+ if isinstance(result, (TransitionType, LabeledTransition)):
405
+ return [result]
406
+ return result
407
+
408
+
409
+ # ============================================================================
410
+ # Global Registry
411
+ # ============================================================================
412
+
413
+ _state_registry: Dict[str, StateSpec] = {}
414
+
415
+
416
+ def register_state(state: StateSpec) -> None:
417
+ """Register a state in the global registry"""
418
+ _state_registry[state.name] = state
419
+
420
+
421
+ def get_state(name: str) -> Optional[StateSpec]:
422
+ """Get a state from the global registry"""
423
+ return _state_registry.get(name)
424
+
425
+
426
+ def get_all_states() -> Dict[str, StateSpec]:
427
+ """Get all registered states"""
428
+ return _state_registry.copy()
429
+
430
+
431
+ # ============================================================================
432
+ # StateRegistry - Validation and Resolution
433
+ # ============================================================================
434
+
435
+ class ValidationError(Exception):
436
+ """Raised when state machine validation fails"""
437
+ pass
438
+
439
+
440
+ @dataclass
441
+ class ValidationResult:
442
+ """Results of state machine validation"""
443
+ valid: bool
444
+ errors: list[str]
445
+ warnings: list[str]
446
+ discovered_states: set[str]
447
+
448
+
449
+ class StateRegistry:
450
+ """
451
+ Manages state resolution and validation for StateSpec-based states.
452
+
453
+ This registry handles:
454
+ - Resolving StateRef strings to StateSpec objects
455
+ - Validating all states in a state machine
456
+ - Getting transition mappings for states
457
+ - Detecting circular references and invalid transitions
458
+ """
459
+
460
+ def __init__(self):
461
+ self._state_cache: Dict[str, StateSpec] = {}
462
+ self._transition_cache: Dict[str, Dict] = {}
463
+ self._resolved_modules: Dict[str, Any] = {}
464
+ self._validation_errors: list[str] = []
465
+ self._validation_warnings: list[str] = []
466
+
467
+ def resolve_state(
468
+ self, state_ref: Union[str, StateRef, StateSpec], validate_only: bool = False
469
+ ) -> Optional[StateSpec]:
470
+ """
471
+ Resolve a state reference to an actual StateSpec object.
472
+
473
+ Args:
474
+ state_ref: Can be a StateSpec, a StateRef, or a fully qualified state name string
475
+ validate_only: If True, don't cache the result (used during validation)
476
+
477
+ Returns:
478
+ The resolved StateSpec, or None if not found (when validate_only=True)
479
+ """
480
+ # If it's already a StateSpec, return it
481
+ if isinstance(state_ref, StateSpec):
482
+ # Track resolved module
483
+ if state_ref.module and state_ref.module not in self._resolved_modules:
484
+ mod = sys.modules.get(state_ref.module)
485
+ self._resolved_modules[state_ref.module] = mod
486
+
487
+ # Store short name for easier local file lookup
488
+ if mod and hasattr(mod, "__file__"):
489
+ fp = Path(getattr(mod, "__file__"))
490
+ self._resolved_modules[fp.stem] = mod
491
+
492
+ return state_ref
493
+
494
+ # If it's a StateRef, resolve the string
495
+ if isinstance(state_ref, StateRef):
496
+ state_name = state_ref.state
497
+ elif isinstance(state_ref, str):
498
+ state_name = state_ref
499
+ else:
500
+ return None
501
+
502
+ # Check cache first
503
+ if state_name in self._state_cache and not validate_only:
504
+ return self._state_cache[state_name]
505
+
506
+ # Try to get from global registry
507
+ state = get_state(state_name)
508
+ if state:
509
+ if not validate_only:
510
+ self._state_cache[state_name] = state
511
+ return state
512
+
513
+ # If not in global registry, we can't resolve it
514
+ # (In the old system, it would try dynamic imports, but with
515
+ # decorator-based states, everything should be pre-registered)
516
+ if validate_only:
517
+ self._validation_errors.append(
518
+ f"State '{state_name}' not found in registry"
519
+ )
520
+ return None
521
+ else:
522
+ raise ValidationError(f"State '{state_name}' not found in registry")
523
+
524
+ def validate_all_states(
525
+ self,
526
+ initial_state: Union[str, StateRef, StateSpec],
527
+ error_state: Optional[Union[str, StateRef, StateSpec]] = None,
528
+ ) -> ValidationResult:
529
+ """
530
+ Validate all states reachable from the initial state.
531
+
532
+ Args:
533
+ initial_state: The starting state
534
+ error_state: Optional error state
535
+
536
+ Returns:
537
+ ValidationResult with errors, warnings, and discovered states
538
+ """
539
+ self._validation_errors = []
540
+ self._validation_warnings = []
541
+
542
+ # Discover all reachable states
543
+ discovered_states = set()
544
+ state_references = set()
545
+ resolved_references = set()
546
+ states_to_visit = []
547
+
548
+ # Resolve initial state
549
+ initial_resolved = self.resolve_state(initial_state, validate_only=True)
550
+ if initial_resolved:
551
+ states_to_visit.append(initial_resolved)
552
+ discovered_states.add(initial_resolved.name)
553
+
554
+ # Resolve error state if provided
555
+ if error_state:
556
+ error_resolved = self.resolve_state(error_state, validate_only=True)
557
+ if error_resolved:
558
+ states_to_visit.append(error_resolved)
559
+ discovered_states.add(error_resolved.name)
560
+
561
+ # Traverse all reachable states
562
+ visited = set()
563
+ while states_to_visit:
564
+ current_state = states_to_visit.pop()
565
+ state_name = current_state.name
566
+
567
+ if state_name in visited:
568
+ continue
569
+ visited.add(state_name)
570
+
571
+ def handle_state(s):
572
+ discovered_states.add(s.name if hasattr(s, 'name') else str(s))
573
+ states_to_visit.append(s)
574
+
575
+ def handle_reference(r):
576
+ state_references.add(r)
577
+
578
+ def handle_push(p):
579
+ for state in p.push_states:
580
+ match state:
581
+ case s if isinstance(s, StateSpec):
582
+ handle_state(s)
583
+ case r if isinstance(r, StateRef):
584
+ handle_reference(r)
585
+ case invalid:
586
+ self._validation_errors.append(
587
+ f"State '{state_name}' attempted to push an invalid transition: {invalid}"
588
+ )
589
+
590
+ try:
591
+ # Get transitions
592
+ raw_transitions = current_state.get_transitions()
593
+
594
+ # Handle different return types
595
+ match raw_transitions:
596
+ case list() | tuple():
597
+ pass # Already a list
598
+ case _ if isinstance(raw_transitions, (TransitionType, LabeledTransition)):
599
+ raw_transitions = [raw_transitions]
600
+ case _:
601
+ raw_transitions = []
602
+
603
+ # Validate each transition
604
+ for transition in raw_transitions:
605
+ match transition:
606
+ case _ if isinstance(transition, Push):
607
+ # Push transitions must have labels
608
+ self._validation_errors.append(
609
+ f"State '{state_name}' has a push transition without a label"
610
+ )
611
+ handle_push(transition)
612
+ case LabeledTransition(label, s):
613
+ match s:
614
+ case st if isinstance(st, StateSpec):
615
+ handle_state(st)
616
+ case r if isinstance(r, StateRef):
617
+ handle_reference(r)
618
+ case p if isinstance(p, Push):
619
+ handle_push(p)
620
+ case s if isinstance(s, StateSpec):
621
+ handle_state(s)
622
+ case r if isinstance(r, StateRef):
623
+ self._validation_warnings.append(
624
+ f"State '{state_name}' has a StateReference transition without a label"
625
+ )
626
+ handle_reference(r)
627
+ case _ if isinstance(transition, (StateContinuation, StateRenewal)):
628
+ # Valid continuation/renewal
629
+ pass
630
+ case invalid:
631
+ self._validation_errors.append(
632
+ f"State '{state_name}' has an invalid transition: {invalid}"
633
+ )
634
+
635
+ # Validate that the state has an on_state handler
636
+ if current_state.on_state is None:
637
+ self._validation_errors.append(
638
+ f"State {state_name} does not have an on_state handler defined"
639
+ )
640
+
641
+ # Make sure non-terminal states define transitions
642
+ if not current_state.config.terminal:
643
+ if current_state.transitions is None:
644
+ self._validation_errors.append(
645
+ f"Non-terminal state {state_name} does not define any transitions"
646
+ )
647
+
648
+ # Validate timeout configuration
649
+ if current_state.config.timeout is not None:
650
+ if (
651
+ not isinstance(current_state.config.timeout, (int, float))
652
+ or current_state.config.timeout <= 0
653
+ ):
654
+ self._validation_errors.append(
655
+ f"State '{state_name}' has invalid timeout: {current_state.config.timeout}"
656
+ )
657
+
658
+ has_timeout_handler = current_state.on_timeout is not None
659
+ if not has_timeout_handler:
660
+ self._validation_warnings.append(
661
+ f"State '{state_name}' has timeout but no on_timeout handler defined"
662
+ )
663
+
664
+ # Pass 2: Validate all string references
665
+ for state_ref in list(state_references - resolved_references):
666
+ resolved = self.resolve_state(state_ref, validate_only=True)
667
+ if resolved:
668
+ resolved_references.add(state_ref)
669
+ if resolved.name not in discovered_states:
670
+ discovered_states.add(resolved.name)
671
+ states_to_visit.append(resolved)
672
+
673
+ except Exception as e:
674
+ self._validation_errors.append(
675
+ f"Error processing transitions for state '{state_name}': {traceback.format_exc()}"
676
+ )
677
+
678
+ return ValidationResult(
679
+ valid=len(self._validation_errors) == 0,
680
+ errors=self._validation_errors.copy(),
681
+ warnings=self._validation_warnings.copy(),
682
+ discovered_states=discovered_states,
683
+ )
684
+
685
+ def get_transitions(self, state: StateSpec) -> Dict:
686
+ """
687
+ Get the transition mapping for a state, with caching.
688
+
689
+ Returns a dict mapping event labels/enum values to transition targets.
690
+ """
691
+ state_name = state.name
692
+
693
+ if state_name in self._transition_cache:
694
+ return self._transition_cache[state_name]
695
+
696
+ resolved_transitions = {}
697
+
698
+ raw_transitions = state.get_transitions()
699
+ match raw_transitions:
700
+ case list() | tuple():
701
+ pass # Already a list
702
+ case _ if isinstance(raw_transitions, (TransitionType, LabeledTransition)):
703
+ raw_transitions = [raw_transitions]
704
+ case _:
705
+ raw_transitions = []
706
+
707
+ for transition in raw_transitions:
708
+ match transition:
709
+ case LabeledTransition(label, target):
710
+ # Use the label enum itself as the key
711
+ resolved_transitions[label] = target
712
+ # Also allow lookup by label name
713
+ resolved_transitions[label.name] = target
714
+ # Also allow lookup by label value
715
+ if hasattr(label, 'value'):
716
+ resolved_transitions[label.value] = target
717
+ case s if isinstance(s, StateSpec):
718
+ # Direct state reference
719
+ resolved_transitions[s.name] = s
720
+ resolved_transitions[s] = s
721
+ case r if isinstance(r, StateRef):
722
+ # String reference - resolve it
723
+ resolved = self.resolve_state(r)
724
+ if resolved:
725
+ resolved_transitions[r.state] = resolved
726
+ case _ if isinstance(transition, (StateContinuation, StateRenewal)):
727
+ # Special transitions - allow lookup by type
728
+ resolved_transitions[transition] = transition
729
+
730
+ self._transition_cache[state_name] = resolved_transitions
731
+ return resolved_transitions
732
+
733
+
734
+ # ============================================================================
735
+ # Message Types for StateMachine
736
+ # ============================================================================
737
+
738
+ @dataclass
739
+ class PrioritizedMessage:
740
+ priority: int
741
+ item: Any = field(compare=False)
742
+
743
+ def __lt__(self, other):
744
+ """Compare based on priority only"""
745
+ if not isinstance(other, PrioritizedMessage):
746
+ return NotImplemented
747
+ return self.priority < other.priority
748
+
749
+
750
+ @dataclass
751
+ class EnokiInternalMessage:
752
+ """Base class for internal FSM messages"""
753
+ pass
754
+
755
+
756
+ @dataclass
757
+ class TimeoutMessage(EnokiInternalMessage):
758
+ """Internal message for timeout events"""
759
+ state_name: str
760
+ timeout_id: int
761
+ timeout_duration: float
762
+
763
+
764
+ # ============================================================================
765
+ # Exceptions for StateMachine
766
+ # ============================================================================
767
+
768
+ class StateMachineComplete(Exception):
769
+ """Raised when the state machine has reached a terminal state"""
770
+ pass
771
+
772
+
773
+ class BlockedInUntimedState(Exception):
774
+ """Raised when a non-dwelling state blocks without a timeout"""
775
+ pass
776
+
777
+
778
+ class PopFromEmptyStack(Exception):
779
+ """Raised when we attempt to pop from an empty stack"""
780
+ pass
781
+
782
+
783
+ class NoStateToTick(Exception):
784
+ """Raised when the FSM object doesn't have a state attached"""
785
+ pass
786
+
787
+
788
+ # ============================================================================
789
+ # StateMachine Runtime
790
+ # ============================================================================
791
+
792
+ class StateMachine:
793
+ """Asyncio-native finite state machine for executing StateSpec-based states.
794
+
795
+ The StateMachine manages state execution, transitions, message passing,
796
+ and lifecycle handling. States are defined using the @enoki.state decorator
797
+ and should define on_state, on_enter, on_leave, on_timeout, or on_fail handlers.
798
+
799
+ Args:
800
+ initial_state: The initial state (state function, StateRef, or StateSpec)
801
+ error_state: Optional error state for handling exceptions
802
+ filter_fn: Optional function to filter messages (ctx, msg) -> bool
803
+ trap_fn: Optional function to trap exceptions (exc) -> None
804
+ on_error_fn: Optional function called on errors (ctx, exc) -> None
805
+ common_data: Shared data accessible via ctx.common in all states
806
+
807
+ Attributes:
808
+ current_state: The currently executing state
809
+ context: SharedContext containing msg and common data
810
+ state_stack: Stack for Push/Pop hierarchical state management
811
+
812
+ Methods:
813
+ initialize(): Initialize the state machine (must be called before tick)
814
+ tick(timeout): Process one state machine tick
815
+ send_message(msg): Send a message to the state machine
816
+ reset(): Reset to initial state
817
+
818
+ Example:
819
+ fsm = StateMachine(
820
+ initial_state=IdleState,
821
+ error_state=ErrorState,
822
+ common_data={"counter": 0}
823
+ )
824
+ await fsm.initialize()
825
+ fsm.send_message("start")
826
+ await fsm.tick()
827
+ """
828
+
829
+ class MessagePriorities(IntEnum):
830
+ ERROR = 0
831
+ INTERNAL_MESSAGE = 1
832
+ MESSAGE = 2
833
+
834
+ def __init__(
835
+ self,
836
+ initial_state: Union[str, StateRef, StateSpec],
837
+ error_state: Optional[Union[str, StateRef, StateSpec]] = None,
838
+ filter_fn: Optional[Callable] = None,
839
+ trap_fn: Optional[Callable] = None,
840
+ on_error_fn: Optional[Callable] = None,
841
+ common_data: Optional[Any] = None,
842
+ ):
843
+
844
+ self.registry = StateRegistry()
845
+
846
+ # The filter and trap functions
847
+ self._filter_fn = filter_fn or (lambda x, y: None)
848
+ self._trap_fn = trap_fn or (lambda x: None)
849
+ self._on_err_fn = on_error_fn or (lambda x, y: None)
850
+
851
+ # Validate the state machine structure
852
+ validation_result = self.registry.validate_all_states(
853
+ initial_state, error_state
854
+ )
855
+
856
+ if not validation_result.valid:
857
+ error_msg = "State machine validation failed:\n" + "\n".join(
858
+ validation_result.errors
859
+ )
860
+ if validation_result.warnings:
861
+ error_msg += "\n\nWarnings:\n" + "\n".join(validation_result.warnings)
862
+ raise ValidationError(error_msg)
863
+
864
+ # Log warnings
865
+ if validation_result.warnings:
866
+ for warning in validation_result.warnings:
867
+ self.log(f"WARNING: {warning}")
868
+
869
+ # Resolve initial and error states
870
+ self.initial_state = self.registry.resolve_state(initial_state)
871
+ self.error_state = (
872
+ self.registry.resolve_state(error_state) if error_state else None
873
+ )
874
+
875
+ # Asyncio-based infrastructure
876
+ self._message_queue = PriorityQueue()
877
+ self._timeout_task = None
878
+ self._timeout_counter = 0
879
+ self._current_timeout_id = None
880
+
881
+ # Runtime state
882
+ self.current_state: Optional[StateSpec] = None
883
+ self.state_stack: List[StateSpec] = []
884
+ self.context = SharedContext(
885
+ send_message=self.send_message,
886
+ log=self.log,
887
+ common=common_data if common_data is not None else dict(),
888
+ )
889
+
890
+ self.retry_counters: Dict[str, int] = {}
891
+ self.state_enter_times: Dict[str, float] = {}
892
+
893
+ # Initialize asynchronously
894
+ self._initialized = False
895
+
896
+ async def initialize(self):
897
+ """Initialize the state machine asynchronously"""
898
+ if not self._initialized:
899
+ await self.reset()
900
+ self._initialized = True
901
+
902
+ def send_message(self, message: Any):
903
+ """Send a message to the state machine"""
904
+ try:
905
+ if message and self._filter_fn(self.context, message):
906
+ return
907
+ self._send_message_internal(message)
908
+ except Exception as e:
909
+ self._send_message_internal(e)
910
+
911
+ def _send_message_internal(self, message: Any):
912
+ """Internal message sending"""
913
+
914
+ match message:
915
+ case _ if isinstance(message, EnokiInternalMessage):
916
+ try:
917
+ self._message_queue.put_nowait(
918
+ PrioritizedMessage(
919
+ self.MessagePriorities.INTERNAL_MESSAGE, message
920
+ )
921
+ )
922
+ except Exception:
923
+ pass # Queue full
924
+ case _ if isinstance(message, Exception):
925
+ try:
926
+ self._message_queue.put_nowait(
927
+ PrioritizedMessage(self.MessagePriorities.ERROR, message)
928
+ )
929
+ except Exception:
930
+ pass
931
+ case _:
932
+ try:
933
+ self._message_queue.put_nowait(
934
+ PrioritizedMessage(self.MessagePriorities.MESSAGE, message)
935
+ )
936
+ except Exception:
937
+ pass
938
+
939
+ async def reset(self):
940
+ """Reset the state machine to initial state"""
941
+ self._cancel_timeout()
942
+ self.state_stack = []
943
+ self.retry_counters = {}
944
+ self.state_enter_times = {}
945
+
946
+ # Clear message queue
947
+ while not self._message_queue.empty():
948
+ try:
949
+ self._message_queue.get_nowait()
950
+ except Exception:
951
+ break
952
+
953
+ # Transition into our initial state
954
+ await self._transition_to_state(self.initial_state)
955
+
956
+ def log(self, message: str):
957
+ """Log a message"""
958
+ print(f"[FSM] {message}")
959
+
960
+ async def tick(self, timeout: Optional[Union[float, int]] = 0):
961
+ """Process one state machine tick"""
962
+
963
+ if not self._initialized:
964
+ await self.initialize()
965
+
966
+ if not self.current_state:
967
+ raise NoStateToTick()
968
+
969
+ # Validate timeout parameter
970
+ match timeout:
971
+ case x if x and not isinstance(x, (int, float)):
972
+ raise ValueError(f"Tick timeout must be None, or an int/float >= 0")
973
+ case 0:
974
+ block = False
975
+ case None:
976
+ block = True
977
+ case x if x > 0:
978
+ block = True
979
+ case _:
980
+ raise ValueError("Tick timeout must be None or >= 0")
981
+
982
+ while True:
983
+ try:
984
+ if isinstance(self.context.msg, Exception):
985
+ raise self.context.msg
986
+
987
+ # Handle timeout messages specially
988
+ if isinstance(self.context.msg, TimeoutMessage):
989
+ if not self._handle_timeout_message():
990
+ return
991
+ transition = await self._call_on_timeout(self.current_state)
992
+ else:
993
+ transition = await self._call_on_state(self.current_state)
994
+
995
+ if self.current_state.config.terminal:
996
+ raise StateMachineComplete()
997
+
998
+ if transition in (None, Unhandled):
999
+ self._trap_fn(self.context)
1000
+
1001
+ self.context.msg = None
1002
+
1003
+ should_check_queue = await self._process_transition(transition, block)
1004
+
1005
+ # In manual mode (block=False), stop after processing one transition if not checking queue
1006
+ if not block and not should_check_queue:
1007
+ break
1008
+
1009
+ if should_check_queue:
1010
+ if block:
1011
+ message = await asyncio.wait_for(
1012
+ self._message_queue.get(), timeout=timeout
1013
+ )
1014
+ message = message.item
1015
+ else:
1016
+ message = self._message_queue.get_nowait().item
1017
+ self.context.msg = message
1018
+ except asyncio.QueueEmpty:
1019
+ break
1020
+ except asyncio.TimeoutError:
1021
+ break
1022
+ except Exception as e:
1023
+ if (
1024
+ isinstance(e, StateMachineComplete)
1025
+ or self.current_state.config.terminal
1026
+ ):
1027
+ raise
1028
+
1029
+ next_transition = self._on_err_fn(self.context, e)
1030
+ if next_transition:
1031
+ await self._process_transition(next_transition, block)
1032
+ elif self.error_state and self.current_state != self.error_state:
1033
+ # Clear the exception message before transitioning
1034
+ self.context.msg = None
1035
+ await self._transition_to_state(self.error_state)
1036
+ elif self.current_state == self.error_state:
1037
+ # Already in error state and got another exception - stop
1038
+ break
1039
+ else:
1040
+ raise
1041
+
1042
+ return None
1043
+
1044
+ async def _call_on_state(self, state: StateSpec):
1045
+ """Call the state's on_state handler"""
1046
+ if state.on_state is None:
1047
+ raise ValueError(f"State {state.name} has no on_state handler")
1048
+ # Create interface view if handler has interface type hint
1049
+ ctx_to_pass = _create_interface_view_for_context(self.context, state.on_state)
1050
+ return await state.on_state(ctx_to_pass)
1051
+
1052
+ async def _call_on_enter(self, state: StateSpec):
1053
+ """Call the state's on_enter handler"""
1054
+ if state.on_enter is not None:
1055
+ self.log(f" [DEBUG] Calling on_enter for {state.name}")
1056
+ # Create interface view if handler has interface type hint
1057
+ ctx_to_pass = _create_interface_view_for_context(self.context, state.on_enter)
1058
+ await state.on_enter(ctx_to_pass)
1059
+ self.log(f" [DEBUG] on_enter completed for {state.name}")
1060
+
1061
+ async def _call_on_leave(self, state: StateSpec):
1062
+ """Call the state's on_leave handler"""
1063
+ if state.on_leave is not None:
1064
+ # Create interface view if handler has interface type hint
1065
+ ctx_to_pass = _create_interface_view_for_context(self.context, state.on_leave)
1066
+ await state.on_leave(ctx_to_pass)
1067
+
1068
+ async def _call_on_timeout(self, state: StateSpec):
1069
+ """Call the state's on_timeout handler"""
1070
+ if state.on_timeout is None:
1071
+ raise ValueError(f"State {state.name} has no on_timeout handler")
1072
+ # Create interface view if handler has interface type hint
1073
+ ctx_to_pass = _create_interface_view_for_context(self.context, state.on_timeout)
1074
+ return await state.on_timeout(ctx_to_pass)
1075
+
1076
+ async def _call_on_fail(self, state: StateSpec):
1077
+ """Call the state's on_fail handler"""
1078
+ if state.on_fail is None:
1079
+ raise ValueError(f"State {state.name} has no on_fail handler")
1080
+ # Create interface view if handler has interface type hint
1081
+ ctx_to_pass = _create_interface_view_for_context(self.context, state.on_fail)
1082
+ return await state.on_fail(ctx_to_pass)
1083
+
1084
+ async def _process_transition(self, transition: Any, block: bool = False) -> bool:
1085
+ """Process a state transition"""
1086
+
1087
+ valid_transitions = self.registry.get_transitions(self.current_state)
1088
+
1089
+ # Try to get the transition target
1090
+ target = valid_transitions.get(transition)
1091
+
1092
+ # If not found by the transition itself, try by value
1093
+ if target is None and transition in valid_transitions.values():
1094
+ target = transition
1095
+ elif transition is not None and target is None:
1096
+ raise ValueError(f"Unknown transition {transition}: {valid_transitions}")
1097
+
1098
+ # Process the transition
1099
+
1100
+ if transition in (None, Unhandled):
1101
+ if (
1102
+ self.current_state.config.can_dwell
1103
+ ) or self.current_state.config.timeout:
1104
+ return True
1105
+ raise BlockedInUntimedState(
1106
+ f"{self.current_state.name} cannot dwell and does not have a timeout"
1107
+ )
1108
+ elif isinstance(target, StateSpec):
1109
+ await self._transition_to_state(target)
1110
+ elif isinstance(target, Push):
1111
+ await self._handle_push(target)
1112
+ elif isinstance(target, Pop) or target is Pop:
1113
+ await self._handle_pop()
1114
+ elif target == Again:
1115
+ # Re-execute current state immediately
1116
+ # In manual tick mode (block=False), stop after one execution
1117
+ # In automatic mode (block=True), continue the loop
1118
+ if block:
1119
+ # Automatic mode: continue the loop
1120
+ return True
1121
+ else:
1122
+ # Manual mode: stop after this execution
1123
+ return False
1124
+ elif target == Repeat:
1125
+ await self._transition_to_state(self.current_state)
1126
+ elif target == Restart:
1127
+ self._reset_state_context()
1128
+ await self._transition_to_state(self.current_state)
1129
+ return True
1130
+ elif target == Retry:
1131
+ await self._handle_retry()
1132
+ elif isinstance(target, (StateContinuation, StateRenewal)):
1133
+ # Other continuations/renewals - handle appropriately
1134
+ if target == Again:
1135
+ return True # Continue the loop
1136
+ elif target == Repeat:
1137
+ await self._transition_to_state(self.current_state)
1138
+ elif target == Restart:
1139
+ self._reset_state_context()
1140
+ await self._transition_to_state(self.current_state)
1141
+ return True
1142
+ elif target == Retry:
1143
+ await self._handle_retry()
1144
+
1145
+ return False
1146
+
1147
+ async def _transition_to_state(self, next_state: StateSpec):
1148
+ """Transition to a new state"""
1149
+
1150
+ if self.current_state:
1151
+ await self._call_on_leave(self.current_state)
1152
+ self._cancel_timeout()
1153
+
1154
+ self.current_state = next_state
1155
+ self.state_enter_times[next_state.name] = time.time()
1156
+ self.log(f"Transitioned to {next_state.name}")
1157
+ self.log(f" [DEBUG] common object id: {id(self.context.common)}")
1158
+ self.log(f" [DEBUG] on_enter is None? {next_state.on_enter is None}")
1159
+
1160
+ self._start_timeout(next_state)
1161
+ await self._call_on_enter(next_state)
1162
+ self.log(f" [DEBUG] After on_enter, common: {self.context.common}")
1163
+
1164
+ def _start_timeout(self, state: StateSpec) -> Optional[int]:
1165
+ """Start a timeout for the given state"""
1166
+
1167
+ if state.config.timeout is None:
1168
+ return None
1169
+
1170
+ self._cancel_timeout()
1171
+
1172
+ self._timeout_counter += 1
1173
+ timeout_id = self._timeout_counter
1174
+ self._current_timeout_id = timeout_id
1175
+
1176
+ async def timeout_handler():
1177
+ try:
1178
+ await asyncio.sleep(state.config.timeout)
1179
+ self._send_timeout_message(state.name, timeout_id, state.config.timeout)
1180
+ except asyncio.CancelledError:
1181
+ pass
1182
+
1183
+ self._timeout_task = asyncio.create_task(timeout_handler())
1184
+ return timeout_id
1185
+
1186
+ def _cancel_timeout(self):
1187
+ """Cancel any active timeout"""
1188
+ if self._timeout_task is not None:
1189
+ self._timeout_task.cancel()
1190
+ self._timeout_task = None
1191
+ self._current_timeout_id = None
1192
+
1193
+ def _send_timeout_message(self, state_name: str, timeout_id: int, duration: float):
1194
+ """Internal method to send timeout messages"""
1195
+
1196
+ timeout_msg = TimeoutMessage(state_name, timeout_id, duration)
1197
+ self._send_message_internal(timeout_msg)
1198
+
1199
+ def _handle_timeout_message(self) -> bool:
1200
+ """Handle a timeout message"""
1201
+
1202
+ timeout_msg = self.context.msg
1203
+ self.context.msg = None
1204
+
1205
+ if timeout_msg.timeout_id != self._current_timeout_id:
1206
+ self.log(f"Ignoring stale timeout {timeout_msg.timeout_id}")
1207
+ return False
1208
+
1209
+ if timeout_msg.state_name != self.current_state.name:
1210
+ self.log(
1211
+ f"Ignoring timeout for {timeout_msg.state_name}, current state is {self.current_state.name}"
1212
+ )
1213
+ return False
1214
+
1215
+ self.log(
1216
+ f"Handling timeout for {timeout_msg.state_name} ({timeout_msg.timeout_duration}s)"
1217
+ )
1218
+ self._cancel_timeout()
1219
+ return True
1220
+
1221
+ async def _handle_retry(self):
1222
+ """Handle retry logic with counter"""
1223
+ state_name = self.current_state.name
1224
+
1225
+ if state_name not in self.retry_counters:
1226
+ self.retry_counters[state_name] = 0
1227
+
1228
+ self.retry_counters[state_name] += 1
1229
+
1230
+ if (
1231
+ self.current_state.config.retries is not None
1232
+ and self.retry_counters[state_name] > self.current_state.config.retries
1233
+ ):
1234
+ self.log(f"Retry limit exceeded for {state_name}")
1235
+ transition = await self._call_on_fail(self.current_state)
1236
+ await self._process_transition(transition)
1237
+ else:
1238
+ await self._transition_to_state(self.current_state)
1239
+
1240
+ def _reset_state_context(self):
1241
+ """Reset context for current state"""
1242
+ state_name = self.current_state.name
1243
+ if state_name in self.retry_counters:
1244
+ del self.retry_counters[state_name]
1245
+
1246
+ async def _handle_push(self, push: Push):
1247
+ """Handle push transition"""
1248
+ resolved_states = [self.registry.resolve_state(s) for s in push.push_states]
1249
+
1250
+ for state in reversed(resolved_states[1:]):
1251
+ self.state_stack.append(state)
1252
+
1253
+ await self._transition_to_state(resolved_states[0])
1254
+
1255
+ async def _handle_pop(self):
1256
+ """Handle pop transition"""
1257
+ if not self.state_stack:
1258
+ raise PopFromEmptyStack()
1259
+
1260
+ next_state = self.state_stack.pop()
1261
+ await self._transition_to_state(next_state)
1262
+
1263
+ async def run(
1264
+ self, max_iterations: Optional[int] = None, timeout: Optional[float] = 1
1265
+ ):
1266
+ """Run the state machine until terminal state"""
1267
+ if not self._initialized:
1268
+ await self.initialize()
1269
+
1270
+ iteration = 0
1271
+
1272
+ while True:
1273
+ if max_iterations is not None and iteration >= max_iterations:
1274
+ break
1275
+
1276
+ if self.current_state.config.terminal:
1277
+ self.log(f"Reached terminal state: {self.current_state.name}")
1278
+ break
1279
+
1280
+ try:
1281
+ await self.tick(timeout=timeout)
1282
+ iteration += 1
1283
+ except StateMachineComplete:
1284
+ raise
1285
+ except Exception as e:
1286
+ self.log(f"Error in FSM: {e}")
1287
+ if self.error_state:
1288
+ await self._transition_to_state(self.error_state)
1289
+ else:
1290
+ raise
1291
+
1292
+
1293
+ # ============================================================================
1294
+ # Decorators
1295
+ # ============================================================================
1296
+
1297
+ class _EnokiDecoratorAPI:
1298
+ """Decorator API for defining states"""
1299
+
1300
+ def __init__(self):
1301
+ self._tracking_stack: List[List[Tuple[str, Any]]] = []
1302
+
1303
+ def state(
1304
+ self,
1305
+ config: StateConfiguration = StateConfiguration(),
1306
+ name: Optional[str] = None,
1307
+ group: Optional[str] = None,
1308
+ ) -> Callable[[Callable], StateSpec]:
1309
+ """
1310
+ Decorator to define a state.
1311
+
1312
+ Args:
1313
+ config: StateConfiguration for this state
1314
+ name: Optional name (auto-generated if not provided)
1315
+ group: Optional group name for organization
1316
+
1317
+ The decorated function should contain nested decorated functions
1318
+ (@on_state, @on_enter, @transitions, etc.). No need to return a dict!
1319
+ """
1320
+
1321
+ def decorator(func: Callable[..., Any]) -> StateSpec:
1322
+ # Get function metadata
1323
+ module = func.__module__
1324
+ qualname = func.__qualname__
1325
+
1326
+ # Auto-generate name if not provided
1327
+ if name is None:
1328
+ # Handle __main__ module
1329
+ if module == "__main__":
1330
+ filename = inspect.getfile(func)
1331
+ module = os.path.splitext(os.path.basename(filename))[0]
1332
+ state_name = f"{module}.{qualname}"
1333
+ else:
1334
+ state_name = name
1335
+
1336
+ # Set up tracking for this state
1337
+ tracked_items = []
1338
+ self._tracking_stack.append(tracked_items)
1339
+
1340
+ try:
1341
+ # Call the function to execute inner decorators
1342
+ func()
1343
+ finally:
1344
+ # Always pop the tracking stack
1345
+ self._tracking_stack.pop()
1346
+
1347
+ # Extract decorated methods from tracked items
1348
+ on_state = None
1349
+ on_enter = None
1350
+ on_leave = None
1351
+ on_timeout = None
1352
+ on_fail = None
1353
+ transitions = None
1354
+ Events = None
1355
+
1356
+ for item_name, item_fn in tracked_items:
1357
+ if hasattr(item_fn, "_enoki_on_state"):
1358
+ on_state = item_fn
1359
+ elif hasattr(item_fn, "_enoki_on_enter"):
1360
+ on_enter = item_fn
1361
+ elif hasattr(item_fn, "_enoki_on_leave"):
1362
+ on_leave = item_fn
1363
+ elif hasattr(item_fn, "_enoki_on_timeout"):
1364
+ on_timeout = item_fn
1365
+ elif hasattr(item_fn, "_enoki_on_fail"):
1366
+ on_fail = item_fn
1367
+ elif hasattr(item_fn, "_enoki_transitions"):
1368
+ transitions = item_fn
1369
+ elif hasattr(item_fn, "_enoki_events"):
1370
+ Events = item_fn
1371
+
1372
+ # Validate required methods
1373
+ if on_state is None:
1374
+ raise ValueError(f"State '{state_name}' must have an @on_state decorated method")
1375
+
1376
+ # Create the StateSpec
1377
+ state_spec = StateSpec(
1378
+ name=state_name,
1379
+ qualname=qualname,
1380
+ module=module,
1381
+ config=config,
1382
+ on_state=on_state,
1383
+ on_enter=on_enter,
1384
+ on_leave=on_leave,
1385
+ on_timeout=on_timeout,
1386
+ on_fail=on_fail,
1387
+ transitions=transitions,
1388
+ Events=Events,
1389
+ group_name=group or "",
1390
+ )
1391
+
1392
+ # Register the state
1393
+ register_state(state_spec)
1394
+
1395
+ return state_spec
1396
+
1397
+ return decorator
1398
+
1399
+ def on_state(self, func: Callable) -> Callable:
1400
+ """Decorator for the main state logic method"""
1401
+ func._enoki_on_state = func
1402
+ if self._tracking_stack:
1403
+ self._tracking_stack[-1].append((func.__name__, func))
1404
+ return func
1405
+
1406
+ def on_enter(self, func: Callable) -> Callable:
1407
+ """Decorator for on_enter lifecycle method"""
1408
+ func._enoki_on_enter = func
1409
+ if self._tracking_stack:
1410
+ self._tracking_stack[-1].append((func.__name__, func))
1411
+ return func
1412
+
1413
+ def on_leave(self, func: Callable) -> Callable:
1414
+ """Decorator for on_leave lifecycle method"""
1415
+ func._enoki_on_leave = func
1416
+ if self._tracking_stack:
1417
+ self._tracking_stack[-1].append((func.__name__, func))
1418
+ return func
1419
+
1420
+ def on_timeout(self, func: Callable) -> Callable:
1421
+ """Decorator for on_timeout lifecycle method"""
1422
+ func._enoki_on_timeout = func
1423
+ if self._tracking_stack:
1424
+ self._tracking_stack[-1].append((func.__name__, func))
1425
+ return func
1426
+
1427
+ def on_fail(self, func: Callable) -> Callable:
1428
+ """Decorator for on_fail lifecycle method"""
1429
+ func._enoki_on_fail = func
1430
+ if self._tracking_stack:
1431
+ self._tracking_stack[-1].append((func.__name__, func))
1432
+ return func
1433
+
1434
+ def transitions(self, func: Callable) -> Callable:
1435
+ """Decorator for declaring transitions"""
1436
+ func._enoki_transitions = func
1437
+ if self._tracking_stack:
1438
+ self._tracking_stack[-1].append((func.__name__, func))
1439
+ return func
1440
+
1441
+ def events(self, events_class: type) -> type:
1442
+ """
1443
+ Decorator for registering the Events enum.
1444
+
1445
+ Optional decorator that makes the Events enum accessible via state.Events
1446
+ for introspection and testing. The Events enum works via closure even
1447
+ without this decorator, but state.Events will be None.
1448
+
1449
+ Use this decorator when:
1450
+ - You need to access state.Events programmatically (e.g., in tests)
1451
+ - You want to inspect what events a state defines
1452
+ - Registry methods need to reference the Events enum
1453
+
1454
+ Skip this decorator when:
1455
+ - You only use Events within on_state/transitions (closure handles it)
1456
+ - You don't need external access to the Events enum
1457
+
1458
+ Example:
1459
+ @enoki.state()
1460
+ def MyState():
1461
+ @enoki.events # Optional - makes MyState.Events accessible
1462
+ class Events(Enum):
1463
+ GO = auto()
1464
+
1465
+ @enoki.on_state
1466
+ async def on_state(ctx):
1467
+ return Events.GO # Works via closure even without @enoki.events
1468
+ """
1469
+ events_class._enoki_events = events_class
1470
+ if self._tracking_stack:
1471
+ self._tracking_stack[-1].append(("Events", events_class))
1472
+ return events_class
1473
+
1474
+ def root(self, func: Callable) -> Callable:
1475
+ """Decorator to mark a state as the initial/root state"""
1476
+ func._enoki_is_root = True
1477
+ return func
1478
+
1479
+
1480
+ # Create the decorator API instance
1481
+ enoki = _EnokiDecoratorAPI()
1482
+
1483
+ # Export decorators for convenience
1484
+ state = enoki.state
1485
+ on_state = enoki.on_state
1486
+ on_enter = enoki.on_enter
1487
+ on_leave = enoki.on_leave
1488
+ on_timeout = enoki.on_timeout
1489
+ on_fail = enoki.on_fail
1490
+ transitions = enoki.transitions
1491
+ events = enoki.events
1492
+ root = enoki.root
1493
+
1494
+
1495
+ # ============================================================================
1496
+ # Export all public components
1497
+ # ============================================================================
1498
+
1499
+ __all__ = [
1500
+ # Decorators
1501
+ "enoki",
1502
+ "state",
1503
+ "on_state",
1504
+ "on_enter",
1505
+ "on_leave",
1506
+ "on_timeout",
1507
+ "on_fail",
1508
+ "transitions",
1509
+ "root",
1510
+ # Core types
1511
+ "StateSpec",
1512
+ "StateConfiguration",
1513
+ "SharedContext",
1514
+ # Transitions
1515
+ "TransitionType",
1516
+ "StateTransition",
1517
+ "StateContinuation",
1518
+ "StateRenewal",
1519
+ "Again",
1520
+ "Unhandled",
1521
+ "Retry",
1522
+ "Restart",
1523
+ "Repeat",
1524
+ "StateRef",
1525
+ "LabeledTransition",
1526
+ "Push",
1527
+ "Pop",
1528
+ # Registry
1529
+ "register_state",
1530
+ "get_state",
1531
+ "get_all_states",
1532
+ "StateRegistry",
1533
+ "ValidationError",
1534
+ "ValidationResult",
1535
+ # StateMachine
1536
+ "StateMachine",
1537
+ "StateMachineComplete",
1538
+ "BlockedInUntimedState",
1539
+ "PopFromEmptyStack",
1540
+ "NoStateToTick",
1541
+ # Messages
1542
+ "PrioritizedMessage",
1543
+ "EnokiInternalMessage",
1544
+ "TimeoutMessage",
1545
+ ]