flatmachines 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. flatmachines/__init__.py +136 -0
  2. flatmachines/actions.py +408 -0
  3. flatmachines/adapters/__init__.py +38 -0
  4. flatmachines/adapters/flatagent.py +86 -0
  5. flatmachines/adapters/pi_agent_bridge.py +127 -0
  6. flatmachines/adapters/pi_agent_runner.mjs +99 -0
  7. flatmachines/adapters/smolagents.py +125 -0
  8. flatmachines/agents.py +144 -0
  9. flatmachines/assets/MACHINES.md +141 -0
  10. flatmachines/assets/README.md +11 -0
  11. flatmachines/assets/__init__.py +0 -0
  12. flatmachines/assets/flatagent.d.ts +219 -0
  13. flatmachines/assets/flatagent.schema.json +271 -0
  14. flatmachines/assets/flatagent.slim.d.ts +58 -0
  15. flatmachines/assets/flatagents-runtime.d.ts +523 -0
  16. flatmachines/assets/flatagents-runtime.schema.json +281 -0
  17. flatmachines/assets/flatagents-runtime.slim.d.ts +187 -0
  18. flatmachines/assets/flatmachine.d.ts +403 -0
  19. flatmachines/assets/flatmachine.schema.json +620 -0
  20. flatmachines/assets/flatmachine.slim.d.ts +106 -0
  21. flatmachines/assets/profiles.d.ts +140 -0
  22. flatmachines/assets/profiles.schema.json +93 -0
  23. flatmachines/assets/profiles.slim.d.ts +26 -0
  24. flatmachines/backends.py +222 -0
  25. flatmachines/distributed.py +835 -0
  26. flatmachines/distributed_hooks.py +351 -0
  27. flatmachines/execution.py +638 -0
  28. flatmachines/expressions/__init__.py +60 -0
  29. flatmachines/expressions/cel.py +101 -0
  30. flatmachines/expressions/simple.py +166 -0
  31. flatmachines/flatmachine.py +1263 -0
  32. flatmachines/hooks.py +381 -0
  33. flatmachines/locking.py +69 -0
  34. flatmachines/monitoring.py +505 -0
  35. flatmachines/persistence.py +213 -0
  36. flatmachines/run.py +117 -0
  37. flatmachines/utils.py +166 -0
  38. flatmachines/validation.py +79 -0
  39. flatmachines-1.0.0.dist-info/METADATA +390 -0
  40. flatmachines-1.0.0.dist-info/RECORD +41 -0
  41. flatmachines-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,638 @@
