flatmachines 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flatmachines/__init__.py +136 -0
- flatmachines/actions.py +408 -0
- flatmachines/adapters/__init__.py +38 -0
- flatmachines/adapters/flatagent.py +86 -0
- flatmachines/adapters/pi_agent_bridge.py +127 -0
- flatmachines/adapters/pi_agent_runner.mjs +99 -0
- flatmachines/adapters/smolagents.py +125 -0
- flatmachines/agents.py +144 -0
- flatmachines/assets/MACHINES.md +141 -0
- flatmachines/assets/README.md +11 -0
- flatmachines/assets/__init__.py +0 -0
- flatmachines/assets/flatagent.d.ts +219 -0
- flatmachines/assets/flatagent.schema.json +271 -0
- flatmachines/assets/flatagent.slim.d.ts +58 -0
- flatmachines/assets/flatagents-runtime.d.ts +523 -0
- flatmachines/assets/flatagents-runtime.schema.json +281 -0
- flatmachines/assets/flatagents-runtime.slim.d.ts +187 -0
- flatmachines/assets/flatmachine.d.ts +403 -0
- flatmachines/assets/flatmachine.schema.json +620 -0
- flatmachines/assets/flatmachine.slim.d.ts +106 -0
- flatmachines/assets/profiles.d.ts +140 -0
- flatmachines/assets/profiles.schema.json +93 -0
- flatmachines/assets/profiles.slim.d.ts +26 -0
- flatmachines/backends.py +222 -0
- flatmachines/distributed.py +835 -0
- flatmachines/distributed_hooks.py +351 -0
- flatmachines/execution.py +638 -0
- flatmachines/expressions/__init__.py +60 -0
- flatmachines/expressions/cel.py +101 -0
- flatmachines/expressions/simple.py +166 -0
- flatmachines/flatmachine.py +1263 -0
- flatmachines/hooks.py +381 -0
- flatmachines/locking.py +69 -0
- flatmachines/monitoring.py +505 -0
- flatmachines/persistence.py +213 -0
- flatmachines/run.py +117 -0
- flatmachines/utils.py +166 -0
- flatmachines/validation.py +79 -0
- flatmachines-1.0.0.dist-info/METADATA +390 -0
- flatmachines-1.0.0.dist-info/RECORD +41 -0
- flatmachines-1.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,638 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Execution Types for FlatMachine.
|
|
3
|
+
|
|
4
|
+
Provides different execution strategies for agent calls:
|
|
5
|
+
- Default: Single call
|
|
6
|
+
- Parallel: Multiple calls, first success or aggregate
|
|
7
|
+
- Retry: Multiple attempts with backoff
|
|
8
|
+
- MDAP Voting: Multi-sampling with majority vote
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import json
|
|
13
|
+
import re
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from collections import Counter
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
from .monitoring import get_logger
|
|
20
|
+
from .agents import AgentExecutor, AgentResult, coerce_agent_result
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _coerce_status_code(value: Any) -> Optional[int]:
|
|
26
|
+
if value is None:
|
|
27
|
+
return None
|
|
28
|
+
if isinstance(value, int):
|
|
29
|
+
return value
|
|
30
|
+
if isinstance(value, str) and value.isdigit():
|
|
31
|
+
return int(value)
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _extract_status_code(error: Optional[BaseException]) -> Optional[int]:
|
|
36
|
+
if error is None:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
for attr in ("status_code", "status", "http_status", "statusCode"):
|
|
40
|
+
code = _coerce_status_code(getattr(error, attr, None))
|
|
41
|
+
if code is not None:
|
|
42
|
+
return code
|
|
43
|
+
|
|
44
|
+
response = getattr(error, "response", None)
|
|
45
|
+
if response is not None:
|
|
46
|
+
for attr in ("status_code", "status", "http_status", "statusCode"):
|
|
47
|
+
code = _coerce_status_code(getattr(response, attr, None))
|
|
48
|
+
if code is not None:
|
|
49
|
+
return code
|
|
50
|
+
if isinstance(response, dict):
|
|
51
|
+
for key in ("status_code", "status", "http_status", "statusCode"):
|
|
52
|
+
code = _coerce_status_code(response.get(key))
|
|
53
|
+
if code is not None:
|
|
54
|
+
return code
|
|
55
|
+
|
|
56
|
+
match = re.search(r"\b([4-5]\d{2})\b", str(error))
|
|
57
|
+
if match:
|
|
58
|
+
return int(match.group(1))
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _normalize_headers(raw_headers: Any) -> Dict[str, str]:
|
|
63
|
+
if raw_headers is None:
|
|
64
|
+
return {}
|
|
65
|
+
|
|
66
|
+
if isinstance(raw_headers, dict):
|
|
67
|
+
items = raw_headers.items()
|
|
68
|
+
elif hasattr(raw_headers, "items"):
|
|
69
|
+
items = raw_headers.items()
|
|
70
|
+
elif isinstance(raw_headers, (list, tuple)):
|
|
71
|
+
items = raw_headers
|
|
72
|
+
else:
|
|
73
|
+
return {}
|
|
74
|
+
|
|
75
|
+
normalized: Dict[str, str] = {}
|
|
76
|
+
for key, value in items:
|
|
77
|
+
if key is None:
|
|
78
|
+
continue
|
|
79
|
+
key_text = str(key).lower()
|
|
80
|
+
if isinstance(value, (list, tuple)):
|
|
81
|
+
value_text = ",".join(str(item) for item in value)
|
|
82
|
+
else:
|
|
83
|
+
value_text = str(value)
|
|
84
|
+
normalized[key_text] = value_text
|
|
85
|
+
|
|
86
|
+
return normalized
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _extract_error_headers(error: Optional[BaseException]) -> Dict[str, str]:
|
|
90
|
+
if error is None:
|
|
91
|
+
return {}
|
|
92
|
+
|
|
93
|
+
response = getattr(error, "response", None)
|
|
94
|
+
headers: Dict[str, str] = {}
|
|
95
|
+
if response is not None:
|
|
96
|
+
headers.update(_normalize_headers(getattr(response, "headers", None)))
|
|
97
|
+
if not headers and isinstance(response, dict):
|
|
98
|
+
headers.update(_normalize_headers(response.get("headers")))
|
|
99
|
+
|
|
100
|
+
headers.update(_normalize_headers(getattr(error, "headers", None)))
|
|
101
|
+
return headers
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_api_calls(result: AgentResult) -> int:
|
|
105
|
+
usage = result.usage or {}
|
|
106
|
+
if isinstance(usage, dict):
|
|
107
|
+
return int(usage.get("api_calls") or usage.get("requests") or usage.get("calls") or 0)
|
|
108
|
+
return 0
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _extract_cost(result: AgentResult) -> float:
|
|
112
|
+
if result.cost is not None:
|
|
113
|
+
try:
|
|
114
|
+
return float(result.cost)
|
|
115
|
+
except (TypeError, ValueError):
|
|
116
|
+
return 0.0
|
|
117
|
+
usage = result.usage or {}
|
|
118
|
+
if isinstance(usage, dict):
|
|
119
|
+
cost = usage.get("cost")
|
|
120
|
+
if isinstance(cost, (int, float)):
|
|
121
|
+
return float(cost)
|
|
122
|
+
if isinstance(cost, dict):
|
|
123
|
+
total = cost.get("total")
|
|
124
|
+
if isinstance(total, (int, float)):
|
|
125
|
+
return float(total)
|
|
126
|
+
return 0.0
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _merge_usage(result: AgentResult, api_calls: int) -> Optional[Dict[str, Any]]:
|
|
130
|
+
if api_calls == 0 and result.usage is None:
|
|
131
|
+
return result.usage
|
|
132
|
+
usage: Dict[str, Any] = {}
|
|
133
|
+
if isinstance(result.usage, dict):
|
|
134
|
+
usage.update(result.usage)
|
|
135
|
+
if api_calls:
|
|
136
|
+
usage["api_calls"] = api_calls
|
|
137
|
+
return usage
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# Registry of execution types
|
|
141
|
+
_EXECUTION_TYPES: Dict[str, type] = {}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def register_execution_type(name: str):
|
|
145
|
+
"""Decorator to register an execution type."""
|
|
146
|
+
def decorator(cls):
|
|
147
|
+
_EXECUTION_TYPES[name] = cls
|
|
148
|
+
return cls
|
|
149
|
+
return decorator
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_execution_type(config: Optional[Dict[str, Any]] = None) -> "ExecutionType":
|
|
153
|
+
"""Get an execution type instance from config."""
|
|
154
|
+
if config is None:
|
|
155
|
+
return DefaultExecution()
|
|
156
|
+
|
|
157
|
+
type_name = config.get("type", "default")
|
|
158
|
+
if type_name not in _EXECUTION_TYPES:
|
|
159
|
+
raise ValueError(f"Unknown execution type: {type_name}")
|
|
160
|
+
|
|
161
|
+
cls = _EXECUTION_TYPES[type_name]
|
|
162
|
+
return cls.from_config(config)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ExecutionType(ABC):
|
|
166
|
+
"""Base class for execution types."""
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def from_config(cls, config: Dict[str, Any]) -> "ExecutionType":
|
|
171
|
+
"""Create instance from YAML config."""
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
@abstractmethod
|
|
175
|
+
async def execute(
|
|
176
|
+
self,
|
|
177
|
+
executor: AgentExecutor,
|
|
178
|
+
input_data: Dict[str, Any],
|
|
179
|
+
context: Optional[Dict[str, Any]] = None,
|
|
180
|
+
) -> AgentResult:
|
|
181
|
+
"""
|
|
182
|
+
Execute the agent with this execution type.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
executor: The AgentExecutor to call
|
|
186
|
+
input_data: Input data for the agent
|
|
187
|
+
context: Current machine context
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
AgentResult
|
|
191
|
+
"""
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@register_execution_type("default")
|
|
196
|
+
class DefaultExecution(ExecutionType):
|
|
197
|
+
"""Standard single agent call."""
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def from_config(cls, config: Dict[str, Any]) -> "DefaultExecution":
|
|
201
|
+
return cls()
|
|
202
|
+
|
|
203
|
+
async def execute(
|
|
204
|
+
self,
|
|
205
|
+
executor: AgentExecutor,
|
|
206
|
+
input_data: Dict[str, Any],
|
|
207
|
+
context: Optional[Dict[str, Any]] = None,
|
|
208
|
+
) -> AgentResult:
|
|
209
|
+
"""Single agent call."""
|
|
210
|
+
result = await executor.execute(input_data, context=context)
|
|
211
|
+
return coerce_agent_result(result)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# Parallel Execution Type
|
|
215
|
+
|
|
216
|
+
@register_execution_type("parallel")
|
|
217
|
+
class ParallelExecution(ExecutionType):
|
|
218
|
+
"""
|
|
219
|
+
Run N samples in parallel, return all results.
|
|
220
|
+
|
|
221
|
+
Useful for getting multiple diverse responses to compare or aggregate.
|
|
222
|
+
|
|
223
|
+
Example YAML:
|
|
224
|
+
execution:
|
|
225
|
+
type: parallel
|
|
226
|
+
n_samples: 5
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(self, n_samples: int = 3):
|
|
230
|
+
self.n_samples = n_samples
|
|
231
|
+
|
|
232
|
+
@classmethod
|
|
233
|
+
def from_config(cls, config: Dict[str, Any]) -> "ParallelExecution":
|
|
234
|
+
return cls(
|
|
235
|
+
n_samples=config.get("n_samples", 3)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
async def execute(
|
|
239
|
+
self,
|
|
240
|
+
executor: AgentExecutor,
|
|
241
|
+
input_data: Dict[str, Any],
|
|
242
|
+
context: Optional[Dict[str, Any]] = None,
|
|
243
|
+
) -> AgentResult:
|
|
244
|
+
"""Run N agent calls in parallel, return all results."""
|
|
245
|
+
async def single_call() -> AgentResult:
|
|
246
|
+
result = await executor.execute(input_data, context=context)
|
|
247
|
+
return coerce_agent_result(result)
|
|
248
|
+
|
|
249
|
+
# Run all samples in parallel
|
|
250
|
+
tasks = [single_call() for _ in range(self.n_samples)]
|
|
251
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
252
|
+
|
|
253
|
+
# Filter out exceptions
|
|
254
|
+
valid_results = [r for r in results if not isinstance(r, Exception)]
|
|
255
|
+
|
|
256
|
+
if not valid_results:
|
|
257
|
+
return AgentResult()
|
|
258
|
+
|
|
259
|
+
payloads = [result.output_payload() for result in valid_results]
|
|
260
|
+
total_api_calls = sum(_extract_api_calls(result) for result in valid_results)
|
|
261
|
+
total_cost = sum(_extract_cost(result) for result in valid_results)
|
|
262
|
+
|
|
263
|
+
usage = {"api_calls": total_api_calls} if total_api_calls else None
|
|
264
|
+
cost = total_cost if total_cost else None
|
|
265
|
+
|
|
266
|
+
return AgentResult(
|
|
267
|
+
output={"results": payloads, "count": len(payloads)},
|
|
268
|
+
raw={"results": valid_results},
|
|
269
|
+
usage=usage,
|
|
270
|
+
cost=cost,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Retry Execution Type
|
|
275
|
+
|
|
276
|
+
@register_execution_type("retry")
|
|
277
|
+
class RetryExecution(ExecutionType):
|
|
278
|
+
"""
|
|
279
|
+
Retry on failure with configurable backoff delays and jitter.
|
|
280
|
+
|
|
281
|
+
Default backoffs [2, 8, 16, 35] total 61 seconds, intended to wait
|
|
282
|
+
for a fresh RPM (requests per minute) bucket.
|
|
283
|
+
|
|
284
|
+
Example YAML:
|
|
285
|
+
execution:
|
|
286
|
+
type: retry
|
|
287
|
+
backoffs: [2, 8, 16, 35] # Backoff delays in seconds
|
|
288
|
+
jitter: 0.1 # Random jitter factor (0.1 = ±10%)
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
# Default backoffs: 2 + 8 + 16 + 35 = 61 seconds (wait for fresh RPM bucket)
|
|
292
|
+
DEFAULT_BACKOFFS = [2, 8, 16, 35]
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
backoffs: Optional[List[float]] = None,
|
|
297
|
+
jitter: float = 0.1,
|
|
298
|
+
retry_on_empty: bool = False
|
|
299
|
+
):
|
|
300
|
+
self.backoffs = backoffs if backoffs is not None else self.DEFAULT_BACKOFFS
|
|
301
|
+
self.jitter = jitter
|
|
302
|
+
self.retry_on_empty = retry_on_empty
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def from_config(cls, config: Dict[str, Any]) -> "RetryExecution":
|
|
306
|
+
return cls(
|
|
307
|
+
backoffs=config.get("backoffs"),
|
|
308
|
+
jitter=config.get("jitter", 0.1),
|
|
309
|
+
retry_on_empty=config.get("retry_on_empty", False)
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def _apply_jitter(self, delay: float) -> float:
|
|
313
|
+
"""Apply random jitter to a delay."""
|
|
314
|
+
import random
|
|
315
|
+
jitter_range = delay * self.jitter
|
|
316
|
+
return delay + random.uniform(-jitter_range, jitter_range)
|
|
317
|
+
|
|
318
|
+
async def execute(
|
|
319
|
+
self,
|
|
320
|
+
executor: AgentExecutor,
|
|
321
|
+
input_data: Dict[str, Any],
|
|
322
|
+
context: Optional[Dict[str, Any]] = None,
|
|
323
|
+
) -> AgentResult:
|
|
324
|
+
"""Execute with retries on failure."""
|
|
325
|
+
last_error = None
|
|
326
|
+
max_attempts = len(self.backoffs) + 1 # Initial attempt + retries
|
|
327
|
+
total_api_calls = 0
|
|
328
|
+
total_cost = 0.0
|
|
329
|
+
|
|
330
|
+
for attempt in range(max_attempts):
|
|
331
|
+
try:
|
|
332
|
+
result = await executor.execute(input_data, context=context)
|
|
333
|
+
agent_result = coerce_agent_result(result)
|
|
334
|
+
total_api_calls += _extract_api_calls(agent_result)
|
|
335
|
+
total_cost += _extract_cost(agent_result)
|
|
336
|
+
payload = agent_result.output_payload()
|
|
337
|
+
|
|
338
|
+
merged_usage = _merge_usage(agent_result, total_api_calls)
|
|
339
|
+
merged_cost = total_cost if total_cost else agent_result.cost
|
|
340
|
+
|
|
341
|
+
if payload:
|
|
342
|
+
return AgentResult(
|
|
343
|
+
output=agent_result.output,
|
|
344
|
+
content=agent_result.content,
|
|
345
|
+
raw=agent_result.raw,
|
|
346
|
+
usage=merged_usage,
|
|
347
|
+
cost=merged_cost,
|
|
348
|
+
metadata=agent_result.metadata,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if self.retry_on_empty:
|
|
352
|
+
raise ValueError("Empty response from agent")
|
|
353
|
+
|
|
354
|
+
return AgentResult(
|
|
355
|
+
output=agent_result.output,
|
|
356
|
+
content=agent_result.content,
|
|
357
|
+
raw=agent_result.raw,
|
|
358
|
+
usage=merged_usage,
|
|
359
|
+
cost=merged_cost,
|
|
360
|
+
metadata=agent_result.metadata,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
last_error = e
|
|
365
|
+
logger.warning(
|
|
366
|
+
f"Attempt {attempt + 1}/{max_attempts} failed: {e}"
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# If we have more retries, wait with jitter
|
|
370
|
+
if attempt < len(self.backoffs):
|
|
371
|
+
delay = self._apply_jitter(self.backoffs[attempt])
|
|
372
|
+
logger.info(f"Retrying in {delay:.1f}s...")
|
|
373
|
+
await asyncio.sleep(delay)
|
|
374
|
+
|
|
375
|
+
# All retries exhausted
|
|
376
|
+
logger.error(f"All {max_attempts} attempts failed. Last error: {last_error}")
|
|
377
|
+
error_payload: Dict[str, Any] = {
|
|
378
|
+
"_error": str(last_error) if last_error else "LLM call failed",
|
|
379
|
+
"_error_type": type(last_error).__name__ if last_error else "UnknownError",
|
|
380
|
+
}
|
|
381
|
+
status_code = _extract_status_code(last_error)
|
|
382
|
+
if status_code is not None:
|
|
383
|
+
error_payload["_error_status_code"] = status_code
|
|
384
|
+
headers = _extract_error_headers(last_error)
|
|
385
|
+
if headers:
|
|
386
|
+
error_payload["_error_headers"] = headers
|
|
387
|
+
|
|
388
|
+
usage = {"api_calls": total_api_calls} if total_api_calls else None
|
|
389
|
+
cost = total_cost if total_cost else None
|
|
390
|
+
|
|
391
|
+
return AgentResult(output=error_payload, raw=last_error, usage=usage, cost=cost)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
# MDAP Voting Execution Type
|
|
395
|
+
|
|
396
|
+
@dataclass
|
|
397
|
+
class MDAPMetrics:
|
|
398
|
+
"""Execution metrics collected during MDAP runs."""
|
|
399
|
+
total_samples: int = 0
|
|
400
|
+
total_red_flags: int = 0
|
|
401
|
+
red_flags_by_reason: Dict[str, int] = field(default_factory=dict)
|
|
402
|
+
samples_per_step: List[int] = field(default_factory=list)
|
|
403
|
+
|
|
404
|
+
def record_red_flag(self, reason: str):
|
|
405
|
+
self.total_red_flags += 1
|
|
406
|
+
self.red_flags_by_reason[reason] = self.red_flags_by_reason.get(reason, 0) + 1
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
@register_execution_type("mdap_voting")
|
|
410
|
+
class MDAPVotingExecution(ExecutionType):
|
|
411
|
+
"""
|
|
412
|
+
Multi-sample with first-to-ahead-by-k voting.
|
|
413
|
+
|
|
414
|
+
Implements the voting algorithm from the MAKER paper.
|
|
415
|
+
"""
|
|
416
|
+
|
|
417
|
+
def __init__(
|
|
418
|
+
self,
|
|
419
|
+
k_margin: int = 3,
|
|
420
|
+
max_candidates: int = 10,
|
|
421
|
+
max_response_tokens: Optional[int] = None
|
|
422
|
+
):
|
|
423
|
+
self.k_margin = k_margin
|
|
424
|
+
self.max_candidates = max_candidates
|
|
425
|
+
self.max_response_tokens = max_response_tokens
|
|
426
|
+
self.metrics = MDAPMetrics()
|
|
427
|
+
|
|
428
|
+
# Loaded from agent metadata
|
|
429
|
+
self._patterns: Dict[str, Tuple[re.Pattern, str]] = {}
|
|
430
|
+
self._validation_schema: Optional[Dict] = None
|
|
431
|
+
|
|
432
|
+
@classmethod
|
|
433
|
+
def from_config(cls, config: Dict[str, Any]) -> "MDAPVotingExecution":
|
|
434
|
+
return cls(
|
|
435
|
+
k_margin=config.get("k_margin", 3),
|
|
436
|
+
max_candidates=config.get("max_candidates", 10),
|
|
437
|
+
max_response_tokens=config.get("max_response_tokens")
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
def _configure_from_executor(self, executor: AgentExecutor):
|
|
441
|
+
"""Load parsing and validation config from executor metadata."""
|
|
442
|
+
metadata = getattr(executor, "metadata", {}) or {}
|
|
443
|
+
|
|
444
|
+
# Check if metadata overrides execution config
|
|
445
|
+
mdap_config = metadata.get('mdap', {}) if isinstance(metadata, dict) else {}
|
|
446
|
+
if mdap_config.get('k_margin'):
|
|
447
|
+
self.k_margin = mdap_config['k_margin']
|
|
448
|
+
if mdap_config.get('max_candidates'):
|
|
449
|
+
self.max_candidates = mdap_config['max_candidates']
|
|
450
|
+
if mdap_config.get('max_response_tokens'):
|
|
451
|
+
self.max_response_tokens = mdap_config['max_response_tokens']
|
|
452
|
+
|
|
453
|
+
# Load parsing patterns
|
|
454
|
+
parsing_config = metadata.get('parsing', {}) if isinstance(metadata, dict) else {}
|
|
455
|
+
self._patterns = {}
|
|
456
|
+
for field_name, field_config in parsing_config.items():
|
|
457
|
+
pattern = field_config.get('pattern')
|
|
458
|
+
if pattern:
|
|
459
|
+
self._patterns[field_name] = (
|
|
460
|
+
re.compile(pattern, re.DOTALL),
|
|
461
|
+
field_config.get('type', 'str')
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Load validation schema
|
|
465
|
+
self._validation_schema = metadata.get('validation', None) if isinstance(metadata, dict) else None
|
|
466
|
+
|
|
467
|
+
def _parse_response(self, content: str) -> Optional[Dict[str, Any]]:
|
|
468
|
+
"""Parse LLM response using regex patterns."""
|
|
469
|
+
if not self._patterns:
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
result = {}
|
|
473
|
+
for field_name, (pattern, field_type) in self._patterns.items():
|
|
474
|
+
match = pattern.search(content)
|
|
475
|
+
if match:
|
|
476
|
+
value = match.group(1)
|
|
477
|
+
if field_type == 'json':
|
|
478
|
+
try:
|
|
479
|
+
result[field_name] = json.loads(value)
|
|
480
|
+
except json.JSONDecodeError:
|
|
481
|
+
return None
|
|
482
|
+
elif field_type == 'int':
|
|
483
|
+
try:
|
|
484
|
+
result[field_name] = int(value)
|
|
485
|
+
except ValueError:
|
|
486
|
+
return None
|
|
487
|
+
else:
|
|
488
|
+
result[field_name] = value
|
|
489
|
+
else:
|
|
490
|
+
return None
|
|
491
|
+
|
|
492
|
+
return result
|
|
493
|
+
|
|
494
|
+
def _validate_parsed(self, parsed: Dict[str, Any]) -> bool:
|
|
495
|
+
"""Validate parsed result against JSON Schema."""
|
|
496
|
+
if not self._validation_schema:
|
|
497
|
+
return True
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
import jsonschema
|
|
501
|
+
jsonschema.validate(instance=parsed, schema=self._validation_schema)
|
|
502
|
+
return True
|
|
503
|
+
except Exception:
|
|
504
|
+
return False
|
|
505
|
+
|
|
506
|
+
def _check_red_flags(self, content: str, parsed: Optional[Dict[str, Any]]) -> Optional[str]:
|
|
507
|
+
"""Check response for red flags per MAKER paper."""
|
|
508
|
+
if parsed is None:
|
|
509
|
+
return "format_error"
|
|
510
|
+
|
|
511
|
+
if not self._validate_parsed(parsed):
|
|
512
|
+
return "validation_failed"
|
|
513
|
+
|
|
514
|
+
# Only check response length if max_response_tokens is set
|
|
515
|
+
if self.max_response_tokens is not None:
|
|
516
|
+
estimated_tokens = len(content) // 4
|
|
517
|
+
if estimated_tokens > self.max_response_tokens:
|
|
518
|
+
return "length_exceeded"
|
|
519
|
+
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
def _extract_candidate(self, result: AgentResult) -> tuple[Optional[Dict[str, Any]], str]:
|
|
523
|
+
"""Extract candidate payload and content for voting."""
|
|
524
|
+
content = result.content or ""
|
|
525
|
+
|
|
526
|
+
if self._patterns:
|
|
527
|
+
parsed = self._parse_response(content) if content else None
|
|
528
|
+
return parsed, content
|
|
529
|
+
|
|
530
|
+
if result.output is not None:
|
|
531
|
+
payload = result.output
|
|
532
|
+
content_for_length = content or json.dumps(payload, sort_keys=True)
|
|
533
|
+
return payload, content_for_length
|
|
534
|
+
|
|
535
|
+
if content:
|
|
536
|
+
return {"content": content}, content
|
|
537
|
+
|
|
538
|
+
return None, content
|
|
539
|
+
|
|
540
|
+
async def execute(
|
|
541
|
+
self,
|
|
542
|
+
executor: AgentExecutor,
|
|
543
|
+
input_data: Dict[str, Any],
|
|
544
|
+
context: Optional[Dict[str, Any]] = None,
|
|
545
|
+
) -> AgentResult:
|
|
546
|
+
"""
|
|
547
|
+
Multi-sample with voting - replaces single agent call.
|
|
548
|
+
|
|
549
|
+
Returns the winning parsed response or empty AgentResult.
|
|
550
|
+
"""
|
|
551
|
+
self._configure_from_executor(executor)
|
|
552
|
+
|
|
553
|
+
votes: Counter = Counter()
|
|
554
|
+
responses: Dict[str, Any] = {}
|
|
555
|
+
num_samples = 0
|
|
556
|
+
total_api_calls = 0
|
|
557
|
+
total_cost = 0.0
|
|
558
|
+
|
|
559
|
+
for _ in range(self.max_candidates):
|
|
560
|
+
try:
|
|
561
|
+
result = await executor.execute(input_data, context=context)
|
|
562
|
+
agent_result = coerce_agent_result(result)
|
|
563
|
+
num_samples += 1
|
|
564
|
+
self.metrics.total_samples += 1
|
|
565
|
+
total_api_calls += _extract_api_calls(agent_result)
|
|
566
|
+
total_cost += _extract_cost(agent_result)
|
|
567
|
+
|
|
568
|
+
candidate, content = self._extract_candidate(agent_result)
|
|
569
|
+
if candidate is None:
|
|
570
|
+
flag_reason = "format_error"
|
|
571
|
+
else:
|
|
572
|
+
flag_reason = self._check_red_flags(content, candidate)
|
|
573
|
+
|
|
574
|
+
if flag_reason:
|
|
575
|
+
self.metrics.record_red_flag(flag_reason)
|
|
576
|
+
continue
|
|
577
|
+
|
|
578
|
+
key = json.dumps(candidate, sort_keys=True) if not isinstance(candidate, str) else candidate
|
|
579
|
+
votes[key] += 1
|
|
580
|
+
responses[key] = candidate
|
|
581
|
+
|
|
582
|
+
if votes[key] >= self.k_margin:
|
|
583
|
+
self.metrics.samples_per_step.append(num_samples)
|
|
584
|
+
usage = {"api_calls": total_api_calls} if total_api_calls else None
|
|
585
|
+
cost = total_cost if total_cost else None
|
|
586
|
+
return AgentResult(
|
|
587
|
+
output=responses[key],
|
|
588
|
+
raw=agent_result.raw,
|
|
589
|
+
usage=usage,
|
|
590
|
+
cost=cost,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if len(votes) >= 2:
|
|
594
|
+
top = votes.most_common(2)
|
|
595
|
+
if top[0][1] - top[1][1] >= self.k_margin:
|
|
596
|
+
self.metrics.samples_per_step.append(num_samples)
|
|
597
|
+
usage = {"api_calls": total_api_calls} if total_api_calls else None
|
|
598
|
+
cost = total_cost if total_cost else None
|
|
599
|
+
return AgentResult(
|
|
600
|
+
output=responses[top[0][0]],
|
|
601
|
+
usage=usage,
|
|
602
|
+
cost=cost,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
except Exception as e:
|
|
606
|
+
logger.warning(f"Sample failed: {e}")
|
|
607
|
+
continue
|
|
608
|
+
|
|
609
|
+
# Majority fallback
|
|
610
|
+
self.metrics.samples_per_step.append(num_samples)
|
|
611
|
+
usage = {"api_calls": total_api_calls} if total_api_calls else None
|
|
612
|
+
cost = total_cost if total_cost else None
|
|
613
|
+
|
|
614
|
+
if votes:
|
|
615
|
+
winner_key = votes.most_common(1)[0][0]
|
|
616
|
+
return AgentResult(output=responses[winner_key], usage=usage, cost=cost)
|
|
617
|
+
|
|
618
|
+
return AgentResult(usage=usage, cost=cost)
|
|
619
|
+
|
|
620
|
+
def get_metrics(self) -> Dict[str, Any]:
|
|
621
|
+
"""Get collected metrics."""
|
|
622
|
+
return {
|
|
623
|
+
"total_samples": self.metrics.total_samples,
|
|
624
|
+
"total_red_flags": self.metrics.total_red_flags,
|
|
625
|
+
"red_flags_by_reason": self.metrics.red_flags_by_reason,
|
|
626
|
+
"samples_per_step": self.metrics.samples_per_step,
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
__all__ = [
|
|
631
|
+
"ExecutionType",
|
|
632
|
+
"DefaultExecution",
|
|
633
|
+
"ParallelExecution",
|
|
634
|
+
"RetryExecution",
|
|
635
|
+
"MDAPVotingExecution",
|
|
636
|
+
"get_execution_type",
|
|
637
|
+
"register_execution_type",
|
|
638
|
+
]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Expression engines for flatmachines.
|
|
3
|
+
|
|
4
|
+
Provides two modes:
|
|
5
|
+
- simple: Built-in parser for basic comparisons and boolean logic (default)
|
|
6
|
+
- cel: Full CEL support via cel-python (optional extra)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class ExpressionEngine(Protocol):
|
|
14
|
+
"""Protocol for expression engines."""
|
|
15
|
+
|
|
16
|
+
def evaluate(self, expression: str, variables: Dict[str, Any]) -> Any:
|
|
17
|
+
"""
|
|
18
|
+
Evaluate an expression with the given variables.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
expression: The expression string to evaluate
|
|
22
|
+
variables: Dictionary of variable names to values
|
|
23
|
+
(e.g., {"context": {...}, "input": {...}, "output": {...}})
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The result of evaluating the expression
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_expression_engine(mode: str = "simple") -> ExpressionEngine:
|
|
32
|
+
"""
|
|
33
|
+
Get an expression engine by mode.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
mode: "simple" (default) or "cel"
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ExpressionEngine instance
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ImportError: If CEL mode requested but cel-python not installed
|
|
43
|
+
ValueError: If unknown mode
|
|
44
|
+
"""
|
|
45
|
+
if mode == "simple":
|
|
46
|
+
from .simple import SimpleExpressionEngine
|
|
47
|
+
return SimpleExpressionEngine()
|
|
48
|
+
elif mode == "cel":
|
|
49
|
+
try:
|
|
50
|
+
from .cel import CELExpressionEngine
|
|
51
|
+
return CELExpressionEngine()
|
|
52
|
+
except ImportError:
|
|
53
|
+
raise ImportError(
|
|
54
|
+
"CEL expression engine requires: pip install flatmachines[cel]"
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(f"Unknown expression engine: {mode}")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
__all__ = ["ExpressionEngine", "get_expression_engine"]
|