agent-runtime-core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +110 -0
- agent_runtime/config.py +172 -0
- agent_runtime/events/__init__.py +55 -0
- agent_runtime/events/base.py +86 -0
- agent_runtime/events/memory.py +89 -0
- agent_runtime/events/redis.py +185 -0
- agent_runtime/events/sqlite.py +168 -0
- agent_runtime/interfaces.py +390 -0
- agent_runtime/llm/__init__.py +83 -0
- agent_runtime/llm/anthropic.py +237 -0
- agent_runtime/llm/litellm_client.py +175 -0
- agent_runtime/llm/openai.py +220 -0
- agent_runtime/queue/__init__.py +55 -0
- agent_runtime/queue/base.py +167 -0
- agent_runtime/queue/memory.py +184 -0
- agent_runtime/queue/redis.py +453 -0
- agent_runtime/queue/sqlite.py +420 -0
- agent_runtime/registry.py +74 -0
- agent_runtime/runner.py +403 -0
- agent_runtime/state/__init__.py +53 -0
- agent_runtime/state/base.py +69 -0
- agent_runtime/state/memory.py +51 -0
- agent_runtime/state/redis.py +109 -0
- agent_runtime/state/sqlite.py +158 -0
- agent_runtime/tracing/__init__.py +47 -0
- agent_runtime/tracing/langfuse.py +119 -0
- agent_runtime/tracing/noop.py +34 -0
- agent_runtime_core-0.1.0.dist-info/METADATA +75 -0
- agent_runtime_core-0.1.0.dist-info/RECORD +31 -0
- agent_runtime_core-0.1.0.dist-info/WHEEL +4 -0
- agent_runtime_core-0.1.0.dist-info/licenses/LICENSE +21 -0
agent_runtime/runner.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Agent runner - executes agent runs with full lifecycle management.
|
|
3
|
+
|
|
4
|
+
The runner handles:
|
|
5
|
+
- Claiming runs from the queue
|
|
6
|
+
- Executing agent runtimes
|
|
7
|
+
- Emitting events
|
|
8
|
+
- Managing state and checkpoints
|
|
9
|
+
- Handling errors and retries
|
|
10
|
+
- Cancellation
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import asyncio
|
|
14
|
+
import traceback
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
from typing import Optional
|
|
18
|
+
from uuid import UUID
|
|
19
|
+
|
|
20
|
+
from agent_runtime.config import get_config
|
|
21
|
+
from agent_runtime.events.base import EventBus
|
|
22
|
+
from agent_runtime.interfaces import (
|
|
23
|
+
AgentRuntime,
|
|
24
|
+
EventType,
|
|
25
|
+
ErrorInfo,
|
|
26
|
+
Message,
|
|
27
|
+
RunResult,
|
|
28
|
+
ToolRegistry,
|
|
29
|
+
)
|
|
30
|
+
from agent_runtime.queue.base import RunQueue, QueuedRun
|
|
31
|
+
from agent_runtime.registry import get_runtime
|
|
32
|
+
from agent_runtime.state.base import StateStore
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class RunnerConfig:
|
|
37
|
+
"""Configuration for the agent runner."""
|
|
38
|
+
|
|
39
|
+
worker_id: str = "worker-1"
|
|
40
|
+
run_timeout_seconds: int = 300
|
|
41
|
+
heartbeat_interval_seconds: int = 30
|
|
42
|
+
lease_ttl_seconds: int = 60
|
|
43
|
+
max_retries: int = 3
|
|
44
|
+
retry_backoff_base: int = 2
|
|
45
|
+
retry_backoff_max: int = 300
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RunContextImpl:
|
|
49
|
+
"""
|
|
50
|
+
Implementation of RunContext provided to agent runtimes.
|
|
51
|
+
|
|
52
|
+
This is what agent frameworks use to interact with the runtime.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
run_id: UUID,
|
|
58
|
+
conversation_id: Optional[UUID],
|
|
59
|
+
input_messages: list[Message],
|
|
60
|
+
params: dict,
|
|
61
|
+
metadata: dict,
|
|
62
|
+
tool_registry: ToolRegistry,
|
|
63
|
+
event_bus: EventBus,
|
|
64
|
+
state_store: StateStore,
|
|
65
|
+
queue: RunQueue,
|
|
66
|
+
):
|
|
67
|
+
self._run_id = run_id
|
|
68
|
+
self._conversation_id = conversation_id
|
|
69
|
+
self._input_messages = input_messages
|
|
70
|
+
self._params = params
|
|
71
|
+
self._metadata = metadata
|
|
72
|
+
self._tool_registry = tool_registry
|
|
73
|
+
self._event_bus = event_bus
|
|
74
|
+
self._state_store = state_store
|
|
75
|
+
self._queue = queue
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def run_id(self) -> UUID:
|
|
79
|
+
return self._run_id
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def conversation_id(self) -> Optional[UUID]:
|
|
83
|
+
return self._conversation_id
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def input_messages(self) -> list[Message]:
|
|
87
|
+
return self._input_messages
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def params(self) -> dict:
|
|
91
|
+
return self._params
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def metadata(self) -> dict:
|
|
95
|
+
return self._metadata
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def tool_registry(self) -> ToolRegistry:
|
|
99
|
+
return self._tool_registry
|
|
100
|
+
|
|
101
|
+
async def emit(self, event_type: EventType | str, payload: dict) -> None:
|
|
102
|
+
"""Emit an event to the event bus."""
|
|
103
|
+
event_type_str = event_type.value if isinstance(event_type, EventType) else event_type
|
|
104
|
+
await self._event_bus.publish(self._run_id, event_type_str, payload)
|
|
105
|
+
|
|
106
|
+
async def checkpoint(self, state: dict) -> None:
|
|
107
|
+
"""Save a state checkpoint."""
|
|
108
|
+
await self._state_store.save_checkpoint(self._run_id, state)
|
|
109
|
+
await self.emit(EventType.STATE_CHECKPOINT, {"state": state})
|
|
110
|
+
|
|
111
|
+
async def get_state(self) -> Optional[dict]:
|
|
112
|
+
"""Get the last checkpointed state."""
|
|
113
|
+
return await self._state_store.get_checkpoint(self._run_id)
|
|
114
|
+
|
|
115
|
+
def cancelled(self) -> bool:
|
|
116
|
+
"""Check if cancellation has been requested."""
|
|
117
|
+
# This is synchronous for easy checking in loops
|
|
118
|
+
# We use asyncio to check the queue
|
|
119
|
+
try:
|
|
120
|
+
loop = asyncio.get_event_loop()
|
|
121
|
+
return loop.run_until_complete(self._queue.is_cancelled(self._run_id))
|
|
122
|
+
except RuntimeError:
|
|
123
|
+
# No event loop running, can't check
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AgentRunner:
|
|
128
|
+
"""
|
|
129
|
+
Runs agent executions with full lifecycle management.
|
|
130
|
+
|
|
131
|
+
The runner:
|
|
132
|
+
1. Claims runs from the queue
|
|
133
|
+
2. Looks up the appropriate runtime
|
|
134
|
+
3. Executes the runtime with a context
|
|
135
|
+
4. Handles errors, retries, and cancellation
|
|
136
|
+
5. Emits events throughout
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
queue: RunQueue,
|
|
142
|
+
event_bus: EventBus,
|
|
143
|
+
state_store: StateStore,
|
|
144
|
+
config: Optional[RunnerConfig] = None,
|
|
145
|
+
):
|
|
146
|
+
self.queue = queue
|
|
147
|
+
self.event_bus = event_bus
|
|
148
|
+
self.state_store = state_store
|
|
149
|
+
self.config = config or RunnerConfig()
|
|
150
|
+
self._running = False
|
|
151
|
+
self._current_run: Optional[UUID] = None
|
|
152
|
+
|
|
153
|
+
async def run_once(self) -> bool:
|
|
154
|
+
"""
|
|
155
|
+
Claim and execute a single run.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
True if a run was executed, False if queue was empty
|
|
159
|
+
"""
|
|
160
|
+
# Claim a run
|
|
161
|
+
queued_run = await self.queue.claim(
|
|
162
|
+
worker_id=self.config.worker_id,
|
|
163
|
+
lease_seconds=self.config.lease_ttl_seconds,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if queued_run is None:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
await self._execute_run(queued_run)
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
async def run_loop(self, poll_interval: float = 1.0) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Run continuously, processing runs from the queue.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
poll_interval: Seconds to wait between queue polls
|
|
178
|
+
"""
|
|
179
|
+
self._running = True
|
|
180
|
+
|
|
181
|
+
while self._running:
|
|
182
|
+
try:
|
|
183
|
+
executed = await self.run_once()
|
|
184
|
+
if not executed:
|
|
185
|
+
await asyncio.sleep(poll_interval)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
# Log error but keep running
|
|
188
|
+
print(f"Error in run loop: {e}")
|
|
189
|
+
await asyncio.sleep(poll_interval)
|
|
190
|
+
|
|
191
|
+
def stop(self) -> None:
|
|
192
|
+
"""Stop the run loop."""
|
|
193
|
+
self._running = False
|
|
194
|
+
|
|
195
|
+
async def _execute_run(self, queued_run: QueuedRun) -> None:
|
|
196
|
+
"""Execute a single run."""
|
|
197
|
+
run_id = queued_run.run_id
|
|
198
|
+
self._current_run = run_id
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
# Look up the runtime
|
|
202
|
+
runtime = get_runtime(queued_run.agent_key)
|
|
203
|
+
if runtime is None:
|
|
204
|
+
raise ValueError(f"Unknown agent: {queued_run.agent_key}")
|
|
205
|
+
|
|
206
|
+
# Update status
|
|
207
|
+
await self.state_store.update_run_status(run_id, "running")
|
|
208
|
+
|
|
209
|
+
# Build context
|
|
210
|
+
ctx = self._build_context(queued_run)
|
|
211
|
+
|
|
212
|
+
# Emit started event
|
|
213
|
+
await ctx.emit(EventType.RUN_STARTED, {
|
|
214
|
+
"agent_key": queued_run.agent_key,
|
|
215
|
+
"attempt": queued_run.attempt,
|
|
216
|
+
})
|
|
217
|
+
|
|
218
|
+
# Start heartbeat task
|
|
219
|
+
heartbeat_task = asyncio.create_task(
|
|
220
|
+
self._heartbeat_loop(run_id)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
# Execute with timeout
|
|
225
|
+
result = await asyncio.wait_for(
|
|
226
|
+
runtime.run(ctx),
|
|
227
|
+
timeout=self.config.run_timeout_seconds,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Success!
|
|
231
|
+
await self._handle_success(queued_run, ctx, result)
|
|
232
|
+
|
|
233
|
+
except asyncio.TimeoutError:
|
|
234
|
+
await self._handle_timeout(queued_run, ctx, runtime)
|
|
235
|
+
except asyncio.CancelledError:
|
|
236
|
+
await self._handle_cancellation(queued_run, ctx, runtime)
|
|
237
|
+
except Exception as e:
|
|
238
|
+
await self._handle_error(queued_run, ctx, runtime, e)
|
|
239
|
+
finally:
|
|
240
|
+
heartbeat_task.cancel()
|
|
241
|
+
try:
|
|
242
|
+
await heartbeat_task
|
|
243
|
+
except asyncio.CancelledError:
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
finally:
|
|
247
|
+
self._current_run = None
|
|
248
|
+
|
|
249
|
+
def _build_context(self, queued_run: QueuedRun) -> RunContextImpl:
|
|
250
|
+
"""Build a RunContext for a queued run."""
|
|
251
|
+
input_data = queued_run.input
|
|
252
|
+
|
|
253
|
+
return RunContextImpl(
|
|
254
|
+
run_id=queued_run.run_id,
|
|
255
|
+
conversation_id=input_data.get("conversation_id"),
|
|
256
|
+
input_messages=input_data.get("messages", []),
|
|
257
|
+
params=input_data.get("params", {}),
|
|
258
|
+
metadata=queued_run.metadata,
|
|
259
|
+
tool_registry=ToolRegistry(), # TODO: Load from config
|
|
260
|
+
event_bus=self.event_bus,
|
|
261
|
+
state_store=self.state_store,
|
|
262
|
+
queue=self.queue,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
async def _heartbeat_loop(self, run_id: UUID) -> None:
|
|
266
|
+
"""Send periodic heartbeats to extend the lease."""
|
|
267
|
+
while True:
|
|
268
|
+
await asyncio.sleep(self.config.heartbeat_interval_seconds)
|
|
269
|
+
|
|
270
|
+
# Extend lease
|
|
271
|
+
extended = await self.queue.extend_lease(
|
|
272
|
+
run_id=run_id,
|
|
273
|
+
worker_id=self.config.worker_id,
|
|
274
|
+
lease_seconds=self.config.lease_ttl_seconds,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
if not extended:
|
|
278
|
+
# Lost the lease
|
|
279
|
+
break
|
|
280
|
+
|
|
281
|
+
# Emit heartbeat event
|
|
282
|
+
await self.event_bus.publish(
|
|
283
|
+
run_id,
|
|
284
|
+
EventType.RUN_HEARTBEAT.value,
|
|
285
|
+
{"timestamp": datetime.now(timezone.utc).isoformat()},
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
async def _handle_success(
|
|
289
|
+
self,
|
|
290
|
+
queued_run: QueuedRun,
|
|
291
|
+
ctx: RunContextImpl,
|
|
292
|
+
result: RunResult,
|
|
293
|
+
) -> None:
|
|
294
|
+
"""Handle successful run completion."""
|
|
295
|
+
await self.state_store.update_run_status(queued_run.run_id, "succeeded")
|
|
296
|
+
|
|
297
|
+
await ctx.emit(EventType.RUN_SUCCEEDED, {
|
|
298
|
+
"final_output": result.final_output,
|
|
299
|
+
"usage": result.usage,
|
|
300
|
+
})
|
|
301
|
+
|
|
302
|
+
await self.queue.release(
|
|
303
|
+
run_id=queued_run.run_id,
|
|
304
|
+
worker_id=self.config.worker_id,
|
|
305
|
+
success=True,
|
|
306
|
+
output=result.final_output,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
async def _handle_timeout(
|
|
310
|
+
self,
|
|
311
|
+
queued_run: QueuedRun,
|
|
312
|
+
ctx: RunContextImpl,
|
|
313
|
+
runtime: AgentRuntime,
|
|
314
|
+
) -> None:
|
|
315
|
+
"""Handle run timeout."""
|
|
316
|
+
await self.state_store.update_run_status(queued_run.run_id, "timed_out")
|
|
317
|
+
|
|
318
|
+
await ctx.emit(EventType.RUN_TIMED_OUT, {
|
|
319
|
+
"timeout_seconds": self.config.run_timeout_seconds,
|
|
320
|
+
})
|
|
321
|
+
|
|
322
|
+
await self.queue.release(
|
|
323
|
+
run_id=queued_run.run_id,
|
|
324
|
+
worker_id=self.config.worker_id,
|
|
325
|
+
success=False,
|
|
326
|
+
error={"type": "TimeoutError", "message": "Run timed out"},
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
async def _handle_cancellation(
|
|
330
|
+
self,
|
|
331
|
+
queued_run: QueuedRun,
|
|
332
|
+
ctx: RunContextImpl,
|
|
333
|
+
runtime: AgentRuntime,
|
|
334
|
+
) -> None:
|
|
335
|
+
"""Handle run cancellation."""
|
|
336
|
+
await runtime.cancel(ctx)
|
|
337
|
+
await self.state_store.update_run_status(queued_run.run_id, "cancelled")
|
|
338
|
+
|
|
339
|
+
await ctx.emit(EventType.RUN_CANCELLED, {})
|
|
340
|
+
|
|
341
|
+
await self.queue.release(
|
|
342
|
+
run_id=queued_run.run_id,
|
|
343
|
+
worker_id=self.config.worker_id,
|
|
344
|
+
success=False,
|
|
345
|
+
error={"type": "CancelledError", "message": "Run was cancelled"},
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
async def _handle_error(
|
|
349
|
+
self,
|
|
350
|
+
queued_run: QueuedRun,
|
|
351
|
+
ctx: RunContextImpl,
|
|
352
|
+
runtime: AgentRuntime,
|
|
353
|
+
error: Exception,
|
|
354
|
+
) -> None:
|
|
355
|
+
"""Handle run error."""
|
|
356
|
+
# Get error info from runtime
|
|
357
|
+
error_info = await runtime.on_error(ctx, error)
|
|
358
|
+
if error_info is None:
|
|
359
|
+
error_info = ErrorInfo(
|
|
360
|
+
type=type(error).__name__,
|
|
361
|
+
message=str(error),
|
|
362
|
+
stack=traceback.format_exc(),
|
|
363
|
+
retriable=True,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
error_dict = {
|
|
367
|
+
"type": error_info.type,
|
|
368
|
+
"message": error_info.message,
|
|
369
|
+
"stack": error_info.stack,
|
|
370
|
+
"retriable": error_info.retriable,
|
|
371
|
+
"details": error_info.details,
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
# Check if we should retry
|
|
375
|
+
if error_info.retriable and queued_run.attempt < self.config.max_retries:
|
|
376
|
+
# Calculate backoff
|
|
377
|
+
delay = min(
|
|
378
|
+
self.config.retry_backoff_base ** queued_run.attempt,
|
|
379
|
+
self.config.retry_backoff_max,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
requeued = await self.queue.requeue_for_retry(
|
|
383
|
+
run_id=queued_run.run_id,
|
|
384
|
+
worker_id=self.config.worker_id,
|
|
385
|
+
error=error_dict,
|
|
386
|
+
delay_seconds=delay,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if requeued:
|
|
390
|
+
await self.state_store.update_run_status(queued_run.run_id, "pending")
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
# No retry - mark as failed
|
|
394
|
+
await self.state_store.update_run_status(queued_run.run_id, "failed")
|
|
395
|
+
|
|
396
|
+
await ctx.emit(EventType.RUN_FAILED, error_dict)
|
|
397
|
+
|
|
398
|
+
await self.queue.release(
|
|
399
|
+
run_id=queued_run.run_id,
|
|
400
|
+
worker_id=self.config.worker_id,
|
|
401
|
+
success=False,
|
|
402
|
+
error=error_dict,
|
|
403
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""
|
|
2
|
+
State store implementations for agent checkpoints and run state.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- StateStore: Abstract interface
|
|
6
|
+
- InMemoryStateStore: For testing and simple use cases
|
|
7
|
+
- RedisStateStore: For production with Redis
|
|
8
|
+
- SQLiteStateStore: For persistent local storage
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from agent_runtime.state.base import StateStore
|
|
12
|
+
from agent_runtime.state.memory import InMemoryStateStore
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"StateStore",
|
|
16
|
+
"InMemoryStateStore",
|
|
17
|
+
"get_state_store",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_state_store(backend: str = None, **kwargs) -> StateStore:
|
|
22
|
+
"""
|
|
23
|
+
Factory function to get a state store.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
backend: "memory", "redis", or "sqlite"
|
|
27
|
+
**kwargs: Backend-specific configuration
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
StateStore instance
|
|
31
|
+
"""
|
|
32
|
+
from agent_runtime.config import get_config
|
|
33
|
+
|
|
34
|
+
config = get_config()
|
|
35
|
+
backend = backend or config.state_store_backend
|
|
36
|
+
|
|
37
|
+
if backend == "memory":
|
|
38
|
+
return InMemoryStateStore()
|
|
39
|
+
|
|
40
|
+
elif backend == "redis":
|
|
41
|
+
from agent_runtime.state.redis import RedisStateStore
|
|
42
|
+
url = kwargs.get("url") or config.redis_url
|
|
43
|
+
if not url:
|
|
44
|
+
raise ValueError("redis_url is required for redis state store backend")
|
|
45
|
+
return RedisStateStore(url=url)
|
|
46
|
+
|
|
47
|
+
elif backend == "sqlite":
|
|
48
|
+
from agent_runtime.state.sqlite import SQLiteStateStore
|
|
49
|
+
path = kwargs.get("path") or config.sqlite_path or "agent_runtime.db"
|
|
50
|
+
return SQLiteStateStore(path=path)
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError(f"Unknown state store backend: {backend}")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for state store implementations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StateStore(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract interface for state storage.
|
|
13
|
+
|
|
14
|
+
State stores handle:
|
|
15
|
+
- Run state (status, metadata)
|
|
16
|
+
- Checkpoints for recovery
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def save_checkpoint(self, run_id: UUID, state: dict) -> None:
|
|
21
|
+
"""
|
|
22
|
+
Save a checkpoint for a run.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
run_id: Run identifier
|
|
26
|
+
state: State to checkpoint
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def get_checkpoint(self, run_id: UUID) -> Optional[dict]:
|
|
32
|
+
"""
|
|
33
|
+
Get the latest checkpoint for a run.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
run_id: Run identifier
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Latest checkpoint state, or None if no checkpoint exists
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def update_run_status(self, run_id: UUID, status: str) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Update the status of a run.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
run_id: Run identifier
|
|
50
|
+
status: New status
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def get_run_status(self, run_id: UUID) -> Optional[str]:
|
|
56
|
+
"""
|
|
57
|
+
Get the status of a run.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
run_id: Run identifier
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Run status, or None if run not found
|
|
64
|
+
"""
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
async def close(self) -> None:
|
|
68
|
+
"""Close any connections. Override if needed."""
|
|
69
|
+
pass
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
In-memory state store implementation.
|
|
3
|
+
|
|
4
|
+
Good for:
|
|
5
|
+
- Unit testing
|
|
6
|
+
- Local development
|
|
7
|
+
- Simple single-process scripts
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
from agent_runtime.state.base import StateStore
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InMemoryStateStore(StateStore):
|
|
17
|
+
"""
|
|
18
|
+
In-memory state store implementation.
|
|
19
|
+
|
|
20
|
+
Stores state in memory. Data is lost when the process exits.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
# run_id -> list of checkpoints (ordered by time)
|
|
25
|
+
self._checkpoints: dict[UUID, list[dict]] = {}
|
|
26
|
+
# run_id -> status
|
|
27
|
+
self._statuses: dict[UUID, str] = {}
|
|
28
|
+
|
|
29
|
+
async def save_checkpoint(self, run_id: UUID, state: dict) -> None:
|
|
30
|
+
"""Save a checkpoint for a run."""
|
|
31
|
+
if run_id not in self._checkpoints:
|
|
32
|
+
self._checkpoints[run_id] = []
|
|
33
|
+
self._checkpoints[run_id].append(state)
|
|
34
|
+
|
|
35
|
+
async def get_checkpoint(self, run_id: UUID) -> Optional[dict]:
|
|
36
|
+
"""Get the latest checkpoint for a run."""
|
|
37
|
+
checkpoints = self._checkpoints.get(run_id, [])
|
|
38
|
+
return checkpoints[-1] if checkpoints else None
|
|
39
|
+
|
|
40
|
+
async def update_run_status(self, run_id: UUID, status: str) -> None:
|
|
41
|
+
"""Update the status of a run."""
|
|
42
|
+
self._statuses[run_id] = status
|
|
43
|
+
|
|
44
|
+
async def get_run_status(self, run_id: UUID) -> Optional[str]:
|
|
45
|
+
"""Get the status of a run."""
|
|
46
|
+
return self._statuses.get(run_id)
|
|
47
|
+
|
|
48
|
+
def clear(self) -> None:
|
|
49
|
+
"""Clear all state. Useful for testing."""
|
|
50
|
+
self._checkpoints.clear()
|
|
51
|
+
self._statuses.clear()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Redis state store implementation.
|
|
3
|
+
|
|
4
|
+
Good for:
|
|
5
|
+
- Production deployments
|
|
6
|
+
- Multi-process/distributed setups
|
|
7
|
+
- Automatic TTL-based cleanup
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Optional
|
|
13
|
+
from uuid import UUID
|
|
14
|
+
|
|
15
|
+
from agent_runtime.state.base import StateStore, Checkpoint
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RedisStateStore(StateStore):
|
|
19
|
+
"""
|
|
20
|
+
Redis-backed state store.
|
|
21
|
+
|
|
22
|
+
Stores checkpoints in Redis with optional TTL.
|
|
23
|
+
Uses sorted sets for efficient retrieval by sequence number.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
url: str = "redis://localhost:6379",
|
|
29
|
+
prefix: str = "agent_runtime:state:",
|
|
30
|
+
ttl_seconds: int = 3600 * 24, # 24 hours default
|
|
31
|
+
):
|
|
32
|
+
self.url = url
|
|
33
|
+
self.prefix = prefix
|
|
34
|
+
self.ttl_seconds = ttl_seconds
|
|
35
|
+
self._client = None
|
|
36
|
+
|
|
37
|
+
async def _get_client(self):
|
|
38
|
+
"""Get or create Redis client."""
|
|
39
|
+
if self._client is None:
|
|
40
|
+
try:
|
|
41
|
+
import redis.asyncio as redis
|
|
42
|
+
except ImportError:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"redis package is required for RedisStateStore. "
|
|
45
|
+
"Install with: pip install agent_runtime[redis]"
|
|
46
|
+
)
|
|
47
|
+
self._client = redis.from_url(self.url)
|
|
48
|
+
return self._client
|
|
49
|
+
|
|
50
|
+
def _key(self, run_id: UUID) -> str:
|
|
51
|
+
"""Get Redis key for a run's checkpoints."""
|
|
52
|
+
return f"{self.prefix}{run_id}"
|
|
53
|
+
|
|
54
|
+
async def save_checkpoint(self, checkpoint: Checkpoint) -> None:
|
|
55
|
+
"""Save a checkpoint."""
|
|
56
|
+
client = await self._get_client()
|
|
57
|
+
key = self._key(checkpoint.run_id)
|
|
58
|
+
|
|
59
|
+
# Serialize checkpoint
|
|
60
|
+
data = json.dumps(checkpoint.to_dict())
|
|
61
|
+
|
|
62
|
+
# Add to sorted set with seq as score
|
|
63
|
+
await client.zadd(key, {data: checkpoint.seq})
|
|
64
|
+
|
|
65
|
+
# Set TTL on the key
|
|
66
|
+
await client.expire(key, self.ttl_seconds)
|
|
67
|
+
|
|
68
|
+
async def get_latest_checkpoint(self, run_id: UUID) -> Optional[Checkpoint]:
|
|
69
|
+
"""Get the latest checkpoint for a run."""
|
|
70
|
+
client = await self._get_client()
|
|
71
|
+
key = self._key(run_id)
|
|
72
|
+
|
|
73
|
+
# Get highest scored item (latest seq)
|
|
74
|
+
results = await client.zrevrange(key, 0, 0)
|
|
75
|
+
if not results:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
data = json.loads(results[0])
|
|
79
|
+
return Checkpoint.from_dict(data)
|
|
80
|
+
|
|
81
|
+
async def get_checkpoints(self, run_id: UUID) -> list[Checkpoint]:
|
|
82
|
+
"""Get all checkpoints for a run."""
|
|
83
|
+
client = await self._get_client()
|
|
84
|
+
key = self._key(run_id)
|
|
85
|
+
|
|
86
|
+
# Get all items ordered by seq
|
|
87
|
+
results = await client.zrange(key, 0, -1)
|
|
88
|
+
|
|
89
|
+
return [Checkpoint.from_dict(json.loads(r)) for r in results]
|
|
90
|
+
|
|
91
|
+
async def get_next_seq(self, run_id: UUID) -> int:
|
|
92
|
+
"""Get the next sequence number for a run."""
|
|
93
|
+
latest = await self.get_latest_checkpoint(run_id)
|
|
94
|
+
return (latest.seq + 1) if latest else 0
|
|
95
|
+
|
|
96
|
+
async def delete_checkpoints(self, run_id: UUID) -> int:
|
|
97
|
+
"""Delete all checkpoints for a run."""
|
|
98
|
+
client = await self._get_client()
|
|
99
|
+
key = self._key(run_id)
|
|
100
|
+
|
|
101
|
+
count = await client.zcard(key)
|
|
102
|
+
await client.delete(key)
|
|
103
|
+
return count
|
|
104
|
+
|
|
105
|
+
async def close(self) -> None:
|
|
106
|
+
"""Close Redis connection."""
|
|
107
|
+
if self._client:
|
|
108
|
+
await self._client.close()
|
|
109
|
+
self._client = None
|