flatagents 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flatagents/__init__.py +136 -0
- flatagents/actions.py +239 -0
- flatagents/assets/__init__.py +0 -0
- flatagents/assets/flatagent.d.ts +189 -0
- flatagents/assets/flatagent.schema.json +210 -0
- flatagents/assets/flatagent.slim.d.ts +52 -0
- flatagents/assets/flatmachine.d.ts +363 -0
- flatagents/assets/flatmachine.schema.json +515 -0
- flatagents/assets/flatmachine.slim.d.ts +94 -0
- flatagents/backends.py +222 -0
- flatagents/baseagent.py +814 -0
- flatagents/execution.py +462 -0
- flatagents/expressions/__init__.py +60 -0
- flatagents/expressions/cel.py +101 -0
- flatagents/expressions/simple.py +166 -0
- flatagents/flatagent.py +735 -0
- flatagents/flatmachine.py +1176 -0
- flatagents/gcp/__init__.py +25 -0
- flatagents/gcp/firestore.py +227 -0
- flatagents/hooks.py +380 -0
- flatagents/locking.py +69 -0
- flatagents/monitoring.py +373 -0
- flatagents/persistence.py +200 -0
- flatagents/utils.py +46 -0
- flatagents/validation.py +141 -0
- flatagents-0.4.1.dist-info/METADATA +310 -0
- flatagents-0.4.1.dist-info/RECORD +28 -0
- flatagents-0.4.1.dist-info/WHEEL +4 -0
flatagents/monitoring.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Monitoring and observability utilities for FlatAgents.
|
|
3
|
+
|
|
4
|
+
Provides standardized logging configuration and OpenTelemetry-based metrics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from contextlib import contextmanager
|
|
12
|
+
from typing import Any, Dict, Optional
|
|
13
|
+
|
|
14
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
15
|
+
# Logging Configuration
|
|
16
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
17
|
+
|
|
18
|
+
# Global logger registry
|
|
19
|
+
_loggers: Dict[str, logging.Logger] = {}
|
|
20
|
+
_logging_configured = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def setup_logging(
|
|
24
|
+
level: Optional[str] = None,
|
|
25
|
+
format: Optional[str] = None,
|
|
26
|
+
force: bool = False
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Configure SDK-wide logging with sensible defaults.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
|
|
33
|
+
Defaults to FLATAGENTS_LOG_LEVEL env var or INFO.
|
|
34
|
+
format: Log format style. Options:
|
|
35
|
+
- 'standard': Human-readable with timestamps
|
|
36
|
+
- 'json': Structured JSON logging
|
|
37
|
+
- 'simple': Just level and message
|
|
38
|
+
- Custom format string
|
|
39
|
+
Defaults to FLATAGENTS_LOG_FORMAT env var or 'standard'.
|
|
40
|
+
force: If True, reconfigure even if already configured.
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> from flatagents import setup_logging
|
|
44
|
+
>>> setup_logging(level='DEBUG')
|
|
45
|
+
>>> # Or via environment:
|
|
46
|
+
>>> # export FLATAGENTS_LOG_LEVEL=DEBUG
|
|
47
|
+
>>> # export FLATAGENTS_LOG_FORMAT=json
|
|
48
|
+
"""
|
|
49
|
+
global _logging_configured
|
|
50
|
+
|
|
51
|
+
if _logging_configured and not force:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
# Determine log level
|
|
55
|
+
if level is None:
|
|
56
|
+
level = os.getenv('FLATAGENTS_LOG_LEVEL', 'INFO').upper()
|
|
57
|
+
|
|
58
|
+
log_level = getattr(logging, level.upper(), logging.INFO)
|
|
59
|
+
|
|
60
|
+
# Determine format
|
|
61
|
+
if format is None:
|
|
62
|
+
format = os.getenv('FLATAGENTS_LOG_FORMAT', 'standard')
|
|
63
|
+
|
|
64
|
+
if format == 'json':
|
|
65
|
+
# Structured JSON logging - note: message content should be escaped by caller for true JSON safety
|
|
66
|
+
log_format = '{"time":"%(asctime)s","name":"%(name)s","level":"%(levelname)s","message":%(message)s}'
|
|
67
|
+
elif format == 'simple':
|
|
68
|
+
log_format = '%(levelname)s - %(message)s'
|
|
69
|
+
elif format == 'standard':
|
|
70
|
+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
71
|
+
else:
|
|
72
|
+
# Custom format string
|
|
73
|
+
log_format = format
|
|
74
|
+
|
|
75
|
+
# Configure root logger for the SDK
|
|
76
|
+
logging.basicConfig(
|
|
77
|
+
level=log_level,
|
|
78
|
+
format=log_format,
|
|
79
|
+
datefmt='%Y-%m-%d %H:%M:%S',
|
|
80
|
+
stream=sys.stdout,
|
|
81
|
+
force=force
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
_logging_configured = True
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_logger(name: str) -> logging.Logger:
|
|
88
|
+
"""
|
|
89
|
+
Get a properly configured logger for a module.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
name: Logger name (typically __name__ from the calling module)
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Configured logger instance
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
>>> from flatagents import get_logger
|
|
99
|
+
>>> logger = get_logger(__name__)
|
|
100
|
+
>>> logger.info("Agent started")
|
|
101
|
+
"""
|
|
102
|
+
if name not in _loggers:
|
|
103
|
+
# Ensure logging is configured
|
|
104
|
+
if not _logging_configured:
|
|
105
|
+
setup_logging()
|
|
106
|
+
|
|
107
|
+
logger = logging.getLogger(name)
|
|
108
|
+
_loggers[name] = logger
|
|
109
|
+
|
|
110
|
+
return _loggers[name]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
114
|
+
# Metrics with OpenTelemetry
|
|
115
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
116
|
+
|
|
117
|
+
# Lazy imports for OpenTelemetry (optional dependency)
|
|
118
|
+
_otel_available = False
|
|
119
|
+
_meter = None
|
|
120
|
+
_metrics_enabled = False
|
|
121
|
+
_cached_histograms: Dict[str, Any] = {} # Cache for histogram instruments
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
from opentelemetry import metrics
|
|
125
|
+
from opentelemetry.sdk.metrics import MeterProvider
|
|
126
|
+
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
|
127
|
+
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
|
128
|
+
from opentelemetry.sdk.resources import Resource, SERVICE_NAME
|
|
129
|
+
_otel_available = True
|
|
130
|
+
except ImportError:
|
|
131
|
+
_otel_available = False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _init_metrics() -> None:
|
|
135
|
+
"""Initialize OpenTelemetry metrics if enabled and available."""
|
|
136
|
+
global _meter, _metrics_enabled
|
|
137
|
+
|
|
138
|
+
# Check if metrics should be enabled
|
|
139
|
+
enabled = os.getenv('FLATAGENTS_METRICS_ENABLED', 'false').lower() in ('true', '1', 'yes')
|
|
140
|
+
|
|
141
|
+
if not enabled:
|
|
142
|
+
_metrics_enabled = False
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
if not _otel_available:
|
|
146
|
+
logger = get_logger(__name__)
|
|
147
|
+
logger.warning(
|
|
148
|
+
"Metrics enabled but OpenTelemetry not available. "
|
|
149
|
+
"Install with: pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp"
|
|
150
|
+
)
|
|
151
|
+
_metrics_enabled = False
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
# Get service name from environment or use default
|
|
156
|
+
service_name = os.getenv('OTEL_SERVICE_NAME', 'flatagents')
|
|
157
|
+
|
|
158
|
+
# Create resource with service name
|
|
159
|
+
resource = Resource(attributes={
|
|
160
|
+
SERVICE_NAME: service_name
|
|
161
|
+
})
|
|
162
|
+
|
|
163
|
+
# Check which exporter to use
|
|
164
|
+
exporter_type = os.getenv('OTEL_METRICS_EXPORTER', 'otlp').lower()
|
|
165
|
+
|
|
166
|
+
if exporter_type == 'console':
|
|
167
|
+
# Use console exporter for testing/debugging
|
|
168
|
+
try:
|
|
169
|
+
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter
|
|
170
|
+
exporter = ConsoleMetricExporter()
|
|
171
|
+
except ImportError:
|
|
172
|
+
logger = get_logger(__name__)
|
|
173
|
+
logger.warning("Console exporter not available, falling back to OTLP")
|
|
174
|
+
exporter = OTLPMetricExporter(
|
|
175
|
+
endpoint=os.getenv('OTEL_EXPORTER_OTLP_ENDPOINT'),
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
# Configure OTLP exporter (supports Datadog, Honeycomb, etc.)
|
|
179
|
+
exporter = OTLPMetricExporter(
|
|
180
|
+
endpoint=os.getenv('OTEL_EXPORTER_OTLP_ENDPOINT'),
|
|
181
|
+
# Headers can be set via OTEL_EXPORTER_OTLP_HEADERS env var
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Create meter provider with periodic export
|
|
185
|
+
reader = PeriodicExportingMetricReader(
|
|
186
|
+
exporter=exporter,
|
|
187
|
+
export_interval_millis=int(os.getenv('OTEL_METRIC_EXPORT_INTERVAL', '5000' if exporter_type == 'console' else '60000'))
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
provider = MeterProvider(
|
|
191
|
+
resource=resource,
|
|
192
|
+
metric_readers=[reader]
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
metrics.set_meter_provider(provider)
|
|
196
|
+
_meter = metrics.get_meter(__name__)
|
|
197
|
+
_metrics_enabled = True
|
|
198
|
+
|
|
199
|
+
logger = get_logger(__name__)
|
|
200
|
+
logger.info(f"OpenTelemetry metrics enabled for service: {service_name}")
|
|
201
|
+
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger = get_logger(__name__)
|
|
204
|
+
logger.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
|
|
205
|
+
_metrics_enabled = False
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def get_meter():
|
|
209
|
+
"""
|
|
210
|
+
Get the OpenTelemetry meter for creating custom metrics.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
OpenTelemetry Meter instance or None if metrics disabled
|
|
214
|
+
|
|
215
|
+
Example:
|
|
216
|
+
>>> from flatagents import get_meter
|
|
217
|
+
>>> meter = get_meter()
|
|
218
|
+
>>> if meter:
|
|
219
|
+
... counter = meter.create_counter("my_custom_metric")
|
|
220
|
+
... counter.add(1, {"attribute": "value"})
|
|
221
|
+
"""
|
|
222
|
+
global _meter
|
|
223
|
+
|
|
224
|
+
if _meter is None and not _metrics_enabled:
|
|
225
|
+
_init_metrics()
|
|
226
|
+
|
|
227
|
+
return _meter
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class AgentMonitor:
|
|
231
|
+
"""
|
|
232
|
+
Context manager for tracking agent execution metrics.
|
|
233
|
+
|
|
234
|
+
Automatically tracks:
|
|
235
|
+
- Execution duration
|
|
236
|
+
- Success/failure status
|
|
237
|
+
- Custom metrics via the metrics dict
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
>>> from flatagents import AgentMonitor
|
|
241
|
+
>>> with AgentMonitor("my-agent") as monitor:
|
|
242
|
+
... # Do agent work
|
|
243
|
+
... monitor.metrics["tokens"] = 1500
|
|
244
|
+
... monitor.metrics["cost"] = 0.03
|
|
245
|
+
>>> # Metrics automatically emitted on exit
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
def __init__(self, agent_id: str, extra_attributes: Optional[Dict[str, Any]] = None):
|
|
249
|
+
"""
|
|
250
|
+
Initialize the monitor.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
agent_id: Identifier for this agent/operation
|
|
254
|
+
extra_attributes: Additional attributes to attach to all metrics
|
|
255
|
+
"""
|
|
256
|
+
self.agent_id = agent_id
|
|
257
|
+
self.start_time = None
|
|
258
|
+
self.metrics: Dict[str, Any] = {}
|
|
259
|
+
self.extra_attributes = extra_attributes or {}
|
|
260
|
+
self.logger = get_logger(f"flatagents.monitor.{agent_id}")
|
|
261
|
+
|
|
262
|
+
# Get or create metric instruments
|
|
263
|
+
self._meter = get_meter()
|
|
264
|
+
if self._meter:
|
|
265
|
+
self._duration_histogram = self._meter.create_histogram(
|
|
266
|
+
"flatagents.agent.duration",
|
|
267
|
+
unit="ms",
|
|
268
|
+
description="Agent execution duration"
|
|
269
|
+
)
|
|
270
|
+
self._token_counter = self._meter.create_counter(
|
|
271
|
+
"flatagents.agent.tokens",
|
|
272
|
+
description="Tokens used by agent"
|
|
273
|
+
)
|
|
274
|
+
self._cost_counter = self._meter.create_counter(
|
|
275
|
+
"flatagents.agent.cost",
|
|
276
|
+
description="Estimated cost of agent execution"
|
|
277
|
+
)
|
|
278
|
+
self._status_counter = self._meter.create_counter(
|
|
279
|
+
"flatagents.agent.executions",
|
|
280
|
+
description="Agent execution count by status"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def __enter__(self):
|
|
284
|
+
"""Start monitoring."""
|
|
285
|
+
self.start_time = time.time()
|
|
286
|
+
self.logger.debug(f"Agent {self.agent_id} started")
|
|
287
|
+
return self
|
|
288
|
+
|
|
289
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
290
|
+
"""Stop monitoring and emit metrics."""
|
|
291
|
+
duration_ms = (time.time() - self.start_time) * 1000
|
|
292
|
+
status = "success" if exc_type is None else "error"
|
|
293
|
+
|
|
294
|
+
# Build attributes
|
|
295
|
+
attributes = {
|
|
296
|
+
"agent_id": self.agent_id,
|
|
297
|
+
"status": status,
|
|
298
|
+
**self.extra_attributes
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
if exc_type is not None:
|
|
302
|
+
attributes["error_type"] = exc_type.__name__
|
|
303
|
+
|
|
304
|
+
# Log completion
|
|
305
|
+
self.logger.info(
|
|
306
|
+
f"Agent {self.agent_id} completed in {duration_ms:.2f}ms - {status}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Emit metrics if enabled
|
|
310
|
+
if self._meter:
|
|
311
|
+
self._duration_histogram.record(duration_ms, attributes)
|
|
312
|
+
self._status_counter.add(1, attributes)
|
|
313
|
+
|
|
314
|
+
if "tokens" in self.metrics:
|
|
315
|
+
self._token_counter.add(self.metrics["tokens"], attributes)
|
|
316
|
+
|
|
317
|
+
if "cost" in self.metrics:
|
|
318
|
+
self._cost_counter.add(self.metrics["cost"], attributes)
|
|
319
|
+
|
|
320
|
+
# Don't suppress exceptions
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
325
|
+
# Convenience context manager for temporary metrics
|
|
326
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
327
|
+
|
|
328
|
+
@contextmanager
|
|
329
|
+
def track_operation(operation_name: str, **attributes):
|
|
330
|
+
"""
|
|
331
|
+
Track duration of an operation.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
operation_name: Name of the operation
|
|
335
|
+
**attributes: Additional attributes to attach
|
|
336
|
+
|
|
337
|
+
Example:
|
|
338
|
+
>>> from flatagents.monitoring import track_operation
|
|
339
|
+
>>> with track_operation("llm_call", model="gpt-4"):
|
|
340
|
+
... response = await llm.call(messages)
|
|
341
|
+
"""
|
|
342
|
+
meter = get_meter()
|
|
343
|
+
start_time = time.time()
|
|
344
|
+
|
|
345
|
+
try:
|
|
346
|
+
yield
|
|
347
|
+
status = "success"
|
|
348
|
+
except Exception as e:
|
|
349
|
+
status = "error"
|
|
350
|
+
attributes["error_type"] = type(e).__name__
|
|
351
|
+
raise
|
|
352
|
+
finally:
|
|
353
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
354
|
+
|
|
355
|
+
if meter:
|
|
356
|
+
# Cache histogram to avoid recreating on each call
|
|
357
|
+
cache_key = f"flatagents.{operation_name}.duration"
|
|
358
|
+
if cache_key not in _cached_histograms:
|
|
359
|
+
_cached_histograms[cache_key] = meter.create_histogram(
|
|
360
|
+
cache_key,
|
|
361
|
+
unit="ms",
|
|
362
|
+
description=f"Duration of {operation_name}"
|
|
363
|
+
)
|
|
364
|
+
_cached_histograms[cache_key].record(duration_ms, {**attributes, "status": status})
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
__all__ = [
|
|
368
|
+
"setup_logging",
|
|
369
|
+
"get_logger",
|
|
370
|
+
"get_meter",
|
|
371
|
+
"AgentMonitor",
|
|
372
|
+
"track_operation",
|
|
373
|
+
]
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import fcntl
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Dict, Optional, List
|
|
7
|
+
from dataclasses import dataclass, asdict, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
import aiofiles
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class MachineSnapshot:
|
|
16
|
+
"""Wire format for machine checkpoints."""
|
|
17
|
+
execution_id: str
|
|
18
|
+
machine_name: str
|
|
19
|
+
spec_version: str
|
|
20
|
+
current_state: str
|
|
21
|
+
context: Dict[str, Any]
|
|
22
|
+
step: int
|
|
23
|
+
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
24
|
+
event: Optional[str] = None # The event that triggered this checkpoint (machine_start, etc)
|
|
25
|
+
output: Optional[Dict[str, Any]] = None # Output if captured at state_exit/machine_end
|
|
26
|
+
total_api_calls: Optional[int] = None # Cumulative API calls
|
|
27
|
+
total_cost: Optional[float] = None # Cumulative cost
|
|
28
|
+
# Lineage (v0.4.0)
|
|
29
|
+
parent_execution_id: Optional[str] = None # ID of launcher machine if this was launched
|
|
30
|
+
# Outbox pattern (v0.4.0)
|
|
31
|
+
pending_launches: Optional[List[Dict[str, Any]]] = None # LaunchIntent dicts awaiting completion
|
|
32
|
+
|
|
33
|
+
class PersistenceBackend(ABC):
|
|
34
|
+
"""Abstract storage backend for checkpoints."""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def save(self, key: str, value: bytes) -> None:
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def load(self, key: str) -> Optional[bytes]:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
async def delete(self, key: str) -> None:
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
class LocalFileBackend(PersistenceBackend):
|
|
49
|
+
"""File-based persistence backend."""
|
|
50
|
+
|
|
51
|
+
def __init__(self, base_dir: str = ".checkpoints"):
|
|
52
|
+
self.base_dir = Path(base_dir)
|
|
53
|
+
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
54
|
+
|
|
55
|
+
async def save(self, key: str, value: bytes) -> None:
|
|
56
|
+
path = self.base_dir / key
|
|
57
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
58
|
+
async with aiofiles.open(path, 'wb') as f:
|
|
59
|
+
await f.write(value)
|
|
60
|
+
|
|
61
|
+
async def load(self, key: str) -> Optional[bytes]:
|
|
62
|
+
path = self.base_dir / key
|
|
63
|
+
if not path.exists():
|
|
64
|
+
return None
|
|
65
|
+
async with aiofiles.open(path, 'rb') as f:
|
|
66
|
+
return await f.read()
|
|
67
|
+
|
|
68
|
+
async def delete(self, key: str) -> None:
|
|
69
|
+
path = self.base_dir / key
|
|
70
|
+
if path.exists():
|
|
71
|
+
path.unlink()
|
|
72
|
+
|
|
73
|
+
class MemoryBackend(PersistenceBackend):
|
|
74
|
+
"""In-memory backend for ephemeral executions."""
|
|
75
|
+
|
|
76
|
+
def __init__(self):
|
|
77
|
+
self._store: Dict[str, bytes] = {}
|
|
78
|
+
|
|
79
|
+
async def save(self, key: str, value: bytes) -> None:
|
|
80
|
+
self._store[key] = value
|
|
81
|
+
|
|
82
|
+
async def load(self, key: str) -> Optional[bytes]:
|
|
83
|
+
return self._store.get(key)
|
|
84
|
+
|
|
85
|
+
async def delete(self, key: str) -> None:
|
|
86
|
+
self._store.pop(key, None)
|
|
87
|
+
|
|
88
|
+
class CheckpointManager:
|
|
89
|
+
"""Manages saving and loading machine snapshots."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, backend: PersistenceBackend, execution_id: str):
|
|
92
|
+
self.backend = backend
|
|
93
|
+
self.execution_id = execution_id
|
|
94
|
+
|
|
95
|
+
def _snapshot_key(self, event: str, step: int) -> str:
|
|
96
|
+
"""Generate key for specific snapshot."""
|
|
97
|
+
return f"{self.execution_id}/step_{step:06d}_{event}.json"
|
|
98
|
+
|
|
99
|
+
def _latest_pointer_key(self) -> str:
|
|
100
|
+
"""Key that points to the latest snapshot."""
|
|
101
|
+
return f"{self.execution_id}/latest"
|
|
102
|
+
|
|
103
|
+
def _safe_serialize_value(self, value: Any, path: str, non_serializable: List[str]) -> Any:
|
|
104
|
+
"""Recursively serialize a value, converting non-JSON types to strings."""
|
|
105
|
+
if isinstance(value, dict):
|
|
106
|
+
result = {}
|
|
107
|
+
for k, v in value.items():
|
|
108
|
+
try:
|
|
109
|
+
json.dumps({k: v})
|
|
110
|
+
result[k] = v
|
|
111
|
+
except (TypeError, OverflowError):
|
|
112
|
+
result[k] = self._safe_serialize_value(v, f"{path}.{k}", non_serializable)
|
|
113
|
+
return result
|
|
114
|
+
elif isinstance(value, list):
|
|
115
|
+
result = []
|
|
116
|
+
for i, item in enumerate(value):
|
|
117
|
+
try:
|
|
118
|
+
json.dumps(item)
|
|
119
|
+
result.append(item)
|
|
120
|
+
except (TypeError, OverflowError):
|
|
121
|
+
result.append(self._safe_serialize_value(item, f"{path}[{i}]", non_serializable))
|
|
122
|
+
return result
|
|
123
|
+
else:
|
|
124
|
+
try:
|
|
125
|
+
json.dumps(value)
|
|
126
|
+
return value
|
|
127
|
+
except (TypeError, OverflowError):
|
|
128
|
+
original_type = type(value).__name__
|
|
129
|
+
non_serializable.append(f"{path} ({original_type})")
|
|
130
|
+
return str(value)
|
|
131
|
+
|
|
132
|
+
def _safe_serialize(self, data: Dict[str, Any]) -> str:
|
|
133
|
+
"""Safely serialize data to JSON, handling non-serializable objects."""
|
|
134
|
+
try:
|
|
135
|
+
return json.dumps(data)
|
|
136
|
+
except (TypeError, OverflowError):
|
|
137
|
+
# Identify and warn about specific non-serializable fields
|
|
138
|
+
safe_data = {}
|
|
139
|
+
non_serializable_fields: List[str] = []
|
|
140
|
+
|
|
141
|
+
for k, v in data.items():
|
|
142
|
+
if isinstance(v, dict):
|
|
143
|
+
# Recursively check nested dicts
|
|
144
|
+
try:
|
|
145
|
+
json.dumps(v)
|
|
146
|
+
safe_data[k] = v
|
|
147
|
+
except (TypeError, OverflowError):
|
|
148
|
+
safe_data[k] = self._safe_serialize_value(v, k, non_serializable_fields)
|
|
149
|
+
elif isinstance(v, list):
|
|
150
|
+
# Recursively check lists
|
|
151
|
+
try:
|
|
152
|
+
json.dumps(v)
|
|
153
|
+
safe_data[k] = v
|
|
154
|
+
except (TypeError, OverflowError):
|
|
155
|
+
safe_data[k] = self._safe_serialize_value(v, k, non_serializable_fields)
|
|
156
|
+
else:
|
|
157
|
+
try:
|
|
158
|
+
json.dumps({k: v})
|
|
159
|
+
safe_data[k] = v
|
|
160
|
+
except (TypeError, OverflowError):
|
|
161
|
+
original_type = type(v).__name__
|
|
162
|
+
safe_data[k] = str(v)
|
|
163
|
+
non_serializable_fields.append(f"{k} ({original_type})")
|
|
164
|
+
|
|
165
|
+
if non_serializable_fields:
|
|
166
|
+
logger.warning(
|
|
167
|
+
f"Context fields not JSON serializable, converted to strings: "
|
|
168
|
+
f"{', '.join(non_serializable_fields)}. "
|
|
169
|
+
f"These values will lose type information on restore."
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return json.dumps(safe_data)
|
|
173
|
+
|
|
174
|
+
async def save_checkpoint(self, snapshot: MachineSnapshot) -> None:
|
|
175
|
+
"""Save a snapshot and update latest pointer."""
|
|
176
|
+
data = asdict(snapshot)
|
|
177
|
+
json_bytes = self._safe_serialize(data).encode('utf-8')
|
|
178
|
+
|
|
179
|
+
# Save the immutable snapshot
|
|
180
|
+
key = self._snapshot_key(snapshot.event or "unknown", snapshot.step)
|
|
181
|
+
await self.backend.save(key, json_bytes)
|
|
182
|
+
|
|
183
|
+
# Update pointer to this key
|
|
184
|
+
await self.backend.save(self._latest_pointer_key(), key.encode('utf-8'))
|
|
185
|
+
|
|
186
|
+
async def load_latest(self) -> Optional[MachineSnapshot]:
|
|
187
|
+
"""Load the latest snapshot."""
|
|
188
|
+
# Get pointer
|
|
189
|
+
ptr_bytes = await self.backend.load(self._latest_pointer_key())
|
|
190
|
+
if not ptr_bytes:
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
# Get snapshot
|
|
194
|
+
key = ptr_bytes.decode('utf-8')
|
|
195
|
+
data_bytes = await self.backend.load(key)
|
|
196
|
+
if not data_bytes:
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
data = json.loads(data_bytes.decode('utf-8'))
|
|
200
|
+
return MachineSnapshot(**data)
|
flatagents/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Utility functions for flatagents."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def strip_markdown_json(content: str) -> str:
|
|
7
|
+
"""
|
|
8
|
+
Extract JSON from potentially wrapped response content.
|
|
9
|
+
|
|
10
|
+
LLMs sometimes wrap JSON responses in markdown code blocks like:
|
|
11
|
+
```json
|
|
12
|
+
{"key": "value"}
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Or include explanatory text before/after the JSON:
|
|
16
|
+
"Here is the result:
|
|
17
|
+
```json
|
|
18
|
+
{"key": "value"}
|
|
19
|
+
```"
|
|
20
|
+
|
|
21
|
+
This function extracts the JSON so json.loads() can parse it.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
content: Raw string that may contain markdown-wrapped JSON
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Extracted JSON string
|
|
28
|
+
"""
|
|
29
|
+
if not content:
|
|
30
|
+
return content
|
|
31
|
+
|
|
32
|
+
text = content.strip()
|
|
33
|
+
|
|
34
|
+
# First, try to find JSON in a markdown code fence (anywhere in content)
|
|
35
|
+
fence_pattern = r'```(?:json|JSON)?\s*\n?([\s\S]*?)\n?```'
|
|
36
|
+
match = re.search(fence_pattern, text)
|
|
37
|
+
if match:
|
|
38
|
+
return match.group(1).strip()
|
|
39
|
+
|
|
40
|
+
# If no fence, try to find a raw JSON object or array
|
|
41
|
+
json_pattern = r'(\{[\s\S]*\}|\[[\s\S]*\])'
|
|
42
|
+
match = re.search(json_pattern, text)
|
|
43
|
+
if match:
|
|
44
|
+
return match.group(1)
|
|
45
|
+
|
|
46
|
+
return text
|