1
+ """
2
+ Execution Types for FlatMachine.
3
+
4
+ Provides different execution strategies for agent calls:
5
+ - Default: Single call
6
+ - Parallel: Multiple calls, first success or aggregate
7
+ - Retry: Multiple attempts with backoff
8
+ - MDAP Voting: Multi-sampling with majority vote
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ import re
14
+ from abc import ABC, abstractmethod
15
+ from collections import Counter
16
+ from dataclasses import dataclass, field
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from .monitoring import get_logger
20
+ from .agents import AgentExecutor, AgentResult, coerce_agent_result
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ def _coerce_status_code(value: Any) -> Optional[int]:
26
+ if value is None:
27
+ return None
28
+ if isinstance(value, int):
29
+ return value
30
+ if isinstance(value, str) and value.isdigit():
31
+ return int(value)
32
+ return None
33
+
34
+
35
+ def _extract_status_code(error: Optional[BaseException]) -> Optional[int]:
36
+ if error is None:
37
+ return None
38
+
39
+ for attr in ("status_code", "status", "http_status", "statusCode"):
40
+ code = _coerce_status_code(getattr(error, attr, None))
41
+ if code is not None:
42
+ return code
43
+
44
+ response = getattr(error, "response", None)
45
+ if response is not None:
46
+ for attr in ("status_code", "status", "http_status", "statusCode"):
47
+ code = _coerce_status_code(getattr(response, attr, None))
48
+ if code is not None:
49
+ return code
50
+ if isinstance(response, dict):
51
+ for key in ("status_code", "status", "http_status", "statusCode"):
52
+ code = _coerce_status_code(response.get(key))
53
+ if code is not None:
54
+ return code
55
+
56
+ match = re.search(r"\b([4-5]\d{2})\b", str(error))
57
+ if match:
58
+ return int(match.group(1))
59
+ return None
60
+
61
+
62
+ def _normalize_headers(raw_headers: Any) -> Dict[str, str]:
63
+ if raw_headers is None:
64
+ return {}
65
+
66
+ if isinstance(raw_headers, dict):
67
+ items = raw_headers.items()
68
+ elif hasattr(raw_headers, "items"):
69
+ items = raw_headers.items()
70
+ elif isinstance(raw_headers, (list, tuple)):
71
+ items = raw_headers
72
+ else:
73
+ return {}
74
+
75
+ normalized: Dict[str, str] = {}
76
+ for key, value in items:
77
+ if key is None:
78
+ continue
79
+ key_text = str(key).lower()
80
+ if isinstance(value, (list, tuple)):
81
+ value_text = ",".join(str(item) for item in value)
82
+ else:
83
+ value_text = str(value)
84
+ normalized[key_text] = value_text
85
+
86
+ return normalized
87
+
88
+
89
+ def _extract_error_headers(error: Optional[BaseException]) -> Dict[str, str]:
90
+ if error is None:
91
+ return {}
92
+
93
+ response = getattr(error, "response", None)
94
+ headers: Dict[str, str] = {}
95
+ if response is not None:
96
+ headers.update(_normalize_headers(getattr(response, "headers", None)))
97
+ if not headers and isinstance(response, dict):
98
+ headers.update(_normalize_headers(response.get("headers")))
99
+
100
+ headers.update(_normalize_headers(getattr(error, "headers", None)))
101
+ return headers
102
+
103
+
104
+ def _extract_api_calls(result: AgentResult) -> int:
105
+ usage = result.usage or {}
106
+ if isinstance(usage, dict):
107
+ return int(usage.get("api_calls") or usage.get("requests") or usage.get("calls") or 0)
108
+ return 0
109
+
110
+
111
+ def _extract_cost(result: AgentResult) -> float:
112
+ if result.cost is not None:
113
+ try:
114
+ return float(result.cost)
115
+ except (TypeError, ValueError):
116
+ return 0.0
117
+ usage = result.usage or {}
118
+ if isinstance(usage, dict):
119
+ cost = usage.get("cost")
120
+ if isinstance(cost, (int, float)):
121
+ return float(cost)
122
+ if isinstance(cost, dict):
123
+ total = cost.get("total")
124
+ if isinstance(total, (int, float)):
125
+ return float(total)
126
+ return 0.0
127
+
128
+
129
+ def _merge_usage(result: AgentResult, api_calls: int) -> Optional[Dict[str, Any]]:
130
+ if api_calls == 0 and result.usage is None:
131
+ return result.usage
132
+ usage: Dict[str, Any] = {}
133
+ if isinstance(result.usage, dict):
134
+ usage.update(result.usage)
135
+ if api_calls:
136
+ usage["api_calls"] = api_calls
137
+ return usage
138
+
139
+
140
+ # Registry of execution types
141
+ _EXECUTION_TYPES: Dict[str, type] = {}
142
+
143
+
144
+ def register_execution_type(name: str):
145
+ """Decorator to register an execution type."""
146
+ def decorator(cls):
147
+ _EXECUTION_TYPES[name] = cls
148
+ return cls
149
+ return decorator
150
+
151
+
152
+ def get_execution_type(config: Optional[Dict[str, Any]] = None) -> "ExecutionType":
153
+ """Get an execution type instance from config."""
154
+ if config is None:
155
+ return DefaultExecution()
156
+
157
+ type_name = config.get("type", "default")
158
+ if type_name not in _EXECUTION_TYPES:
159
+ raise ValueError(f"Unknown execution type: {type_name}")
160
+
161
+ cls = _EXECUTION_TYPES[type_name]
162
+ return cls.from_config(config)
163
+
164
+
165
+ class ExecutionType(ABC):
166
+ """Base class for execution types."""
167
+
168
+ @classmethod
169
+ @abstractmethod
170
+ def from_config(cls, config: Dict[str, Any]) -> "ExecutionType":
171
+ """Create instance from YAML config."""
172
+ pass
173
+
174
+ @abstractmethod
175
+ async def execute(
176
+ self,
177
+ executor: AgentExecutor,
178
+ input_data: Dict[str, Any],
179
+ context: Optional[Dict[str, Any]] = None,
180
+ ) -> AgentResult:
181
+ """
182
+ Execute the agent with this execution type.
183
+
184
+ Args:
185
+ executor: The AgentExecutor to call
186
+ input_data: Input data for the agent
187
+ context: Current machine context
188
+
189
+ Returns:
190
+ AgentResult
191
+ """
192
+ pass
193
+
194
+
195
+ @register_execution_type("default")
196
+ class DefaultExecution(ExecutionType):
197
+ """Standard single agent call."""
198
+
199
+ @classmethod
200
+ def from_config(cls, config: Dict[str, Any]) -> "DefaultExecution":
201
+ return cls()
202
+
203
+ async def execute(
204
+ self,
205
+ executor: AgentExecutor,
206
+ input_data: Dict[str, Any],
207
+ context: Optional[Dict[str, Any]] = None,
208
+ ) -> AgentResult:
209
+ """Single agent call."""
210
+ result = await executor.execute(input_data, context=context)
211
+ return coerce_agent_result(result)
212
+
213
+
214
+ # Parallel Execution Type
215
+
216
+ @register_execution_type("parallel")
217
+ class ParallelExecution(ExecutionType):
218
+ """
219
+ Run N samples in parallel, return all results.
220
+
221
+ Useful for getting multiple diverse responses to compare or aggregate.
222
+
223
+ Example YAML:
224
+ execution:
225
+ type: parallel
226
+ n_samples: 5
227
+ """
228
+
229
+ def __init__(self, n_samples: int = 3):
230
+ self.n_samples = n_samples
231
+
232
+ @classmethod
233
+ def from_config(cls, config: Dict[str, Any]) -> "ParallelExecution":
234
+ return cls(
235
+ n_samples=config.get("n_samples", 3)
236
+ )
237
+
238
+ async def execute(
239
+ self,
240
+ executor: AgentExecutor,
241
+ input_data: Dict[str, Any],
242
+ context: Optional[Dict[str, Any]] = None,
243
+ ) -> AgentResult:
244
+ """Run N agent calls in parallel, return all results."""
245
+ async def single_call() -> AgentResult:
246
+ result = await executor.execute(input_data, context=context)
247
+ return coerce_agent_result(result)
248
+
249
+ # Run all samples in parallel
250
+ tasks = [single_call() for _ in range(self.n_samples)]
251
+ results = await asyncio.gather(*tasks, return_exceptions=True)
252
+
253
+ # Filter out exceptions
254
+ valid_results = [r for r in results if not isinstance(r, Exception)]
255
+
256
+ if not valid_results:
257
+ return AgentResult()
258
+
259
+ payloads = [result.output_payload() for result in valid_results]
260
+ total_api_calls = sum(_extract_api_calls(result) for result in valid_results)
261
+ total_cost = sum(_extract_cost(result) for result in valid_results)
262
+
263
+ usage = {"api_calls": total_api_calls} if total_api_calls else None
264
+ cost = total_cost if total_cost else None
265
+
266
+ return AgentResult(
267
+ output={"results": payloads, "count": len(payloads)},
268
+ raw={"results": valid_results},
269
+ usage=usage,
270
+ cost=cost,
271
+ )
272
+
273
+
274
+ # Retry Execution Type
275
+
276
+ @register_execution_type("retry")
277
+ class RetryExecution(ExecutionType):
278
+ """
279
+ Retry on failure with configurable backoff delays and jitter.
280
+
281
+ Default backoffs [2, 8, 16, 35] total 61 seconds, intended to wait
282
+ for a fresh RPM (requests per minute) bucket.
283
+
284
+ Example YAML:
285
+ execution:
286
+ type: retry
287
+ backoffs: [2, 8, 16, 35] # Backoff delays in seconds
288
+ jitter: 0.1 # Random jitter factor (0.1 = ±10%)
289
+ """
290
+
291
+ # Default backoffs: 2 + 8 + 16 + 35 = 61 seconds (wait for fresh RPM bucket)
292
+ DEFAULT_BACKOFFS = [2, 8, 16, 35]
293
+
294
+ def __init__(
295
+ self,
296
+ backoffs: Optional[List[float]] = None,
297
+ jitter: float = 0.1,
298
+ retry_on_empty: bool = False
299
+ ):
300
+ self.backoffs = backoffs if backoffs is not None else self.DEFAULT_BACKOFFS
301
+ self.jitter = jitter
302
+ self.retry_on_empty = retry_on_empty
303
+
304
+ @classmethod
305
+ def from_config(cls, config: Dict[str, Any]) -> "RetryExecution":
306
+ return cls(
307
+ backoffs=config.get("backoffs"),
308
+ jitter=config.get("jitter", 0.1),
309
+ retry_on_empty=config.get("retry_on_empty", False)
310
+ )
311
+
312
+ def _apply_jitter(self, delay: float) -> float:
313
+ """Apply random jitter to a delay."""
314
+ import random
315
+ jitter_range = delay * self.jitter
316
+ return delay + random.uniform(-jitter_range, jitter_range)
317
+
318
+ async def execute(
319
+ self,
320
+ executor: AgentExecutor,
321
+ input_data: Dict[str, Any],
322
+ context: Optional[Dict[str, Any]] = None,
323
+ ) -> AgentResult:
324
+ """Execute with retries on failure."""
325
+ last_error = None
326
+ max_attempts = len(self.backoffs) + 1 # Initial attempt + retries
327
+ total_api_calls = 0
328
+ total_cost = 0.0
329
+
330
+ for attempt in range(max_attempts):
331
+ try:
332
+ result = await executor.execute(input_data, context=context)
333
+ agent_result = coerce_agent_result(result)
334
+ total_api_calls += _extract_api_calls(agent_result)
335
+ total_cost += _extract_cost(agent_result)
336
+ payload = agent_result.output_payload()
337
+
338
+ merged_usage = _merge_usage(agent_result, total_api_calls)
339
+ merged_cost = total_cost if total_cost else agent_result.cost
340
+
341
+ if payload:
342
+ return AgentResult(
343
+ output=agent_result.output,
344
+ content=agent_result.content,
345
+ raw=agent_result.raw,
346
+ usage=merged_usage,
347
+ cost=merged_cost,
348
+ metadata=agent_result.metadata,
349
+ )
350
+
351
+ if self.retry_on_empty:
352
+ raise ValueError("Empty response from agent")
353
+
354
+ return AgentResult(
355
+ output=agent_result.output,
356
+ content=agent_result.content,
357
+ raw=agent_result.raw,
358
+ usage=merged_usage,
359
+ cost=merged_cost,
360
+ metadata=agent_result.metadata,
361
+ )
362
+
363
+ except Exception as e:
364
+ last_error = e
365
+ logger.warning(
366
+ f"Attempt {attempt + 1}/{max_attempts} failed: {e}"
367
+ )
368
+
369
+ # If we have more retries, wait with jitter
370
+ if attempt < len(self.backoffs):
371
+ delay = self._apply_jitter(self.backoffs[attempt])
372
+ logger.info(f"Retrying in {delay:.1f}s...")
373
+ await asyncio.sleep(delay)
374
+
375
+ # All retries exhausted
376
+ logger.error(f"All {max_attempts} attempts failed. Last error: {last_error}")
377
+ error_payload: Dict[str, Any] = {
378
+ "_error": str(last_error) if last_error else "LLM call failed",
379
+ "_error_type": type(last_error).__name__ if last_error else "UnknownError",
380
+ }
381
+ status_code = _extract_status_code(last_error)
382
+ if status_code is not None:
383
+ error_payload["_error_status_code"] = status_code
384
+ headers = _extract_error_headers(last_error)
385
+ if headers:
386
+ error_payload["_error_headers"] = headers
387
+
388
+ usage = {"api_calls": total_api_calls} if total_api_calls else None
389
+ cost = total_cost if total_cost else None
390
+
391
+ return AgentResult(output=error_payload, raw=last_error, usage=usage, cost=cost)
392
+
393
+
394
+ # MDAP Voting Execution Type
395
+
396
+ @dataclass
397
+ class MDAPMetrics:
398
+ """Execution metrics collected during MDAP runs."""
399
+ total_samples: int = 0
400
+ total_red_flags: int = 0
401
+ red_flags_by_reason: Dict[str, int] = field(default_factory=dict)
402
+ samples_per_step: List[int] = field(default_factory=list)
403
+
404
+ def record_red_flag(self, reason: str):
405
+ self.total_red_flags += 1
406
+ self.red_flags_by_reason[reason] = self.red_flags_by_reason.get(reason, 0) + 1
407
+
408
+
409
+ @register_execution_type("mdap_voting")
410
+ class MDAPVotingExecution(ExecutionType):
411
+ """
412
+ Multi-sample with first-to-ahead-by-k voting.
413
+
414
+ Implements the voting algorithm from the MAKER paper.
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ k_margin: int = 3,
420
+ max_candidates: int = 10,
421
+ max_response_tokens: Optional[int] = None
422
+ ):
423
+ self.k_margin = k_margin
424
+ self.max_candidates = max_candidates
425
+ self.max_response_tokens = max_response_tokens
426
+ self.metrics = MDAPMetrics()
427
+
428
+ # Loaded from agent metadata
429
+ self._patterns: Dict[str, Tuple[re.Pattern, str]] = {}
430
+ self._validation_schema: Optional[Dict] = None
431
+
432
+ @classmethod
433
+ def from_config(cls, config: Dict[str, Any]) -> "MDAPVotingExecution":
434
+ return cls(
435
+ k_margin=config.get("k_margin", 3),
436
+ max_candidates=config.get("max_candidates", 10),
437
+ max_response_tokens=config.get("max_response_tokens")
438
+ )
439
+
440
+ def _configure_from_executor(self, executor: AgentExecutor):
441
+ """Load parsing and validation config from executor metadata."""
442
+ metadata = getattr(executor, "metadata", {}) or {}
443
+
444
+ # Check if metadata overrides execution config
445
+ mdap_config = metadata.get('mdap', {}) if isinstance(metadata, dict) else {}
446
+ if mdap_config.get('k_margin'):
447
+ self.k_margin = mdap_config['k_margin']
448
+ if mdap_config.get('max_candidates'):
449
+ self.max_candidates = mdap_config['max_candidates']
450
+ if mdap_config.get('max_response_tokens'):
451
+ self.max_response_tokens = mdap_config['max_response_tokens']
452
+
453
+ # Load parsing patterns
454
+ parsing_config = metadata.get('parsing', {}) if isinstance(metadata, dict) else {}
455
+ self._patterns = {}
456
+ for field_name, field_config in parsing_config.items():
457
+ pattern = field_config.get('pattern')
458
+ if pattern:
459
+ self._patterns[field_name] = (
460
+ re.compile(pattern, re.DOTALL),
461
+ field_config.get('type', 'str')
462
+ )
463
+
464
+ # Load validation schema
465
+ self._validation_schema = metadata.get('validation', None) if isinstance(metadata, dict) else None
466
+
467
+ def _parse_response(self, content: str) -> Optional[Dict[str, Any]]:
468
+ """Parse LLM response using regex patterns."""
469
+ if not self._patterns:
470
+ return None
471
+
472
+ result = {}
473
+ for field_name, (pattern, field_type) in self._patterns.items():
474
+ match = pattern.search(content)
475
+ if match:
476
+ value = match.group(1)
477
+ if field_type == 'json':
478
+ try:
479
+ result[field_name] = json.loads(value)
480
+ except json.JSONDecodeError:
481
+ return None
482
+ elif field_type == 'int':
483
+ try:
484
+ result[field_name] = int(value)
485
+ except ValueError:
486
+ return None
487
+ else:
488
+ result[field_name] = value
489
+ else:
490
+ return None
491
+
492
+ return result
493
+
494
+ def _validate_parsed(self, parsed: Dict[str, Any]) -> bool:
495
+ """Validate parsed result against JSON Schema."""
496
+ if not self._validation_schema:
497
+ return True
498
+
499
+ try:
500
+ import jsonschema
501
+ jsonschema.validate(instance=parsed, schema=self._validation_schema)
502
+ return True
503
+ except Exception:
504
+ return False
505
+
506
+ def _check_red_flags(self, content: str, parsed: Optional[Dict[str, Any]]) -> Optional[str]:
507
+ """Check response for red flags per MAKER paper."""
508
+ if parsed is None:
509
+ return "format_error"
510
+
511
+ if not self._validate_parsed(parsed):
512
+ return "validation_failed"
513
+
514
+ # Only check response length if max_response_tokens is set
515
+ if self.max_response_tokens is not None:
516
+ estimated_tokens = len(content) // 4
517
+ if estimated_tokens > self.max_response_tokens:
518
+ return "length_exceeded"
519
+
520
+ return None
521
+
522
+ def _extract_candidate(self, result: AgentResult) -> tuple[Optional[Dict[str, Any]], str]:
523
+ """Extract candidate payload and content for voting."""
524
+ content = result.content or ""
525
+
526
+ if self._patterns:
527
+ parsed = self._parse_response(content) if content else None
528
+ return parsed, content
529
+
530
+ if result.output is not None:
531
+ payload = result.output
532
+ content_for_length = content or json.dumps(payload, sort_keys=True)
533
+ return payload, content_for_length
534
+
535
+ if content:
536
+ return {"content": content}, content
537
+
538
+ return None, content
539
+
540
+ async def execute(
541
+ self,
542
+ executor: AgentExecutor,
543
+ input_data: Dict[str, Any],
544
+ context: Optional[Dict[str, Any]] = None,
545
+ ) -> AgentResult:
546
+ """
547
+ Multi-sample with voting - replaces single agent call.
548
+
549
+ Returns the winning parsed response or empty AgentResult.
550
+ """
551
+ self._configure_from_executor(executor)
552
+
553
+ votes: Counter = Counter()
554
+ responses: Dict[str, Any] = {}
555
+ num_samples = 0
556
+ total_api_calls = 0
557
+ total_cost = 0.0
558
+
559
+ for _ in range(self.max_candidates):
560
+ try:
561
+ result = await executor.execute(input_data, context=context)
562
+ agent_result = coerce_agent_result(result)
563
+ num_samples += 1
564
+ self.metrics.total_samples += 1
565
+ total_api_calls += _extract_api_calls(agent_result)
566
+ total_cost += _extract_cost(agent_result)
567
+
568
+ candidate, content = self._extract_candidate(agent_result)
569
+ if candidate is None:
570
+ flag_reason = "format_error"
571
+ else:
572
+ flag_reason = self._check_red_flags(content, candidate)
573
+
574
+ if flag_reason:
575
+ self.metrics.record_red_flag(flag_reason)
576
+ continue
577
+
578
+ key = json.dumps(candidate, sort_keys=True) if not isinstance(candidate, str) else candidate
579
+ votes[key] += 1
580
+ responses[key] = candidate
581
+
582
+ if votes[key] >= self.k_margin:
583
+ self.metrics.samples_per_step.append(num_samples)
584
+ usage = {"api_calls": total_api_calls} if total_api_calls else None
585
+ cost = total_cost if total_cost else None
586
+ return AgentResult(
587
+ output=responses[key],
588
+ raw=agent_result.raw,
589
+ usage=usage,
590
+ cost=cost,
591
+ )
592
+
593
+ if len(votes) >= 2:
594
+ top = votes.most_common(2)
595
+ if top[0][1] - top[1][1] >= self.k_margin:
596
+ self.metrics.samples_per_step.append(num_samples)
597
+ usage = {"api_calls": total_api_calls} if total_api_calls else None
598
+ cost = total_cost if total_cost else None
599
+ return AgentResult(
600
+ output=responses[top[0][0]],
601
+ usage=usage,
602
+ cost=cost,
603
+ )
604
+
605
+ except Exception as e:
606
+ logger.warning(f"Sample failed: {e}")
607
+ continue
608
+
609
+ # Majority fallback
610
+ self.metrics.samples_per_step.append(num_samples)
611
+ usage = {"api_calls": total_api_calls} if total_api_calls else None
612
+ cost = total_cost if total_cost else None
613
+
614
+ if votes:
615
+ winner_key = votes.most_common(1)[0][0]
616
+ return AgentResult(output=responses[winner_key], usage=usage, cost=cost)
617
+
618
+ return AgentResult(usage=usage, cost=cost)
619
+
620
+ def get_metrics(self) -> Dict[str, Any]:
621
+ """Get collected metrics."""
622
+ return {
623
+ "total_samples": self.metrics.total_samples,
624
+ "total_red_flags": self.metrics.total_red_flags,
625
+ "red_flags_by_reason": self.metrics.red_flags_by_reason,
626
+ "samples_per_step": self.metrics.samples_per_step,
627
+ }
628
+
629
+
630
+ __all__ = [
631
+ "ExecutionType",
632
+ "DefaultExecution",
633
+ "ParallelExecution",
634
+ "RetryExecution",
635
+ "MDAPVotingExecution",
636
+ "get_execution_type",
637
+ "register_execution_type",
638
+ ]
@@ -0,0 +1,60 @@
1
+ """
2
+ Expression engines for flatmachines.
3
+
4
+ Provides two modes:
5
+ - simple: Built-in parser for basic comparisons and boolean logic (default)
6
+ - cel: Full CEL support via cel-python (optional extra)
7
+ """
8
+
9
+ from typing import Any, Dict, Protocol, runtime_checkable
10
+
11
+
12
+ @runtime_checkable
13
+ class ExpressionEngine(Protocol):
14
+ """Protocol for expression engines."""
15
+
16
+ def evaluate(self, expression: str, variables: Dict[str, Any]) -> Any:
17
+ """
18
+ Evaluate an expression with the given variables.
19
+
20
+ Args:
21
+ expression: The expression string to evaluate
22
+ variables: Dictionary of variable names to values
23
+ (e.g., {"context": {...}, "input": {...}, "output": {...}})
24
+
25
+ Returns:
26
+ The result of evaluating the expression
27
+ """
28
+ ...
29
+
30
+
31
+ def get_expression_engine(mode: str = "simple") -> ExpressionEngine:
32
+ """
33
+ Get an expression engine by mode.
34
+
35
+ Args:
36
+ mode: "simple" (default) or "cel"
37
+
38
+ Returns:
39
+ ExpressionEngine instance
40
+
41
+ Raises:
42
+ ImportError: If CEL mode requested but cel-python not installed
43
+ ValueError: If unknown mode
44
+ """
45
+ if mode == "simple":
46
+ from .simple import SimpleExpressionEngine
47
+ return SimpleExpressionEngine()
48
+ elif mode == "cel":
49
+ try:
50
+ from .cel import CELExpressionEngine
51
+ return CELExpressionEngine()
52
+ except ImportError:
53
+ raise ImportError(
54
+ "CEL expression engine requires: pip install flatmachines[cel]"
55
+ )
56
+ else:
57
+ raise ValueError(f"Unknown expression engine: {mode}")
58
+
59
+
60
+ __all__ = ["ExpressionEngine", "get_expression_engine"]