agent-runtime-core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,453 @@
1
+ """
2
+ Redis Streams-backed queue with consumer groups.
3
+
4
+ Good for:
5
+ - Production deployments
6
+ - Multi-process/distributed setups
7
+ - High throughput
8
+ """
9
+
10
+ import json
11
+ from datetime import datetime, timedelta, timezone
12
+ from typing import Optional
13
+ from uuid import UUID
14
+
15
+ from agent_runtime.queue.base import RunQueue, QueuedRun, RunStatus
16
+
17
+
18
+ class RedisQueue(RunQueue):
19
+ """
20
+ Redis Streams-backed queue implementation.
21
+
22
+ Uses consumer groups for distributed processing.
23
+ Run state is stored in Redis hashes.
24
+ """
25
+
26
+ STREAM_KEY = "agent_runtime:queue"
27
+ RUNS_KEY = "agent_runtime:runs"
28
+ GROUP_NAME = "agent_workers"
29
+
30
+ def __init__(
31
+ self,
32
+ url: str = "redis://localhost:6379",
33
+ lease_ttl_seconds: int = 30,
34
+ stream_key: Optional[str] = None,
35
+ runs_key: Optional[str] = None,
36
+ group_name: Optional[str] = None,
37
+ ):
38
+ self.url = url
39
+ self.lease_ttl_seconds = lease_ttl_seconds
40
+ self.stream_key = stream_key or self.STREAM_KEY
41
+ self.runs_key = runs_key or self.RUNS_KEY
42
+ self.group_name = group_name or self.GROUP_NAME
43
+ self._client = None
44
+
45
+ async def _get_client(self):
46
+ """Get or create Redis client."""
47
+ if self._client is None:
48
+ try:
49
+ import redis.asyncio as redis
50
+ except ImportError:
51
+ raise ImportError(
52
+ "redis package is required for RedisQueue. "
53
+ "Install with: pip install agent_runtime[redis]"
54
+ )
55
+ self._client = redis.from_url(self.url)
56
+ # Ensure consumer group exists
57
+ try:
58
+ await self._client.xgroup_create(
59
+ self.stream_key, self.group_name, id="0", mkstream=True
60
+ )
61
+ except Exception as e:
62
+ if "BUSYGROUP" not in str(e):
63
+ raise
64
+ return self._client
65
+
66
+ def _run_key(self, run_id: UUID) -> str:
67
+ """Get Redis key for a run."""
68
+ return f"{self.runs_key}:{run_id}"
69
+
70
+ async def enqueue(
71
+ self,
72
+ run_id: UUID,
73
+ agent_key: str,
74
+ input: dict,
75
+ metadata: Optional[dict] = None,
76
+ max_attempts: int = 3,
77
+ ) -> QueuedRun:
78
+ """Add a new run to the queue."""
79
+ client = await self._get_client()
80
+ now = datetime.now(timezone.utc)
81
+
82
+ run_data = {
83
+ "run_id": str(run_id),
84
+ "agent_key": agent_key,
85
+ "input": json.dumps(input),
86
+ "metadata": json.dumps(metadata or {}),
87
+ "max_attempts": str(max_attempts),
88
+ "attempt": "1",
89
+ "status": RunStatus.QUEUED.value,
90
+ "lease_owner": "",
91
+ "lease_expires_at": "",
92
+ "cancel_requested_at": "",
93
+ "created_at": now.isoformat(),
94
+ "started_at": "",
95
+ "finished_at": "",
96
+ "output": "",
97
+ "error": "",
98
+ }
99
+
100
+ # Store run data
101
+ await client.hset(self._run_key(run_id), mapping=run_data)
102
+
103
+ # Add to stream
104
+ await client.xadd(
105
+ self.stream_key,
106
+ {"run_id": str(run_id), "agent_key": agent_key},
107
+ )
108
+
109
+ return QueuedRun(
110
+ run_id=run_id,
111
+ agent_key=agent_key,
112
+ attempt=1,
113
+ lease_expires_at=now,
114
+ input=input,
115
+ metadata=metadata or {},
116
+ max_attempts=max_attempts,
117
+ status=RunStatus.QUEUED,
118
+ )
119
+
120
+ async def claim(
121
+ self,
122
+ worker_id: str,
123
+ agent_keys: Optional[list[str]] = None,
124
+ batch_size: int = 1,
125
+ ) -> list[QueuedRun]:
126
+ """Claim runs from the stream using consumer groups."""
127
+ client = await self._get_client()
128
+ now = datetime.now(timezone.utc)
129
+ lease_expires = now + timedelta(seconds=self.lease_ttl_seconds)
130
+
131
+ # Read from consumer group
132
+ messages = await client.xreadgroup(
133
+ self.group_name,
134
+ worker_id,
135
+ {self.stream_key: ">"},
136
+ count=batch_size,
137
+ block=1000, # 1 second block
138
+ )
139
+
140
+ if not messages:
141
+ return []
142
+
143
+ claimed = []
144
+ for stream_name, stream_messages in messages:
145
+ for msg_id, data in stream_messages:
146
+ run_id_str = data.get(b"run_id", data.get("run_id"))
147
+ if isinstance(run_id_str, bytes):
148
+ run_id_str = run_id_str.decode()
149
+ run_id = UUID(run_id_str)
150
+
151
+ agent_key = data.get(b"agent_key", data.get("agent_key"))
152
+ if isinstance(agent_key, bytes):
153
+ agent_key = agent_key.decode()
154
+
155
+ # Filter by agent_keys if specified
156
+ if agent_keys and agent_key not in agent_keys:
157
+ await client.xack(self.stream_key, self.group_name, msg_id)
158
+ continue
159
+
160
+ # Get and update run data
161
+ run_key = self._run_key(run_id)
162
+ run_data = await client.hgetall(run_key)
163
+
164
+ if not run_data:
165
+ await client.xack(self.stream_key, self.group_name, msg_id)
166
+ continue
167
+
168
+ # Decode bytes if needed
169
+ run_data = {
170
+ (k.decode() if isinstance(k, bytes) else k):
171
+ (v.decode() if isinstance(v, bytes) else v)
172
+ for k, v in run_data.items()
173
+ }
174
+
175
+ status = run_data.get("status", "")
176
+ if status not in [RunStatus.QUEUED.value, RunStatus.RUNNING.value]:
177
+ await client.xack(self.stream_key, self.group_name, msg_id)
178
+ continue
179
+
180
+ # Check if already claimed
181
+ if status == RunStatus.RUNNING.value:
182
+ existing_expires = run_data.get("lease_expires_at", "")
183
+ if existing_expires:
184
+ expires_dt = datetime.fromisoformat(existing_expires)
185
+ if expires_dt > now:
186
+ await client.xack(self.stream_key, self.group_name, msg_id)
187
+ continue
188
+
189
+ # Claim the run
190
+ updates = {
191
+ "status": RunStatus.RUNNING.value,
192
+ "lease_owner": worker_id,
193
+ "lease_expires_at": lease_expires.isoformat(),
194
+ }
195
+ if not run_data.get("started_at"):
196
+ updates["started_at"] = now.isoformat()
197
+
198
+ await client.hset(run_key, mapping=updates)
199
+ await client.xack(self.stream_key, self.group_name, msg_id)
200
+
201
+ claimed.append(QueuedRun(
202
+ run_id=run_id,
203
+ agent_key=agent_key,
204
+ attempt=int(run_data.get("attempt", 1)),
205
+ lease_expires_at=lease_expires,
206
+ input=json.loads(run_data.get("input", "{}")),
207
+ metadata=json.loads(run_data.get("metadata", "{}")),
208
+ max_attempts=int(run_data.get("max_attempts", 3)),
209
+ status=RunStatus.RUNNING,
210
+ ))
211
+
212
+ return claimed
213
+
214
+ async def extend_lease(self, run_id: UUID, worker_id: str, seconds: int) -> bool:
215
+ """Extend lease in Redis."""
216
+ client = await self._get_client()
217
+ run_key = self._run_key(run_id)
218
+
219
+ run_data = await client.hgetall(run_key)
220
+ if not run_data:
221
+ return False
222
+
223
+ # Decode
224
+ run_data = {
225
+ (k.decode() if isinstance(k, bytes) else k):
226
+ (v.decode() if isinstance(v, bytes) else v)
227
+ for k, v in run_data.items()
228
+ }
229
+
230
+ if run_data.get("lease_owner") != worker_id:
231
+ return False
232
+ if run_data.get("status") != RunStatus.RUNNING.value:
233
+ return False
234
+
235
+ new_expires = datetime.now(timezone.utc) + timedelta(seconds=seconds)
236
+ await client.hset(run_key, "lease_expires_at", new_expires.isoformat())
237
+ return True
238
+
239
+ async def release(
240
+ self,
241
+ run_id: UUID,
242
+ worker_id: str,
243
+ success: bool,
244
+ output: Optional[dict] = None,
245
+ error: Optional[dict] = None,
246
+ ) -> None:
247
+ """Release run after completion."""
248
+ client = await self._get_client()
249
+ run_key = self._run_key(run_id)
250
+ now = datetime.now(timezone.utc)
251
+
252
+ updates = {
253
+ "status": RunStatus.SUCCEEDED.value if success else RunStatus.FAILED.value,
254
+ "finished_at": now.isoformat(),
255
+ "lease_owner": "",
256
+ "lease_expires_at": "",
257
+ }
258
+ if output:
259
+ updates["output"] = json.dumps(output)
260
+ if error:
261
+ updates["error"] = json.dumps(error)
262
+
263
+ await client.hset(run_key, mapping=updates)
264
+
265
+ async def requeue_for_retry(
266
+ self,
267
+ run_id: UUID,
268
+ worker_id: str,
269
+ error: dict,
270
+ delay_seconds: int = 0,
271
+ ) -> bool:
272
+ """Requeue for retry."""
273
+ client = await self._get_client()
274
+ run_key = self._run_key(run_id)
275
+
276
+ run_data = await client.hgetall(run_key)
277
+ if not run_data:
278
+ return False
279
+
280
+ run_data = {
281
+ (k.decode() if isinstance(k, bytes) else k):
282
+ (v.decode() if isinstance(v, bytes) else v)
283
+ for k, v in run_data.items()
284
+ }
285
+
286
+ if run_data.get("lease_owner") != worker_id:
287
+ return False
288
+
289
+ attempt = int(run_data.get("attempt", 1))
290
+ max_attempts = int(run_data.get("max_attempts", 3))
291
+
292
+ if attempt >= max_attempts:
293
+ await client.hset(run_key, mapping={
294
+ "status": RunStatus.FAILED.value,
295
+ "error": json.dumps(error),
296
+ "finished_at": datetime.now(timezone.utc).isoformat(),
297
+ "lease_owner": "",
298
+ "lease_expires_at": "",
299
+ })
300
+ return False
301
+
302
+ await client.hset(run_key, mapping={
303
+ "status": RunStatus.QUEUED.value,
304
+ "attempt": str(attempt + 1),
305
+ "error": json.dumps(error),
306
+ "lease_owner": "",
307
+ "lease_expires_at": "",
308
+ })
309
+
310
+ # Re-add to stream (with delay if needed - simplified, no delay support)
311
+ agent_key = run_data.get("agent_key", "")
312
+ await client.xadd(
313
+ self.stream_key,
314
+ {"run_id": str(run_id), "agent_key": agent_key},
315
+ )
316
+ return True
317
+
318
+ async def cancel(self, run_id: UUID) -> bool:
319
+ """Mark run for cancellation."""
320
+ client = await self._get_client()
321
+ run_key = self._run_key(run_id)
322
+
323
+ run_data = await client.hgetall(run_key)
324
+ if not run_data:
325
+ return False
326
+
327
+ run_data = {
328
+ (k.decode() if isinstance(k, bytes) else k):
329
+ (v.decode() if isinstance(v, bytes) else v)
330
+ for k, v in run_data.items()
331
+ }
332
+
333
+ status = run_data.get("status", "")
334
+ if status not in [RunStatus.QUEUED.value, RunStatus.RUNNING.value]:
335
+ return False
336
+
337
+ await client.hset(
338
+ run_key,
339
+ "cancel_requested_at",
340
+ datetime.now(timezone.utc).isoformat()
341
+ )
342
+ return True
343
+
344
+ async def is_cancelled(self, run_id: UUID) -> bool:
345
+ """Check if cancellation was requested."""
346
+ client = await self._get_client()
347
+ run_key = self._run_key(run_id)
348
+
349
+ cancel_at = await client.hget(run_key, "cancel_requested_at")
350
+ if isinstance(cancel_at, bytes):
351
+ cancel_at = cancel_at.decode()
352
+ return bool(cancel_at)
353
+
354
+ async def recover_expired_leases(self) -> int:
355
+ """Recover runs with expired leases."""
356
+ client = await self._get_client()
357
+ now = datetime.now(timezone.utc)
358
+
359
+ # Scan for all run keys
360
+ count = 0
361
+ async for key in client.scan_iter(f"{self.runs_key}:*"):
362
+ run_data = await client.hgetall(key)
363
+ if not run_data:
364
+ continue
365
+
366
+ run_data = {
367
+ (k.decode() if isinstance(k, bytes) else k):
368
+ (v.decode() if isinstance(v, bytes) else v)
369
+ for k, v in run_data.items()
370
+ }
371
+
372
+ if run_data.get("status") != RunStatus.RUNNING.value:
373
+ continue
374
+
375
+ lease_expires = run_data.get("lease_expires_at", "")
376
+ if not lease_expires:
377
+ continue
378
+
379
+ expires_dt = datetime.fromisoformat(lease_expires)
380
+ if expires_dt > now:
381
+ continue
382
+
383
+ # Lease expired
384
+ attempt = int(run_data.get("attempt", 1))
385
+ max_attempts = int(run_data.get("max_attempts", 3))
386
+ run_id = run_data.get("run_id", "")
387
+ agent_key = run_data.get("agent_key", "")
388
+
389
+ if attempt >= max_attempts:
390
+ await client.hset(key, mapping={
391
+ "status": RunStatus.TIMED_OUT.value,
392
+ "finished_at": now.isoformat(),
393
+ "error": json.dumps({
394
+ "type": "LeaseExpired",
395
+ "message": "Worker lease expired without completion",
396
+ "retriable": False,
397
+ }),
398
+ "lease_owner": "",
399
+ "lease_expires_at": "",
400
+ })
401
+ else:
402
+ await client.hset(key, mapping={
403
+ "status": RunStatus.QUEUED.value,
404
+ "attempt": str(attempt + 1),
405
+ "lease_owner": "",
406
+ "lease_expires_at": "",
407
+ })
408
+ # Re-add to stream
409
+ await client.xadd(
410
+ self.stream_key,
411
+ {"run_id": run_id, "agent_key": agent_key},
412
+ )
413
+
414
+ count += 1
415
+
416
+ return count
417
+
418
+ async def get_run(self, run_id: UUID) -> Optional[QueuedRun]:
419
+ """Get a run by ID."""
420
+ client = await self._get_client()
421
+ run_key = self._run_key(run_id)
422
+
423
+ run_data = await client.hgetall(run_key)
424
+ if not run_data:
425
+ return None
426
+
427
+ run_data = {
428
+ (k.decode() if isinstance(k, bytes) else k):
429
+ (v.decode() if isinstance(v, bytes) else v)
430
+ for k, v in run_data.items()
431
+ }
432
+
433
+ lease_expires = run_data.get("lease_expires_at", "")
434
+
435
+ return QueuedRun(
436
+ run_id=UUID(run_data["run_id"]),
437
+ agent_key=run_data.get("agent_key", ""),
438
+ attempt=int(run_data.get("attempt", 1)),
439
+ lease_expires_at=(
440
+ datetime.fromisoformat(lease_expires) if lease_expires
441
+ else datetime.now(timezone.utc)
442
+ ),
443
+ input=json.loads(run_data.get("input", "{}")),
444
+ metadata=json.loads(run_data.get("metadata", "{}")),
445
+ max_attempts=int(run_data.get("max_attempts", 3)),
446
+ status=RunStatus(run_data.get("status", "queued")),
447
+ )
448
+
449
+ async def close(self) -> None:
450
+ """Close Redis connection."""
451
+ if self._client:
452
+ await self._client.close()
453
+ self._client = None