flatagents 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1176 @@
1
+ """
2
+ FlatMachine - State machine orchestration for FlatAgents.
3
+
4
+ A machine defines how agents are connected and executed:
5
+ states, transitions, conditions, and loops.
6
+
7
+ See local/flatmachines-plan.md for the full specification.
8
+ """
9
+
10
+ import asyncio
11
+ import importlib
12
+ import json
13
+ import os
14
+ import re
15
+ from typing import Any, Dict, Optional
16
+
17
+ try:
18
+ from jinja2 import Template
19
+ except ImportError:
20
+ Template = None
21
+
22
+ try:
23
+ import yaml
24
+ except ImportError:
25
+ yaml = None
26
+
27
+ from .monitoring import get_logger
28
+ from .expressions import get_expression_engine, ExpressionEngine
29
+ from .execution import get_execution_type, ExecutionType
30
+ from .hooks import MachineHooks, LoggingHooks
31
+ from .flatagent import FlatAgent
32
+
33
+ import uuid
34
+ from .persistence import (
35
+ PersistenceBackend,
36
+ LocalFileBackend,
37
+ MemoryBackend,
38
+ CheckpointManager,
39
+ MachineSnapshot
40
+ )
41
+ from .backends import (
42
+ ResultBackend,
43
+ InMemoryResultBackend,
44
+ LaunchIntent,
45
+ make_uri,
46
+ get_default_result_backend,
47
+ )
48
+ from .locking import ExecutionLock, LocalFileLock, NoOpLock
49
+ from .actions import (
50
+ Action,
51
+ HookAction,
52
+ MachineInvoker,
53
+ InlineInvoker,
54
+ QueueInvoker,
55
+ )
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ class FlatMachine:
61
+ """
62
+ State machine orchestration for FlatAgents.
63
+
64
+ Executes a sequence of states, evaluations transitions and
65
+ managing context flow between agents.
66
+
67
+ Supports:
68
+ - Persistence (checkpoint/resume)
69
+ - Concurrency control (locking)
70
+ - Machine launching (peer machine execution)
71
+ """
72
+
73
+ SPEC_VERSION = "0.4.0"
74
+
75
+ def __init__(
76
+ self,
77
+ config_file: Optional[str] = None,
78
+ config_dict: Optional[Dict] = None,
79
+ hooks: Optional[MachineHooks] = None,
80
+ persistence: Optional[PersistenceBackend] = None,
81
+ lock: Optional[ExecutionLock] = None,
82
+ invoker: Optional[MachineInvoker] = None,
83
+ result_backend: Optional[ResultBackend] = None,
84
+ **kwargs
85
+ ):
86
+ """
87
+ Initialize the machine.
88
+
89
+ Args:
90
+ config_file: Path to YAML/JSON config file
91
+ config_dict: Configuration dictionary
92
+ hooks: Custom hooks for extensibility
93
+ persistence: Storage backend (overrides config)
94
+ lock: Concurrency lock (overrides config)
95
+ invoker: Strategy for invoking other machines
96
+ result_backend: Backend for inter-machine result storage
97
+ **kwargs: Override config values
98
+ """
99
+ if Template is None:
100
+ raise ImportError("jinja2 is required. Install with: pip install jinja2")
101
+
102
+ # Extract execution_id if passed (for launched machines)
103
+ self.execution_id = kwargs.pop('_execution_id', None) or str(uuid.uuid4())
104
+ self.parent_execution_id = kwargs.pop('_parent_execution_id', None)
105
+
106
+ # Extract _config_dir override (used for launched machines)
107
+ config_dir_override = kwargs.pop('_config_dir', None)
108
+
109
+ self._load_config(config_file, config_dict)
110
+
111
+ # Allow launcher to override config_dir for launched machines
112
+ if config_dir_override:
113
+ self._config_dir = config_dir_override
114
+
115
+ # Merge kwargs into config data (shallow merge)
116
+ if kwargs and 'data' in self.config:
117
+ self.config['data'].update(kwargs)
118
+
119
+ self._validate_spec()
120
+ self._parse_machine_config()
121
+
122
+ # Set up Jinja2 environment with custom filters
123
+ from jinja2 import Environment
124
+ import json
125
+
126
+ def _json_finalize(value):
127
+ """Auto-serialize lists and dicts to JSON in Jinja2 output.
128
+
129
+ This ensures {{ output.items }} renders as ["a", "b"] (valid JSON)
130
+ instead of ['a', 'b'] (Python repr), allowing json.loads() to work.
131
+ """
132
+ if isinstance(value, (list, dict)):
133
+ return json.dumps(value)
134
+ return value
135
+
136
+ self._jinja_env = Environment(finalize=_json_finalize)
137
+ # Add fromjson filter for parsing JSON strings in templates
138
+ # Usage: {% for item in context.items | fromjson %}
139
+ self._jinja_env.filters['fromjson'] = json.loads
140
+
141
+ # Set up expression engine
142
+ expression_mode = self.data.get("expression_engine", "simple")
143
+ self._expression_engine = get_expression_engine(expression_mode)
144
+
145
+ # Hooks - load from config or use provided/default
146
+ self._hooks = self._load_hooks(hooks)
147
+
148
+ # Agent cache
149
+ self._agents: Dict[str, FlatAgent] = {}
150
+
151
+ # Execution tracking
152
+ self.total_api_calls = 0
153
+ self.total_cost = 0.0
154
+
155
+ # Persistence & Locking
156
+ self._initialize_persistence(persistence, lock)
157
+
158
+ # Result backend for inter-machine communication
159
+ self.result_backend = result_backend or get_default_result_backend()
160
+
161
+ # Pending launches (outbox pattern)
162
+ self._pending_launches: list[LaunchIntent] = []
163
+
164
+ # Background tasks for fire-and-forget launches
165
+ self._background_tasks: set[asyncio.Task] = set()
166
+
167
+ # Invoker (for launching peer machines)
168
+ self.invoker = invoker or InlineInvoker()
169
+
170
+ logger.info(f"Initialized FlatMachine: {self.machine_name} (ID: {self.execution_id})")
171
+
172
+ def _initialize_persistence(
173
+ self,
174
+ persistence: Optional[PersistenceBackend],
175
+ lock: Optional[ExecutionLock]
176
+ ) -> None:
177
+ """Initialize persistence and locking components."""
178
+ # Get config
179
+ p_config = self.data.get('persistence', {})
180
+ # Global features config override (simulated for now, would be in kwargs/settings)
181
+ # For now, rely on machine.yml or defaults
182
+
183
+ enabled = p_config.get('enabled', True) # Default enabled? Or disable?
184
+ # Plan says: "Global Defaults... backend: local".
185
+ # Let's default to enabled=False for backward compat if not configured?
186
+ # Or follow plan default? Plan implies explicit configure.
187
+ # Let's default to MemoryBackend if enabled but no backend specified
188
+
189
+ backend_type = p_config.get('backend', 'memory')
190
+
191
+ # Persistence Backend
192
+ if persistence:
193
+ self.persistence = persistence
194
+ elif not enabled:
195
+ self.persistence = MemoryBackend() # Fallback, unsaved
196
+ elif backend_type == 'local':
197
+ self.persistence = LocalFileBackend()
198
+ elif backend_type == 'memory':
199
+ self.persistence = MemoryBackend()
200
+ else:
201
+ logger.warning(f"Unknown backend '{backend_type}', using memory")
202
+ self.persistence = MemoryBackend()
203
+
204
+ # Lock
205
+ if lock:
206
+ self.lock = lock
207
+ elif not enabled:
208
+ self.lock = NoOpLock()
209
+ elif backend_type == 'local':
210
+ self.lock = LocalFileLock()
211
+ else:
212
+ self.lock = NoOpLock()
213
+
214
+ # Checkpoint events (default set)
215
+ default_events = ['machine_start', 'state_enter', 'execute', 'state_exit', 'machine_end']
216
+ self.checkpoint_events = set(p_config.get('checkpoint_on', default_events))
217
+
218
+
219
+ def _load_config(
220
+ self,
221
+ config_file: Optional[str],
222
+ config_dict: Optional[Dict]
223
+ ) -> None:
224
+ """Load configuration from file or dict."""
225
+ config = {}
226
+
227
+ if config_file is not None:
228
+ if not os.path.exists(config_file):
229
+ raise FileNotFoundError(f"Config file not found: {config_file}")
230
+
231
+ with open(config_file, 'r') as f:
232
+ if config_file.endswith('.json'):
233
+ config = json.load(f) or {}
234
+ else:
235
+ if yaml is None:
236
+ raise ImportError("pyyaml required for YAML files")
237
+ config = yaml.safe_load(f) or {}
238
+
239
+ # Store config file path for relative agent references
240
+ self._config_dir = os.path.dirname(os.path.abspath(config_file))
241
+ elif config_dict is not None:
242
+ config = config_dict
243
+ self._config_dir = os.getcwd()
244
+ else:
245
+ raise ValueError("Must provide config_file or config_dict")
246
+
247
+ self.config = config
248
+
249
+ def _validate_spec(self) -> None:
250
+ """Validate the spec envelope."""
251
+ spec = self.config.get('spec')
252
+ if spec != 'flatmachine':
253
+ raise ValueError(
254
+ f"Invalid spec: expected 'flatmachine', got '{spec}'"
255
+ )
256
+
257
+ if 'data' not in self.config:
258
+ raise ValueError("Config missing 'data' section")
259
+
260
+ # Version check with warning
261
+ self.spec_version = self.config.get('spec_version', '0.1.0')
262
+ major_minor = '.'.join(self.spec_version.split('.')[:2])
263
+ if major_minor not in ['0.1']:
264
+ logger.warning(
265
+ f"Config version {self.spec_version} may not be fully supported. "
266
+ f"Current SDK supports 0.1.x."
267
+ )
268
+
269
+ # Schema validation (warnings only, non-blocking)
270
+ try:
271
+ from .validation import validate_flatmachine_config
272
+ validate_flatmachine_config(self.config, warn=True, strict=False)
273
+ except ImportError:
274
+ pass # jsonschema not installed, skip validation
275
+
276
+ def _parse_machine_config(self) -> None:
277
+ """Parse the machine configuration."""
278
+ self.data = self.config['data']
279
+ self.metadata = self.config.get('metadata', {})
280
+
281
+ self.machine_name = self.data.get('name', 'unnamed-machine')
282
+ self.initial_context = self.data.get('context', {})
283
+ self.agent_refs = self.data.get('agents', {})
284
+ self.machine_refs = self.data.get('machines', {})
285
+ self.states = self.data.get('states', {})
286
+ self.settings = self.data.get('settings', {})
287
+
288
+ # Find initial and final states
289
+ self._initial_state = None
290
+ self._final_states = set()
291
+
292
+ for name, state in self.states.items():
293
+ if state.get('type') == 'initial':
294
+ if self._initial_state is not None:
295
+ raise ValueError("Multiple initial states defined")
296
+ self._initial_state = name
297
+ if state.get('type') == 'final':
298
+ self._final_states.add(name)
299
+
300
+ if self._initial_state is None:
301
+ # Default to 'start' if exists, otherwise first state
302
+ if 'start' in self.states:
303
+ self._initial_state = 'start'
304
+ elif self.states:
305
+ self._initial_state = next(iter(self.states))
306
+ else:
307
+ raise ValueError("No states defined")
308
+
309
+ def _load_hooks(self, hooks: Optional[MachineHooks]) -> MachineHooks:
310
+ """
311
+ Load hooks from config or use provided/default.
312
+
313
+ Config format (file-based, preferred for self-contained skills):
314
+ hooks:
315
+ file: "./hooks.py"
316
+ class: "MyHooks"
317
+ args:
318
+ working_dir: "."
319
+
320
+ Config format (module-based, for installed packages):
321
+ hooks:
322
+ module: "mypackage.hooks"
323
+ class: "MyHooks"
324
+ args:
325
+ working_dir: "{{ input.working_dir }}"
326
+
327
+ Priority:
328
+ 1. Explicitly passed hooks argument (for programmatic use)
329
+ 2. file: in config (file-based loading)
330
+ 3. module: in config (Python import)
331
+ 4. Default MachineHooks()
332
+ """
333
+ # If hooks explicitly passed, use them
334
+ if hooks is not None:
335
+ return hooks
336
+
337
+ # Check for hooks config
338
+ hooks_config = self.data.get('hooks')
339
+ if not hooks_config:
340
+ return MachineHooks()
341
+
342
+ class_name = hooks_config.get('class')
343
+ if not class_name:
344
+ logger.warning(
345
+ f"Hooks config missing 'class', using default. Config: {hooks_config}"
346
+ )
347
+ return MachineHooks()
348
+
349
+ hooks_class = None
350
+
351
+ # Try file-based loading first (for self-contained skills)
352
+ file_path = hooks_config.get('file')
353
+ if file_path:
354
+ if not os.path.isabs(file_path):
355
+ file_path = os.path.join(self._config_dir, file_path)
356
+ try:
357
+ spec = importlib.util.spec_from_file_location("hooks", file_path)
358
+ if spec and spec.loader:
359
+ module = importlib.util.module_from_spec(spec)
360
+ spec.loader.exec_module(module)
361
+ hooks_class = getattr(module, class_name)
362
+ except Exception as e:
363
+ logger.error(f"Failed to load hooks from file {file_path}: {e}")
364
+
365
+ # Fall back to module import
366
+ if hooks_class is None:
367
+ module_name = hooks_config.get('module')
368
+ if not module_name:
369
+ logger.warning(
370
+ f"Hooks config has no 'file' or 'module', using default. "
371
+ f"Config: {hooks_config}"
372
+ )
373
+ return MachineHooks()
374
+ try:
375
+ module = importlib.import_module(module_name)
376
+ hooks_class = getattr(module, class_name)
377
+ except (ImportError, AttributeError) as e:
378
+ logger.error(
379
+ f"Failed to load hooks class {class_name} from {module_name}: {e}"
380
+ )
381
+ return MachineHooks()
382
+
383
+ # Get args (note: can't render templates here as input not yet available)
384
+ # Args are passed raw - the hooks class should handle any needed resolution
385
+ hooks_args = hooks_config.get('args', {})
386
+
387
+ try:
388
+ return hooks_class(**hooks_args)
389
+ except Exception as e:
390
+ logger.error(f"Failed to instantiate hooks class {class_name}: {e}")
391
+ return MachineHooks()
392
+
393
+ def _get_agent(self, agent_name: str) -> FlatAgent:
394
+ """Get or load an agent by name."""
395
+ if agent_name in self._agents:
396
+ return self._agents[agent_name]
397
+
398
+ if agent_name not in self.agent_refs:
399
+ raise ValueError(f"Unknown agent: {agent_name}")
400
+
401
+ agent_ref = self.agent_refs[agent_name]
402
+
403
+ # Handle file path reference
404
+ if isinstance(agent_ref, str):
405
+ if not os.path.isabs(agent_ref):
406
+ agent_ref = os.path.join(self._config_dir, agent_ref)
407
+ agent = FlatAgent(config_file=agent_ref)
408
+ # Handle inline config (dict)
409
+ elif isinstance(agent_ref, dict):
410
+ agent = FlatAgent(config_dict=agent_ref)
411
+ else:
412
+ raise ValueError(f"Invalid agent reference: {agent_ref}")
413
+
414
+ self._agents[agent_name] = agent
415
+ return agent
416
+
417
+ # Pattern for simple path references: output.foo, context.bar.baz, input.x
418
+ _PATH_PATTERN = re.compile(r'^(output|context|input)(\.[a-zA-Z_][a-zA-Z0-9_]*)+$')
419
+
420
+ def _resolve_path(self, path: str, variables: Dict[str, Any]) -> Any:
421
+ """Resolve a dotted path like 'output.chapters' to its value."""
422
+ parts = path.split('.')
423
+ value = variables
424
+ for part in parts:
425
+ if isinstance(value, dict):
426
+ value = value.get(part)
427
+ else:
428
+ return None
429
+ return value
430
+
431
+ def _render_template(self, template_str: str, variables: Dict[str, Any]) -> Any:
432
+ """Render a Jinja2 template string or resolve a simple path reference."""
433
+ if not isinstance(template_str, str):
434
+ return template_str
435
+
436
+ # Check if it's a template ({{ for expressions, {% for control flow)
437
+ if '{{' not in template_str and '{%' not in template_str:
438
+ # Check if it's a simple path reference like "output.chapters"
439
+ # This allows direct value passing without Jinja2 string conversion
440
+ if self._PATH_PATTERN.match(template_str.strip()):
441
+ return self._resolve_path(template_str.strip(), variables)
442
+ return template_str
443
+
444
+ template = self._jinja_env.from_string(template_str)
445
+ result = template.render(**variables)
446
+
447
+ # Try to parse as JSON for complex types
448
+ try:
449
+ return json.loads(result)
450
+ except (json.JSONDecodeError, TypeError):
451
+ return result
452
+
453
+ def _render_dict(self, data: Dict[str, Any], variables: Dict[str, Any]) -> Dict[str, Any]:
454
+ """Recursively render all template strings in a dict."""
455
+ result = {}
456
+ for key, value in data.items():
457
+ if isinstance(value, str):
458
+ result[key] = self._render_template(value, variables)
459
+ elif isinstance(value, dict):
460
+ result[key] = self._render_dict(value, variables)
461
+ elif isinstance(value, list):
462
+ result[key] = [
463
+ self._render_template(v, variables) if isinstance(v, str) else v
464
+ for v in value
465
+ ]
466
+ else:
467
+ result[key] = value
468
+ return result
469
+
470
+ def _evaluate_condition(self, condition: str, context: Dict[str, Any]) -> bool:
471
+ """Evaluate a transition condition."""
472
+ variables = {"context": context}
473
+ return bool(self._expression_engine.evaluate(condition, variables))
474
+
475
+ def _get_error_recovery_state(
476
+ self,
477
+ state_config: Dict[str, Any],
478
+ error: Exception
479
+ ) -> Optional[str]:
480
+ """
481
+ Get recovery state from on_error config.
482
+
483
+ Supports two formats:
484
+ - Simple: on_error: "error_state"
485
+ - Granular: on_error: {default: "error_state", RateLimitError: "retry_state"}
486
+ """
487
+ on_error = state_config.get('on_error')
488
+ if not on_error:
489
+ return None
490
+
491
+ # Simple format: on_error: "state_name"
492
+ if isinstance(on_error, str):
493
+ return on_error
494
+
495
+ # Granular format: on_error: {error_type: state_name, default: fallback}
496
+ error_type = type(error).__name__
497
+ return on_error.get(error_type) or on_error.get('default')
498
+
499
+ def _find_next_state(
500
+ self,
501
+ state_name: str,
502
+ context: Dict[str, Any]
503
+ ) -> Optional[str]:
504
+ """Find the next state based on transitions."""
505
+ state = self.states.get(state_name, {})
506
+ transitions = state.get('transitions', [])
507
+
508
+ for transition in transitions:
509
+ condition = transition.get('condition', '')
510
+ to_state = transition.get('to')
511
+
512
+ if not to_state:
513
+ continue
514
+
515
+ # No condition = default transition
516
+ if not condition:
517
+ return to_state
518
+
519
+ # Evaluate condition
520
+ if self._evaluate_condition(condition, context):
521
+ return to_state
522
+
523
+ return None
524
+
525
+ def _resolve_config(self, name: str) -> Dict[str, Any]:
526
+ """Resolve a component reference (agent/machine) to a config dict."""
527
+ ref = self.agent_refs.get(name)
528
+ if not ref:
529
+ raise ValueError(f"Unknown component reference: {name}")
530
+
531
+ if isinstance(ref, dict):
532
+ return ref
533
+
534
+ if isinstance(ref, str):
535
+ path = ref
536
+ if not os.path.isabs(path):
537
+ path = os.path.join(self._config_dir, path)
538
+
539
+ if not os.path.exists(path):
540
+ raise FileNotFoundError(f"Component file not found: {path}")
541
+
542
+ with open(path, 'r') as f:
543
+ if path.endswith('.json'):
544
+ return json.load(f) or {}
545
+ # Assume yaml
546
+ if yaml:
547
+ return yaml.safe_load(f) or {}
548
+ raise ImportError("pyyaml required for YAML files")
549
+
550
+ raise ValueError(f"Invalid reference type: {type(ref)}")
551
+
552
+ def _resolve_machine_config(self, name: str) -> tuple[Dict[str, Any], str]:
553
+ """
554
+ Resolve a machine reference to a config dict and its config directory.
555
+
556
+ Returns:
557
+ Tuple of (config_dict, config_dir) where config_dir is the directory
558
+ containing the machine config file (for resolving relative paths).
559
+ """
560
+ ref = self.machine_refs.get(name)
561
+ if not ref:
562
+ raise ValueError(f"Unknown machine reference: {name}. Check 'machines:' section in config.")
563
+
564
+ if isinstance(ref, dict):
565
+ # Inline config - use parent's config_dir
566
+ return ref, self._config_dir
567
+
568
+ if isinstance(ref, str):
569
+ path = ref
570
+ if not os.path.isabs(path):
571
+ path = os.path.join(self._config_dir, path)
572
+
573
+ if not os.path.exists(path):
574
+ raise FileNotFoundError(f"Machine config file not found: {path}")
575
+
576
+ # The peer's config_dir is the directory containing its config file
577
+ peer_config_dir = os.path.dirname(os.path.abspath(path))
578
+
579
+ with open(path, 'r') as f:
580
+ if path.endswith('.json'):
581
+ config = json.load(f) or {}
582
+ elif yaml:
583
+ config = yaml.safe_load(f) or {}
584
+ else:
585
+ raise ImportError("pyyaml required for YAML files")
586
+
587
+ return config, peer_config_dir
588
+
589
+ raise ValueError(f"Invalid machine reference type: {type(ref)}")
590
+
591
+ # =========================================================================
592
+ # Pending Launches (Outbox Pattern) - v0.4.0
593
+ # =========================================================================
594
+
595
+ def _add_pending_launch(
596
+ self,
597
+ execution_id: str,
598
+ machine: str,
599
+ input_data: Dict[str, Any]
600
+ ) -> LaunchIntent:
601
+ """Add a launch intent to the pending list (outbox pattern)."""
602
+ intent = LaunchIntent(
603
+ execution_id=execution_id,
604
+ machine=machine,
605
+ input=input_data,
606
+ launched=False
607
+ )
608
+ self._pending_launches.append(intent)
609
+ return intent
610
+
611
+ def _mark_launched(self, execution_id: str) -> None:
612
+ """Mark a pending launch as launched."""
613
+ for intent in self._pending_launches:
614
+ if intent.execution_id == execution_id:
615
+ intent.launched = True
616
+ break
617
+
618
+ def _clear_pending_launch(self, execution_id: str) -> None:
619
+ """Remove a completed launch from pending list."""
620
+ self._pending_launches = [
621
+ i for i in self._pending_launches
622
+ if i.execution_id != execution_id
623
+ ]
624
+
625
+ def _get_pending_intents(self) -> list[Dict[str, Any]]:
626
+ """Get pending launches as dicts for snapshot."""
627
+ return [intent.to_dict() for intent in self._pending_launches]
628
+
629
+ async def _resume_pending_launches(self) -> None:
630
+ """Resume any pending launches that weren't completed."""
631
+ for intent in self._pending_launches:
632
+ if intent.launched:
633
+ continue
634
+ # Check if child already has a result
635
+ uri = make_uri(intent.execution_id, "result")
636
+ if await self.result_backend.exists(uri):
637
+ continue
638
+ # Re-launch
639
+ logger.info(f"Resuming launch: {intent.machine} (ID: {intent.execution_id})")
640
+ task = asyncio.create_task(
641
+ self._launch_and_write(intent.machine, intent.execution_id, intent.input)
642
+ )
643
+ self._background_tasks.add(task)
644
+ task.add_done_callback(self._background_tasks.discard)
645
+
646
+ # =========================================================================
647
+ # Machine Invocation - v0.4.0
648
+ # =========================================================================
649
+
650
+ async def _launch_and_write(
651
+ self,
652
+ machine_name: str,
653
+ child_id: str,
654
+ input_data: Dict[str, Any]
655
+ ) -> Any:
656
+ """Launch a peer machine and write its result to the backend."""
657
+ target_config, peer_config_dir = self._resolve_machine_config(machine_name)
658
+
659
+ # Peer machines are independent - they load their own hooks from config
660
+ # (via the hooks: section in their machine.yml)
661
+ # Use peer's config_dir so relative paths (e.g., ./agents/judge.yml) resolve correctly
662
+ peer = FlatMachine(
663
+ config_dict=target_config,
664
+ result_backend=self.result_backend,
665
+ _config_dir=peer_config_dir,
666
+ _execution_id=child_id,
667
+ _parent_execution_id=self.execution_id,
668
+ )
669
+
670
+ try:
671
+ result = await peer.execute(input=input_data)
672
+ # Write result to backend
673
+ uri = make_uri(child_id, "result")
674
+ await self.result_backend.write(uri, result)
675
+ return result
676
+ except Exception as e:
677
+ # Write error to backend so parent knows
678
+ uri = make_uri(child_id, "result")
679
+ await self.result_backend.write(uri, {"_error": str(e), "_error_type": type(e).__name__})
680
+ raise
681
+
682
+ async def _invoke_machine_single(
683
+ self,
684
+ machine_name: str,
685
+ input_data: Dict[str, Any],
686
+ timeout: Optional[float] = None
687
+ ) -> Any:
688
+ """Invoke a single peer machine with blocking read."""
689
+ child_id = str(uuid.uuid4())
690
+
691
+ # Checkpoint intent (outbox pattern)
692
+ self._add_pending_launch(child_id, machine_name, input_data)
693
+
694
+ # Launch and execute
695
+ result = await self._launch_and_write(machine_name, child_id, input_data)
696
+
697
+ # Mark completed and clear
698
+ self._mark_launched(child_id)
699
+ self._clear_pending_launch(child_id)
700
+
701
+ return result
702
+
703
+ async def _invoke_machines_parallel(
704
+ self,
705
+ machines: list[str],
706
+ input_data: Dict[str, Any],
707
+ mode: str = "settled",
708
+ timeout: Optional[float] = None
709
+ ) -> Dict[str, Any]:
710
+ """Invoke multiple machines in parallel."""
711
+ child_ids = {m: str(uuid.uuid4()) for m in machines}
712
+
713
+ # Checkpoint all intents
714
+ for machine_name, child_id in child_ids.items():
715
+ self._add_pending_launch(child_id, machine_name, input_data)
716
+
717
+ # Launch all
718
+ tasks = {}
719
+ for machine_name, child_id in child_ids.items():
720
+ task = asyncio.create_task(
721
+ self._launch_and_write(machine_name, child_id, input_data)
722
+ )
723
+ tasks[machine_name] = task
724
+
725
+ results = {}
726
+ errors = {}
727
+
728
+ if mode == "settled":
729
+ # Wait for all to complete
730
+ gathered = await asyncio.gather(*tasks.values(), return_exceptions=True)
731
+ for machine_name, result in zip(tasks.keys(), gathered):
732
+ if isinstance(result, Exception):
733
+ errors[machine_name] = result
734
+ results[machine_name] = {"_error": str(result), "_error_type": type(result).__name__}
735
+ else:
736
+ results[machine_name] = result
737
+
738
+ elif mode == "any":
739
+ # Wait for first to complete
740
+ done, pending = await asyncio.wait(
741
+ tasks.values(),
742
+ return_when=asyncio.FIRST_COMPLETED,
743
+ timeout=timeout
744
+ )
745
+ # Find which machine finished
746
+ for machine_name, task in tasks.items():
747
+ if task in done:
748
+ try:
749
+ results[machine_name] = task.result()
750
+ except Exception as e:
751
+ results[machine_name] = {"_error": str(e), "_error_type": type(e).__name__}
752
+ break
753
+ # Let pending tasks continue in background
754
+ for task in pending:
755
+ self._background_tasks.add(task)
756
+ task.add_done_callback(self._background_tasks.discard)
757
+
758
+ # Clear pending launches
759
+ for child_id in child_ids.values():
760
+ self._mark_launched(child_id)
761
+ self._clear_pending_launch(child_id)
762
+
763
+ return results
764
+
765
+ async def _invoke_foreach(
766
+ self,
767
+ items: list,
768
+ as_var: str,
769
+ key_expr: Optional[str],
770
+ machine_name: str,
771
+ input_template: Dict[str, Any],
772
+ mode: str = "settled",
773
+ timeout: Optional[float] = None
774
+ ) -> Any:
775
+ """Invoke a machine for each item in a list."""
776
+ child_ids = {}
777
+ item_inputs = {}
778
+
779
+ for i, item in enumerate(items):
780
+ # Compute key
781
+ if key_expr:
782
+ variables = {as_var: item, "context": {}, "input": {}}
783
+ item_key = self._render_template(key_expr, variables)
784
+ else:
785
+ item_key = i
786
+
787
+ child_id = str(uuid.uuid4())
788
+ child_ids[item_key] = child_id
789
+
790
+ # Render input for this item
791
+ variables = {as_var: item, "context": {}, "input": {}}
792
+ item_input = self._render_dict(input_template, variables)
793
+ item_inputs[item_key] = item_input
794
+
795
+ self._add_pending_launch(child_id, machine_name, item_input)
796
+
797
+ # Launch all
798
+ tasks = {}
799
+ for item_key, child_id in child_ids.items():
800
+ task = asyncio.create_task(
801
+ self._launch_and_write(machine_name, child_id, item_inputs[item_key])
802
+ )
803
+ tasks[item_key] = task
804
+
805
+ results = {}
806
+
807
+ if mode == "settled":
808
+ gathered = await asyncio.gather(*tasks.values(), return_exceptions=True)
809
+ for item_key, result in zip(tasks.keys(), gathered):
810
+ if isinstance(result, Exception):
811
+ results[item_key] = {"_error": str(result), "_error_type": type(result).__name__}
812
+ else:
813
+ results[item_key] = result
814
+
815
+ elif mode == "any":
816
+ done, pending = await asyncio.wait(
817
+ tasks.values(),
818
+ return_when=asyncio.FIRST_COMPLETED,
819
+ timeout=timeout
820
+ )
821
+ for item_key, task in tasks.items():
822
+ if task in done:
823
+ try:
824
+ results[item_key] = task.result()
825
+ except Exception as e:
826
+ results[item_key] = {"_error": str(e), "_error_type": type(e).__name__}
827
+ break
828
+ for task in pending:
829
+ self._background_tasks.add(task)
830
+ task.add_done_callback(self._background_tasks.discard)
831
+
832
+ # Clear pending launches
833
+ for child_id in child_ids.values():
834
+ self._mark_launched(child_id)
835
+ self._clear_pending_launch(child_id)
836
+
837
+ # Return dict if key_expr provided, else list
838
+ if key_expr:
839
+ return results
840
+ else:
841
+ return [results[i] for i in sorted(results.keys()) if isinstance(i, int)]
842
+
843
+ async def _launch_fire_and_forget(
844
+ self,
845
+ machines: list[str],
846
+ input_data: Dict[str, Any]
847
+ ) -> None:
848
+ """Launch machines without waiting for results (fire-and-forget).
849
+
850
+ Delegates to self.invoker.launch() for cloud-agnostic execution.
851
+ The invoker determines HOW the launch happens (inline task, queue, etc).
852
+ """
853
+ for machine_name in machines:
854
+ child_id = str(uuid.uuid4())
855
+ target_config, _ = self._resolve_machine_config(machine_name)
856
+
857
+ # Record intent before launch (outbox pattern)
858
+ self._add_pending_launch(child_id, machine_name, input_data)
859
+
860
+ # Delegate to invoker
861
+ await self.invoker.launch(
862
+ caller_machine=self,
863
+ target_config=target_config,
864
+ input_data=input_data,
865
+ execution_id=child_id
866
+ )
867
+
868
+ self._mark_launched(child_id)
869
+
870
+ async def _run_hook(self, method_name: str, *args) -> Any:
871
+ """Run a hook method, awaiting if it's a coroutine."""
872
+ method = getattr(self._hooks, method_name)
873
+ result = method(*args)
874
+ if asyncio.iscoroutine(result):
875
+ return await result
876
+ return result
877
+
878
+ async def _execute_state(
879
+ self,
880
+ state_name: str,
881
+ context: Dict[str, Any]
882
+ ) -> tuple[Dict[str, Any], Optional[Dict[str, Any]]]:
883
+ """
884
+ Execute a single state.
885
+
886
+ Returns:
887
+ Tuple of (updated_context, agent_output)
888
+ """
889
+ state = self.states.get(state_name, {})
890
+ output = None
891
+
892
+ # 1. Handle 'action' (hooks/custom actions)
893
+ action_name = state.get('action')
894
+ if action_name:
895
+ action_impl = HookAction(self._hooks)
896
+ context = await action_impl.execute(action_name, context, config={})
897
+
898
+ # 2. Handle 'launch' (fire-and-forget machine execution)
899
+ launch_spec = state.get('launch')
900
+ if launch_spec:
901
+ launch_input_spec = state.get('launch_input', {})
902
+ variables = {"context": context, "input": context}
903
+ launch_input = self._render_dict(launch_input_spec, variables)
904
+
905
+ # Normalize to list
906
+ machines_to_launch = [launch_spec] if isinstance(launch_spec, str) else launch_spec
907
+ await self._launch_fire_and_forget(machines_to_launch, launch_input)
908
+
909
+ # 3. Handle 'machine' (peer machine execution with blocking read)
910
+ machine_spec = state.get('machine')
911
+ foreach_expr = state.get('foreach')
912
+
913
+ if machine_spec or foreach_expr:
914
+ input_spec = state.get('input', {})
915
+ variables = {"context": context, "input": context}
916
+ mode = state.get('mode', 'settled')
917
+ timeout = state.get('timeout')
918
+
919
+ if foreach_expr:
920
+ # Dynamic parallelism: foreach
921
+ items = self._render_template(foreach_expr, variables)
922
+ if not isinstance(items, list):
923
+ raise ValueError(f"foreach expression must yield a list, got {type(items)}")
924
+
925
+ as_var = state.get('as', 'item')
926
+ key_expr = state.get('key')
927
+ machine_name = machine_spec if isinstance(machine_spec, str) else machine_spec[0]
928
+
929
+ output = await self._invoke_foreach(
930
+ items=items,
931
+ as_var=as_var,
932
+ key_expr=key_expr,
933
+ machine_name=machine_name,
934
+ input_template=input_spec,
935
+ mode=mode,
936
+ timeout=timeout
937
+ )
938
+
939
+ elif isinstance(machine_spec, list):
940
+ # Parallel execution: machine: [a, b, c]
941
+ machine_input = self._render_dict(input_spec, variables)
942
+
943
+ # Handle MachineInput objects (with per-machine inputs)
944
+ if machine_spec and isinstance(machine_spec[0], dict):
945
+ # machine: [{name: a, input: {...}}, ...]
946
+ machine_names = [m['name'] for m in machine_spec]
947
+ # TODO: Support per-machine inputs
948
+ output = await self._invoke_machines_parallel(
949
+ machines=machine_names,
950
+ input_data=machine_input,
951
+ mode=mode,
952
+ timeout=timeout
953
+ )
954
+ else:
955
+ # machine: [a, b, c]
956
+ output = await self._invoke_machines_parallel(
957
+ machines=machine_spec,
958
+ input_data=machine_input,
959
+ mode=mode,
960
+ timeout=timeout
961
+ )
962
+
963
+ else:
964
+ # Single machine: machine: child
965
+ machine_input = self._render_dict(input_spec, variables)
966
+ output = await self._invoke_machine_single(
967
+ machine_name=machine_spec,
968
+ input_data=machine_input,
969
+ timeout=timeout
970
+ )
971
+
972
+ output_mapping = state.get('output_to_context', {})
973
+ if output_mapping:
974
+ safe_output = output or {}
975
+ variables = {"context": context, "output": safe_output, "input": context}
976
+ for ctx_key, template in output_mapping.items():
977
+ context[ctx_key] = self._render_template(template, variables)
978
+
979
+ # 4. Handle 'agent' (LLM execution)
980
+ agent_name = state.get('agent')
981
+ if agent_name:
982
+ agent = self._get_agent(agent_name)
983
+ input_spec = state.get('input', {})
984
+ variables = {"context": context, "input": context}
985
+ agent_input = self._render_dict(input_spec, variables)
986
+
987
+ pre_calls = agent.total_api_calls
988
+ pre_cost = agent.total_cost
989
+
990
+ execution_config = state.get('execution')
991
+ execution_type = get_execution_type(execution_config)
992
+ output = await execution_type.execute(agent, agent_input)
993
+
994
+ self.total_api_calls += agent.total_api_calls - pre_calls
995
+ self.total_cost += agent.total_cost - pre_cost
996
+
997
+ if output is None:
998
+ output = {}
999
+
1000
+ output_mapping = state.get('output_to_context', {})
1001
+ if output_mapping:
1002
+ variables = {"context": context, "output": output, "input": context}
1003
+ for ctx_key, template in output_mapping.items():
1004
+ context[ctx_key] = self._render_template(template, variables)
1005
+
1006
+ # Handle final state output
1007
+ if state.get('type') == 'final':
1008
+ output_spec = state.get('output', {})
1009
+ if output_spec:
1010
+ variables = {"context": context}
1011
+ output = self._render_dict(output_spec, variables)
1012
+
1013
+ return context, output
1014
+
1015
+ async def _save_checkpoint(
1016
+ self,
1017
+ event: str,
1018
+ state_name: str,
1019
+ step: int,
1020
+ context: Dict[str, Any],
1021
+ output: Optional[Dict[str, Any]] = None
1022
+ ) -> None:
1023
+ """Save a checkpoint if configured."""
1024
+ if event not in self.checkpoint_events:
1025
+ return
1026
+
1027
+ snapshot = MachineSnapshot(
1028
+ execution_id=self.execution_id,
1029
+ machine_name=self.machine_name,
1030
+ spec_version=self.SPEC_VERSION,
1031
+ current_state=state_name,
1032
+ context=context,
1033
+ step=step,
1034
+ event=event,
1035
+ output=output,
1036
+ total_api_calls=self.total_api_calls,
1037
+ total_cost=self.total_cost,
1038
+ parent_execution_id=self.parent_execution_id,
1039
+ pending_launches=self._get_pending_intents() if self._pending_launches else None,
1040
+ )
1041
+
1042
+ manager = CheckpointManager(self.persistence, self.execution_id)
1043
+ await manager.save_checkpoint(snapshot)
1044
+
1045
+ async def execute(
1046
+ self,
1047
+ input: Optional[Dict[str, Any]] = None,
1048
+ max_steps: int = 1000,
1049
+ resume_from: Optional[str] = None
1050
+ ) -> Dict[str, Any]:
1051
+ """Execute the machine."""
1052
+ if resume_from:
1053
+ self.execution_id = resume_from
1054
+ logger.info(f"Resuming execution: {self.execution_id}")
1055
+
1056
+ if not await self.lock.acquire(self.execution_id):
1057
+ raise RuntimeError(f"Could not acquire lock for execution {self.execution_id}")
1058
+
1059
+ try:
1060
+ context = {}
1061
+ current_state = None
1062
+ step = 0
1063
+ final_output = {}
1064
+ manager = CheckpointManager(self.persistence, self.execution_id)
1065
+
1066
+ if resume_from:
1067
+ snapshot = await manager.load_latest()
1068
+ if snapshot:
1069
+ context = snapshot.context
1070
+ step = snapshot.step
1071
+ current_state = snapshot.current_state
1072
+ # Restore execution metrics
1073
+ self.total_api_calls = snapshot.total_api_calls or 0
1074
+ self.total_cost = snapshot.total_cost or 0.0
1075
+ # Restore pending launches (outbox pattern)
1076
+ if snapshot.pending_launches:
1077
+ self._pending_launches = [
1078
+ LaunchIntent.from_dict(intent)
1079
+ for intent in snapshot.pending_launches
1080
+ ]
1081
+ await self._resume_pending_launches()
1082
+ if snapshot.event == 'machine_end':
1083
+ logger.info("Execution already completed.")
1084
+ return snapshot.output or {}
1085
+ logger.info(f"Restored from snapshot: step={step}, state={current_state}")
1086
+ else:
1087
+ logger.warning(f"No snapshot found for {resume_from}, starting fresh.")
1088
+
1089
+ if not current_state:
1090
+ current_state = self._initial_state
1091
+ input = input or {}
1092
+ variables = {"input": input}
1093
+ context = self._render_dict(self.initial_context, variables)
1094
+
1095
+ await self._save_checkpoint('machine_start', 'start', step, context)
1096
+ context = await self._run_hook('on_machine_start', context)
1097
+
1098
+ logger.info(f"Starting execution loop at: {current_state}")
1099
+
1100
+ while current_state and step < max_steps:
1101
+ step += 1
1102
+ is_final = current_state in self._final_states
1103
+
1104
+ await self._save_checkpoint('state_enter', current_state, step, context)
1105
+ context = await self._run_hook('on_state_enter', current_state, context)
1106
+
1107
+ await self._save_checkpoint('execute', current_state, step, context)
1108
+
1109
+ try:
1110
+ context, output = await self._execute_state(current_state, context)
1111
+ if output and is_final:
1112
+ final_output = output
1113
+ except Exception as e:
1114
+ context['last_error'] = str(e)
1115
+ context['last_error_type'] = type(e).__name__
1116
+
1117
+ state_config = self.states.get(current_state, {})
1118
+ recovery_state = self._get_error_recovery_state(state_config, e)
1119
+
1120
+ if not recovery_state:
1121
+ recovery_state = await self._run_hook('on_error', current_state, e, context)
1122
+
1123
+ if recovery_state:
1124
+ logger.warning(f"Error in {current_state}, transitioning to {recovery_state}: {e}")
1125
+ current_state = recovery_state
1126
+ continue
1127
+ raise
1128
+
1129
+ await self._save_checkpoint(
1130
+ 'state_exit',
1131
+ current_state,
1132
+ step,
1133
+ context,
1134
+ output=output if is_final else None
1135
+ )
1136
+
1137
+ output = await self._run_hook('on_state_exit', current_state, context, output)
1138
+
1139
+ if is_final:
1140
+ logger.info(f"Reached final state: {current_state}")
1141
+ break
1142
+
1143
+ next_state = self._find_next_state(current_state, context)
1144
+
1145
+ if next_state:
1146
+ next_state = await self._run_hook('on_transition', current_state, next_state, context)
1147
+
1148
+ logger.debug(f"Transition: {current_state} -> {next_state}")
1149
+ current_state = next_state
1150
+
1151
+ if step >= max_steps:
1152
+ logger.warning(f"Machine hit max_steps limit ({max_steps})")
1153
+
1154
+ await self._save_checkpoint('machine_end', 'end', step, context, output=final_output)
1155
+ final_output = await self._run_hook('on_machine_end', context, final_output)
1156
+
1157
+ return final_output
1158
+
1159
+ finally:
1160
+ # Wait for any launched peer machines to complete
1161
+ # This ensures peer equality - launched machines have equal right to finish
1162
+ if self._background_tasks:
1163
+ await asyncio.gather(*self._background_tasks, return_exceptions=True)
1164
+ await self.lock.release(self.execution_id)
1165
+
1166
+ def execute_sync(
1167
+ self,
1168
+ input: Optional[Dict[str, Any]] = None,
1169
+ max_steps: int = 1000
1170
+ ) -> Dict[str, Any]:
1171
+ """Synchronous wrapper for execute()."""
1172
+ import asyncio
1173
+ return asyncio.run(self.execute(input=input, max_steps=max_steps))
1174
+
1175
+
1176
+ __all__ = ["FlatMachine"]