codetether 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_server/__init__.py +29 -0
- a2a_server/a2a_agent_card.py +365 -0
- a2a_server/a2a_errors.py +1133 -0
- a2a_server/a2a_executor.py +926 -0
- a2a_server/a2a_router.py +1033 -0
- a2a_server/a2a_types.py +344 -0
- a2a_server/agent_card.py +408 -0
- a2a_server/agents_server.py +271 -0
- a2a_server/auth_api.py +349 -0
- a2a_server/billing_api.py +638 -0
- a2a_server/billing_service.py +712 -0
- a2a_server/billing_webhooks.py +501 -0
- a2a_server/config.py +96 -0
- a2a_server/database.py +2165 -0
- a2a_server/email_inbound.py +398 -0
- a2a_server/email_notifications.py +486 -0
- a2a_server/enhanced_agents.py +919 -0
- a2a_server/enhanced_server.py +160 -0
- a2a_server/hosted_worker.py +1049 -0
- a2a_server/integrated_agents_server.py +347 -0
- a2a_server/keycloak_auth.py +750 -0
- a2a_server/livekit_bridge.py +439 -0
- a2a_server/marketing_tools.py +1364 -0
- a2a_server/mcp_client.py +196 -0
- a2a_server/mcp_http_server.py +2256 -0
- a2a_server/mcp_server.py +191 -0
- a2a_server/message_broker.py +725 -0
- a2a_server/mock_mcp.py +273 -0
- a2a_server/models.py +494 -0
- a2a_server/monitor_api.py +5904 -0
- a2a_server/opencode_bridge.py +1594 -0
- a2a_server/redis_task_manager.py +518 -0
- a2a_server/server.py +726 -0
- a2a_server/task_manager.py +668 -0
- a2a_server/task_queue.py +742 -0
- a2a_server/tenant_api.py +333 -0
- a2a_server/tenant_middleware.py +219 -0
- a2a_server/tenant_service.py +760 -0
- a2a_server/user_auth.py +721 -0
- a2a_server/vault_client.py +576 -0
- a2a_server/worker_sse.py +873 -0
- agent_worker/__init__.py +8 -0
- agent_worker/worker.py +4877 -0
- codetether/__init__.py +10 -0
- codetether/__main__.py +4 -0
- codetether/cli.py +112 -0
- codetether/worker_cli.py +57 -0
- codetether-1.2.2.dist-info/METADATA +570 -0
- codetether-1.2.2.dist-info/RECORD +66 -0
- codetether-1.2.2.dist-info/WHEEL +5 -0
- codetether-1.2.2.dist-info/entry_points.txt +4 -0
- codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
- codetether-1.2.2.dist-info/top_level.txt +5 -0
- codetether_voice_agent/__init__.py +6 -0
- codetether_voice_agent/agent.py +445 -0
- codetether_voice_agent/codetether_mcp.py +345 -0
- codetether_voice_agent/config.py +16 -0
- codetether_voice_agent/functiongemma_caller.py +380 -0
- codetether_voice_agent/session_playback.py +247 -0
- codetether_voice_agent/tools/__init__.py +21 -0
- codetether_voice_agent/tools/definitions.py +135 -0
- codetether_voice_agent/tools/handlers.py +380 -0
- run_server.py +314 -0
- ui/monitor-tailwind.html +1790 -0
- ui/monitor.html +1775 -0
- ui/monitor.js +2662 -0
|
@@ -0,0 +1,668 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task management for A2A protocol.
|
|
3
|
+
|
|
4
|
+
Handles the lifecycle of tasks including creation, updates, and state management.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import uuid
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Dict, Optional, List, Callable, Any
|
|
12
|
+
from asyncio import Lock
|
|
13
|
+
import asyncio
|
|
14
|
+
|
|
15
|
+
from .models import Task, TaskStatus, TaskStatusUpdateEvent, Message
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TaskManager:
|
|
21
|
+
"""Manages the lifecycle and state of A2A tasks."""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self._tasks: Dict[str, Task] = {}
|
|
25
|
+
self._task_lock = Lock()
|
|
26
|
+
self._update_handlers: Dict[
|
|
27
|
+
str, List[Callable[[TaskStatusUpdateEvent], None]]
|
|
28
|
+
] = {}
|
|
29
|
+
self._handler_lock = Lock()
|
|
30
|
+
|
|
31
|
+
async def create_task(
|
|
32
|
+
self,
|
|
33
|
+
title: Optional[str] = None,
|
|
34
|
+
description: Optional[str] = None,
|
|
35
|
+
task_id: Optional[str] = None,
|
|
36
|
+
) -> Task:
|
|
37
|
+
"""Create a new task."""
|
|
38
|
+
if task_id is None:
|
|
39
|
+
task_id = str(uuid.uuid4())
|
|
40
|
+
|
|
41
|
+
now = datetime.utcnow()
|
|
42
|
+
task = Task(
|
|
43
|
+
id=task_id,
|
|
44
|
+
status=TaskStatus.PENDING,
|
|
45
|
+
created_at=now,
|
|
46
|
+
updated_at=now,
|
|
47
|
+
title=title,
|
|
48
|
+
description=description,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
async with self._task_lock:
|
|
52
|
+
self._tasks[task_id] = task
|
|
53
|
+
|
|
54
|
+
return task
|
|
55
|
+
|
|
56
|
+
async def get_task(self, task_id: str) -> Optional[Task]:
|
|
57
|
+
"""Retrieve a task by ID."""
|
|
58
|
+
async with self._task_lock:
|
|
59
|
+
return self._tasks.get(task_id)
|
|
60
|
+
|
|
61
|
+
async def update_task_status(
|
|
62
|
+
self,
|
|
63
|
+
task_id: str,
|
|
64
|
+
status: TaskStatus,
|
|
65
|
+
message: Optional[Message] = None,
|
|
66
|
+
progress: Optional[float] = None,
|
|
67
|
+
final: bool = False,
|
|
68
|
+
) -> Optional[Task]:
|
|
69
|
+
"""Update a task's status and notify handlers."""
|
|
70
|
+
async with self._task_lock:
|
|
71
|
+
task = self._tasks.get(task_id)
|
|
72
|
+
if not task:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
# Update task
|
|
76
|
+
task.status = status
|
|
77
|
+
task.updated_at = datetime.utcnow()
|
|
78
|
+
if progress is not None:
|
|
79
|
+
task.progress = progress
|
|
80
|
+
|
|
81
|
+
# Create update event
|
|
82
|
+
event = TaskStatusUpdateEvent(
|
|
83
|
+
task=task, message=message, final=final
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Notify handlers
|
|
87
|
+
await self._notify_handlers(task_id, event)
|
|
88
|
+
|
|
89
|
+
return task
|
|
90
|
+
|
|
91
|
+
async def cancel_task(self, task_id: str) -> Optional[Task]:
|
|
92
|
+
"""Cancel a task."""
|
|
93
|
+
return await self.update_task_status(
|
|
94
|
+
task_id, TaskStatus.CANCELLED, final=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
async def delete_task(self, task_id: str) -> bool:
|
|
98
|
+
"""Delete a task from storage."""
|
|
99
|
+
async with self._task_lock:
|
|
100
|
+
if task_id in self._tasks:
|
|
101
|
+
del self._tasks[task_id]
|
|
102
|
+
return True
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
async def list_tasks(
|
|
106
|
+
self, status: Optional[TaskStatus] = None
|
|
107
|
+
) -> List[Task]:
|
|
108
|
+
"""List all tasks, optionally filtered by status."""
|
|
109
|
+
async with self._task_lock:
|
|
110
|
+
tasks = list(self._tasks.values())
|
|
111
|
+
|
|
112
|
+
if status is not None:
|
|
113
|
+
tasks = [task for task in tasks if task.status == status]
|
|
114
|
+
|
|
115
|
+
return tasks
|
|
116
|
+
|
|
117
|
+
async def claim_task(self, task_id: str, worker_id: str) -> Optional[Task]:
|
|
118
|
+
"""
|
|
119
|
+
Atomically claim a task for a worker.
|
|
120
|
+
|
|
121
|
+
This method checks if the task is in 'pending' status and atomically
|
|
122
|
+
sets it to 'working' while recording which worker claimed it.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
task_id: The ID of the task to claim
|
|
126
|
+
worker_id: The ID of the worker claiming the task
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The claimed Task if successful, None if the task doesn't exist,
|
|
130
|
+
is not in pending status, or was already claimed by another worker.
|
|
131
|
+
"""
|
|
132
|
+
async with self._task_lock:
|
|
133
|
+
task = self._tasks.get(task_id)
|
|
134
|
+
if not task:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
# Only allow claiming pending tasks
|
|
138
|
+
if task.status != TaskStatus.PENDING:
|
|
139
|
+
logger.debug(
|
|
140
|
+
f'Task {task_id} cannot be claimed: status is {task.status.value}'
|
|
141
|
+
)
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
# Claim the task
|
|
145
|
+
now = datetime.utcnow()
|
|
146
|
+
task.status = TaskStatus.WORKING
|
|
147
|
+
task.worker_id = worker_id
|
|
148
|
+
task.claimed_at = now
|
|
149
|
+
task.updated_at = now
|
|
150
|
+
|
|
151
|
+
# Create update event
|
|
152
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
153
|
+
|
|
154
|
+
# Notify handlers outside the lock
|
|
155
|
+
await self._notify_handlers(task_id, event)
|
|
156
|
+
logger.info(f'Task {task_id} claimed by worker {worker_id}')
|
|
157
|
+
|
|
158
|
+
return task
|
|
159
|
+
|
|
160
|
+
async def release_task(
|
|
161
|
+
self, task_id: str, worker_id: str
|
|
162
|
+
) -> Optional[Task]:
|
|
163
|
+
"""
|
|
164
|
+
Release a claimed task back to pending status.
|
|
165
|
+
|
|
166
|
+
This is used when a worker fails, disconnects, or wants to give up
|
|
167
|
+
on a task. Only the worker that claimed the task can release it.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
task_id: The ID of the task to release
|
|
171
|
+
worker_id: The ID of the worker releasing the task
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
The released Task if successful, None if the task doesn't exist
|
|
175
|
+
or the worker_id doesn't match the claiming worker.
|
|
176
|
+
"""
|
|
177
|
+
async with self._task_lock:
|
|
178
|
+
task = self._tasks.get(task_id)
|
|
179
|
+
if not task:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
# Only allow the claiming worker to release the task
|
|
183
|
+
if task.worker_id != worker_id:
|
|
184
|
+
logger.warning(
|
|
185
|
+
f'Worker {worker_id} attempted to release task {task_id} '
|
|
186
|
+
f'but it is owned by {task.worker_id}'
|
|
187
|
+
)
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
# Only release tasks that are currently being worked on
|
|
191
|
+
if task.status != TaskStatus.WORKING:
|
|
192
|
+
logger.debug(
|
|
193
|
+
f'Task {task_id} cannot be released: status is {task.status.value}'
|
|
194
|
+
)
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
# Release the task
|
|
198
|
+
now = datetime.utcnow()
|
|
199
|
+
task.status = TaskStatus.PENDING
|
|
200
|
+
task.worker_id = None
|
|
201
|
+
task.claimed_at = None
|
|
202
|
+
task.updated_at = now
|
|
203
|
+
|
|
204
|
+
# Create update event
|
|
205
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
206
|
+
|
|
207
|
+
# Notify handlers outside the lock
|
|
208
|
+
await self._notify_handlers(task_id, event)
|
|
209
|
+
logger.info(f'Task {task_id} released by worker {worker_id}')
|
|
210
|
+
|
|
211
|
+
return task
|
|
212
|
+
|
|
213
|
+
async def register_update_handler(
|
|
214
|
+
self, task_id: str, handler: Callable[[TaskStatusUpdateEvent], None]
|
|
215
|
+
) -> None:
|
|
216
|
+
"""Register a handler for task updates."""
|
|
217
|
+
async with self._handler_lock:
|
|
218
|
+
if task_id not in self._update_handlers:
|
|
219
|
+
self._update_handlers[task_id] = []
|
|
220
|
+
self._update_handlers[task_id].append(handler)
|
|
221
|
+
|
|
222
|
+
async def unregister_update_handler(
|
|
223
|
+
self, task_id: str, handler: Callable[[TaskStatusUpdateEvent], None]
|
|
224
|
+
) -> None:
|
|
225
|
+
"""Unregister a handler for task updates."""
|
|
226
|
+
async with self._handler_lock:
|
|
227
|
+
if task_id in self._update_handlers:
|
|
228
|
+
try:
|
|
229
|
+
self._update_handlers[task_id].remove(handler)
|
|
230
|
+
if not self._update_handlers[task_id]:
|
|
231
|
+
del self._update_handlers[task_id]
|
|
232
|
+
except ValueError:
|
|
233
|
+
pass # Handler wasn't registered
|
|
234
|
+
|
|
235
|
+
async def _notify_handlers(
|
|
236
|
+
self, task_id: str, event: TaskStatusUpdateEvent
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Notify all registered handlers for a task."""
|
|
239
|
+
async with self._handler_lock:
|
|
240
|
+
handlers = self._update_handlers.get(task_id, []).copy()
|
|
241
|
+
|
|
242
|
+
# Run handlers concurrently
|
|
243
|
+
if handlers:
|
|
244
|
+
await asyncio.gather(
|
|
245
|
+
*[
|
|
246
|
+
self._safe_call_handler(handler, event)
|
|
247
|
+
for handler in handlers
|
|
248
|
+
],
|
|
249
|
+
return_exceptions=True,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
async def _safe_call_handler(
|
|
253
|
+
self,
|
|
254
|
+
handler: Callable[[TaskStatusUpdateEvent], None],
|
|
255
|
+
event: TaskStatusUpdateEvent,
|
|
256
|
+
) -> None:
|
|
257
|
+
"""Safely call a handler, catching any exceptions."""
|
|
258
|
+
try:
|
|
259
|
+
if asyncio.iscoroutinefunction(handler):
|
|
260
|
+
await handler(event)
|
|
261
|
+
else:
|
|
262
|
+
handler(event)
|
|
263
|
+
except Exception as e:
|
|
264
|
+
# Log error but don't let it break other handlers
|
|
265
|
+
print(f'Error in task update handler: {e}')
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class InMemoryTaskManager(TaskManager):
|
|
269
|
+
"""In-memory implementation of TaskManager."""
|
|
270
|
+
|
|
271
|
+
pass # Uses the base class implementation
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class PersistentTaskManager(TaskManager):
|
|
275
|
+
"""Task manager with PostgreSQL-backed persistent storage."""
|
|
276
|
+
|
|
277
|
+
def __init__(self, storage_path: str):
|
|
278
|
+
super().__init__()
|
|
279
|
+
self.storage_path = storage_path
|
|
280
|
+
self._pool = None
|
|
281
|
+
self._pool_lock = Lock()
|
|
282
|
+
self._initialized = False
|
|
283
|
+
|
|
284
|
+
async def _get_pool(self):
|
|
285
|
+
if self._pool:
|
|
286
|
+
return self._pool
|
|
287
|
+
|
|
288
|
+
async with self._pool_lock:
|
|
289
|
+
if self._pool:
|
|
290
|
+
return self._pool
|
|
291
|
+
|
|
292
|
+
try:
|
|
293
|
+
import asyncpg
|
|
294
|
+
except ImportError as exc:
|
|
295
|
+
raise ImportError(
|
|
296
|
+
'asyncpg is required for PersistentTaskManager. Install with: pip install asyncpg'
|
|
297
|
+
) from exc
|
|
298
|
+
|
|
299
|
+
self._pool = await asyncpg.create_pool(
|
|
300
|
+
self.storage_path,
|
|
301
|
+
min_size=1,
|
|
302
|
+
max_size=10,
|
|
303
|
+
command_timeout=30,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if not self._initialized:
|
|
307
|
+
await self._init_schema(self._pool)
|
|
308
|
+
self._initialized = True
|
|
309
|
+
|
|
310
|
+
return self._pool
|
|
311
|
+
|
|
312
|
+
async def _init_schema(self, pool) -> None:
|
|
313
|
+
async with pool.acquire() as conn:
|
|
314
|
+
await conn.execute("""
|
|
315
|
+
CREATE TABLE IF NOT EXISTS a2a_tasks (
|
|
316
|
+
id TEXT PRIMARY KEY,
|
|
317
|
+
status TEXT NOT NULL,
|
|
318
|
+
title TEXT,
|
|
319
|
+
description TEXT,
|
|
320
|
+
created_at TIMESTAMPTZ NOT NULL,
|
|
321
|
+
updated_at TIMESTAMPTZ NOT NULL,
|
|
322
|
+
progress REAL,
|
|
323
|
+
messages JSONB DEFAULT '[]'::jsonb,
|
|
324
|
+
worker_id TEXT,
|
|
325
|
+
claimed_at TIMESTAMPTZ
|
|
326
|
+
)
|
|
327
|
+
""")
|
|
328
|
+
|
|
329
|
+
await conn.execute(
|
|
330
|
+
'CREATE INDEX IF NOT EXISTS idx_a2a_tasks_status ON a2a_tasks(status)'
|
|
331
|
+
)
|
|
332
|
+
await conn.execute(
|
|
333
|
+
'CREATE INDEX IF NOT EXISTS idx_a2a_tasks_updated_at ON a2a_tasks(updated_at)'
|
|
334
|
+
)
|
|
335
|
+
await conn.execute(
|
|
336
|
+
'CREATE INDEX IF NOT EXISTS idx_a2a_tasks_worker_id ON a2a_tasks(worker_id)'
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Add columns if they don't exist (for existing databases)
|
|
340
|
+
await conn.execute("""
|
|
341
|
+
DO $$
|
|
342
|
+
BEGIN
|
|
343
|
+
IF NOT EXISTS (SELECT 1 FROM information_schema.columns
|
|
344
|
+
WHERE table_name = 'a2a_tasks' AND column_name = 'worker_id') THEN
|
|
345
|
+
ALTER TABLE a2a_tasks ADD COLUMN worker_id TEXT;
|
|
346
|
+
END IF;
|
|
347
|
+
IF NOT EXISTS (SELECT 1 FROM information_schema.columns
|
|
348
|
+
WHERE table_name = 'a2a_tasks' AND column_name = 'claimed_at') THEN
|
|
349
|
+
ALTER TABLE a2a_tasks ADD COLUMN claimed_at TIMESTAMPTZ;
|
|
350
|
+
END IF;
|
|
351
|
+
END $$;
|
|
352
|
+
""")
|
|
353
|
+
|
|
354
|
+
def _deserialize_messages(self, value: Any) -> Optional[List[Message]]:
|
|
355
|
+
if not value:
|
|
356
|
+
return None
|
|
357
|
+
if isinstance(value, str):
|
|
358
|
+
try:
|
|
359
|
+
value = json.loads(value)
|
|
360
|
+
except json.JSONDecodeError as exc:
|
|
361
|
+
logger.warning('Failed to decode task messages: %s', exc)
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
messages: List[Message] = []
|
|
365
|
+
if isinstance(value, list):
|
|
366
|
+
for item in value:
|
|
367
|
+
try:
|
|
368
|
+
messages.append(Message.model_validate(item))
|
|
369
|
+
except Exception as exc:
|
|
370
|
+
logger.warning('Failed to parse task message: %s', exc)
|
|
371
|
+
return messages or None
|
|
372
|
+
|
|
373
|
+
def _row_to_task(self, row) -> Task:
|
|
374
|
+
return Task(
|
|
375
|
+
id=row['id'],
|
|
376
|
+
status=TaskStatus(row['status']),
|
|
377
|
+
title=row['title'],
|
|
378
|
+
description=row['description'],
|
|
379
|
+
created_at=row['created_at'],
|
|
380
|
+
updated_at=row['updated_at'],
|
|
381
|
+
progress=row['progress'],
|
|
382
|
+
messages=self._deserialize_messages(row['messages']),
|
|
383
|
+
worker_id=row.get('worker_id'),
|
|
384
|
+
claimed_at=row.get('claimed_at'),
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
async def create_task(
|
|
388
|
+
self,
|
|
389
|
+
title: Optional[str] = None,
|
|
390
|
+
description: Optional[str] = None,
|
|
391
|
+
task_id: Optional[str] = None,
|
|
392
|
+
) -> Task:
|
|
393
|
+
"""Create a new task and store it in PostgreSQL."""
|
|
394
|
+
if task_id is None:
|
|
395
|
+
task_id = str(uuid.uuid4())
|
|
396
|
+
|
|
397
|
+
now = datetime.utcnow()
|
|
398
|
+
task = Task(
|
|
399
|
+
id=task_id,
|
|
400
|
+
status=TaskStatus.PENDING,
|
|
401
|
+
created_at=now,
|
|
402
|
+
updated_at=now,
|
|
403
|
+
title=title,
|
|
404
|
+
description=description,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
pool = await self._get_pool()
|
|
408
|
+
async with pool.acquire() as conn:
|
|
409
|
+
await conn.execute(
|
|
410
|
+
"""
|
|
411
|
+
INSERT INTO a2a_tasks (id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at)
|
|
412
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
413
|
+
""",
|
|
414
|
+
task.id,
|
|
415
|
+
task.status.value,
|
|
416
|
+
task.title,
|
|
417
|
+
task.description,
|
|
418
|
+
task.created_at,
|
|
419
|
+
task.updated_at,
|
|
420
|
+
task.progress,
|
|
421
|
+
json.dumps([]),
|
|
422
|
+
None, # worker_id
|
|
423
|
+
None, # claimed_at
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
return task
|
|
427
|
+
|
|
428
|
+
async def get_task(self, task_id: str) -> Optional[Task]:
|
|
429
|
+
"""Retrieve a task by ID."""
|
|
430
|
+
pool = await self._get_pool()
|
|
431
|
+
async with pool.acquire() as conn:
|
|
432
|
+
row = await conn.fetchrow(
|
|
433
|
+
"""
|
|
434
|
+
SELECT id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
435
|
+
FROM a2a_tasks
|
|
436
|
+
WHERE id = $1
|
|
437
|
+
""",
|
|
438
|
+
task_id,
|
|
439
|
+
)
|
|
440
|
+
if not row:
|
|
441
|
+
return None
|
|
442
|
+
return self._row_to_task(row)
|
|
443
|
+
|
|
444
|
+
async def update_task_status(
|
|
445
|
+
self,
|
|
446
|
+
task_id: str,
|
|
447
|
+
status: TaskStatus,
|
|
448
|
+
message: Optional[Message] = None,
|
|
449
|
+
progress: Optional[float] = None,
|
|
450
|
+
final: bool = False,
|
|
451
|
+
) -> Optional[Task]:
|
|
452
|
+
"""Update a task's status in PostgreSQL and notify handlers."""
|
|
453
|
+
updates = ['status = $2', 'updated_at = $3']
|
|
454
|
+
params: List[Any] = [task_id, status.value, datetime.utcnow()]
|
|
455
|
+
param_idx = 4
|
|
456
|
+
|
|
457
|
+
if progress is not None:
|
|
458
|
+
updates.append(f'progress = ${param_idx}')
|
|
459
|
+
params.append(progress)
|
|
460
|
+
param_idx += 1
|
|
461
|
+
|
|
462
|
+
if message:
|
|
463
|
+
updates.append(
|
|
464
|
+
f"messages = COALESCE(messages, '[]'::jsonb) || ${param_idx}::jsonb"
|
|
465
|
+
)
|
|
466
|
+
params.append(json.dumps([message.model_dump(mode='json')]))
|
|
467
|
+
param_idx += 1
|
|
468
|
+
|
|
469
|
+
query = f"""
|
|
470
|
+
UPDATE a2a_tasks
|
|
471
|
+
SET {', '.join(updates)}
|
|
472
|
+
WHERE id = $1
|
|
473
|
+
RETURNING id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
pool = await self._get_pool()
|
|
477
|
+
async with pool.acquire() as conn:
|
|
478
|
+
row = await conn.fetchrow(query, *params)
|
|
479
|
+
|
|
480
|
+
if not row:
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
task = self._row_to_task(row)
|
|
484
|
+
|
|
485
|
+
event = TaskStatusUpdateEvent(task=task, message=message, final=final)
|
|
486
|
+
|
|
487
|
+
await self._notify_handlers(task_id, event)
|
|
488
|
+
|
|
489
|
+
return task
|
|
490
|
+
|
|
491
|
+
async def delete_task(self, task_id: str) -> bool:
|
|
492
|
+
"""Delete a task from storage."""
|
|
493
|
+
pool = await self._get_pool()
|
|
494
|
+
async with pool.acquire() as conn:
|
|
495
|
+
result = await conn.execute(
|
|
496
|
+
'DELETE FROM a2a_tasks WHERE id = $1', task_id
|
|
497
|
+
)
|
|
498
|
+
return 'DELETE 1' in result
|
|
499
|
+
|
|
500
|
+
async def list_tasks(
|
|
501
|
+
self, status: Optional[TaskStatus] = None
|
|
502
|
+
) -> List[Task]:
|
|
503
|
+
"""List all tasks, optionally filtered by status."""
|
|
504
|
+
pool = await self._get_pool()
|
|
505
|
+
async with pool.acquire() as conn:
|
|
506
|
+
if status is None:
|
|
507
|
+
rows = await conn.fetch(
|
|
508
|
+
"""
|
|
509
|
+
SELECT id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
510
|
+
FROM a2a_tasks
|
|
511
|
+
ORDER BY created_at DESC
|
|
512
|
+
"""
|
|
513
|
+
)
|
|
514
|
+
else:
|
|
515
|
+
rows = await conn.fetch(
|
|
516
|
+
"""
|
|
517
|
+
SELECT id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
518
|
+
FROM a2a_tasks
|
|
519
|
+
WHERE status = $1
|
|
520
|
+
ORDER BY created_at DESC
|
|
521
|
+
""",
|
|
522
|
+
status.value,
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
return [self._row_to_task(row) for row in rows]
|
|
526
|
+
|
|
527
|
+
async def claim_task(self, task_id: str, worker_id: str) -> Optional[Task]:
|
|
528
|
+
"""
|
|
529
|
+
Atomically claim a task for a worker using database transactions.
|
|
530
|
+
|
|
531
|
+
Uses SELECT FOR UPDATE to lock the row and ensure atomic claim.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
task_id: The ID of the task to claim
|
|
535
|
+
worker_id: The ID of the worker claiming the task
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
The claimed Task if successful, None if the task doesn't exist,
|
|
539
|
+
is not in pending status, or was already claimed by another worker.
|
|
540
|
+
"""
|
|
541
|
+
pool = await self._get_pool()
|
|
542
|
+
now = datetime.utcnow()
|
|
543
|
+
|
|
544
|
+
async with pool.acquire() as conn:
|
|
545
|
+
# Use a transaction with row-level locking for atomicity
|
|
546
|
+
async with conn.transaction():
|
|
547
|
+
# Lock the row and check status atomically
|
|
548
|
+
row = await conn.fetchrow(
|
|
549
|
+
"""
|
|
550
|
+
SELECT id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
551
|
+
FROM a2a_tasks
|
|
552
|
+
WHERE id = $1
|
|
553
|
+
FOR UPDATE
|
|
554
|
+
""",
|
|
555
|
+
task_id,
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
if not row:
|
|
559
|
+
return None
|
|
560
|
+
|
|
561
|
+
# Check if task is pending
|
|
562
|
+
if row['status'] != TaskStatus.PENDING.value:
|
|
563
|
+
logger.debug(
|
|
564
|
+
f'Task {task_id} cannot be claimed: status is {row["status"]}'
|
|
565
|
+
)
|
|
566
|
+
return None
|
|
567
|
+
|
|
568
|
+
# Claim the task
|
|
569
|
+
updated_row = await conn.fetchrow(
|
|
570
|
+
"""
|
|
571
|
+
UPDATE a2a_tasks
|
|
572
|
+
SET status = $2, worker_id = $3, claimed_at = $4, updated_at = $4
|
|
573
|
+
WHERE id = $1
|
|
574
|
+
RETURNING id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
575
|
+
""",
|
|
576
|
+
task_id,
|
|
577
|
+
TaskStatus.WORKING.value,
|
|
578
|
+
worker_id,
|
|
579
|
+
now,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
task = self._row_to_task(updated_row)
|
|
583
|
+
|
|
584
|
+
# Notify handlers
|
|
585
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
586
|
+
await self._notify_handlers(task_id, event)
|
|
587
|
+
logger.info(f'Task {task_id} claimed by worker {worker_id}')
|
|
588
|
+
|
|
589
|
+
return task
|
|
590
|
+
|
|
591
|
+
async def release_task(
|
|
592
|
+
self, task_id: str, worker_id: str
|
|
593
|
+
) -> Optional[Task]:
|
|
594
|
+
"""
|
|
595
|
+
Release a claimed task back to pending status using database transactions.
|
|
596
|
+
|
|
597
|
+
Uses SELECT FOR UPDATE to lock the row and ensure atomic release.
|
|
598
|
+
|
|
599
|
+
Args:
|
|
600
|
+
task_id: The ID of the task to release
|
|
601
|
+
worker_id: The ID of the worker releasing the task
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
The released Task if successful, None if the task doesn't exist
|
|
605
|
+
or the worker_id doesn't match the claiming worker.
|
|
606
|
+
"""
|
|
607
|
+
pool = await self._get_pool()
|
|
608
|
+
now = datetime.utcnow()
|
|
609
|
+
|
|
610
|
+
async with pool.acquire() as conn:
|
|
611
|
+
# Use a transaction with row-level locking for atomicity
|
|
612
|
+
async with conn.transaction():
|
|
613
|
+
# Lock the row and check ownership atomically
|
|
614
|
+
row = await conn.fetchrow(
|
|
615
|
+
"""
|
|
616
|
+
SELECT id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
617
|
+
FROM a2a_tasks
|
|
618
|
+
WHERE id = $1
|
|
619
|
+
FOR UPDATE
|
|
620
|
+
""",
|
|
621
|
+
task_id,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
if not row:
|
|
625
|
+
return None
|
|
626
|
+
|
|
627
|
+
# Check if worker owns this task
|
|
628
|
+
if row['worker_id'] != worker_id:
|
|
629
|
+
logger.warning(
|
|
630
|
+
f'Worker {worker_id} attempted to release task {task_id} '
|
|
631
|
+
f'but it is owned by {row["worker_id"]}'
|
|
632
|
+
)
|
|
633
|
+
return None
|
|
634
|
+
|
|
635
|
+
# Check if task is in working status
|
|
636
|
+
if row['status'] != TaskStatus.WORKING.value:
|
|
637
|
+
logger.debug(
|
|
638
|
+
f'Task {task_id} cannot be released: status is {row["status"]}'
|
|
639
|
+
)
|
|
640
|
+
return None
|
|
641
|
+
|
|
642
|
+
# Release the task
|
|
643
|
+
updated_row = await conn.fetchrow(
|
|
644
|
+
"""
|
|
645
|
+
UPDATE a2a_tasks
|
|
646
|
+
SET status = $2, worker_id = NULL, claimed_at = NULL, updated_at = $3
|
|
647
|
+
WHERE id = $1
|
|
648
|
+
RETURNING id, status, title, description, created_at, updated_at, progress, messages, worker_id, claimed_at
|
|
649
|
+
""",
|
|
650
|
+
task_id,
|
|
651
|
+
TaskStatus.PENDING.value,
|
|
652
|
+
now,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
task = self._row_to_task(updated_row)
|
|
656
|
+
|
|
657
|
+
# Notify handlers
|
|
658
|
+
event = TaskStatusUpdateEvent(task=task, message=None, final=False)
|
|
659
|
+
await self._notify_handlers(task_id, event)
|
|
660
|
+
logger.info(f'Task {task_id} released by worker {worker_id}')
|
|
661
|
+
|
|
662
|
+
return task
|
|
663
|
+
|
|
664
|
+
async def cleanup(self) -> None:
|
|
665
|
+
"""Close database connections."""
|
|
666
|
+
if self._pool:
|
|
667
|
+
await self._pool.close()
|
|
668
|
+
self._pool = None
|