codetether 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_server/__init__.py +29 -0
- a2a_server/a2a_agent_card.py +365 -0
- a2a_server/a2a_errors.py +1133 -0
- a2a_server/a2a_executor.py +926 -0
- a2a_server/a2a_router.py +1033 -0
- a2a_server/a2a_types.py +344 -0
- a2a_server/agent_card.py +408 -0
- a2a_server/agents_server.py +271 -0
- a2a_server/auth_api.py +349 -0
- a2a_server/billing_api.py +638 -0
- a2a_server/billing_service.py +712 -0
- a2a_server/billing_webhooks.py +501 -0
- a2a_server/config.py +96 -0
- a2a_server/database.py +2165 -0
- a2a_server/email_inbound.py +398 -0
- a2a_server/email_notifications.py +486 -0
- a2a_server/enhanced_agents.py +919 -0
- a2a_server/enhanced_server.py +160 -0
- a2a_server/hosted_worker.py +1049 -0
- a2a_server/integrated_agents_server.py +347 -0
- a2a_server/keycloak_auth.py +750 -0
- a2a_server/livekit_bridge.py +439 -0
- a2a_server/marketing_tools.py +1364 -0
- a2a_server/mcp_client.py +196 -0
- a2a_server/mcp_http_server.py +2256 -0
- a2a_server/mcp_server.py +191 -0
- a2a_server/message_broker.py +725 -0
- a2a_server/mock_mcp.py +273 -0
- a2a_server/models.py +494 -0
- a2a_server/monitor_api.py +5904 -0
- a2a_server/opencode_bridge.py +1594 -0
- a2a_server/redis_task_manager.py +518 -0
- a2a_server/server.py +726 -0
- a2a_server/task_manager.py +668 -0
- a2a_server/task_queue.py +742 -0
- a2a_server/tenant_api.py +333 -0
- a2a_server/tenant_middleware.py +219 -0
- a2a_server/tenant_service.py +760 -0
- a2a_server/user_auth.py +721 -0
- a2a_server/vault_client.py +576 -0
- a2a_server/worker_sse.py +873 -0
- agent_worker/__init__.py +8 -0
- agent_worker/worker.py +4877 -0
- codetether/__init__.py +10 -0
- codetether/__main__.py +4 -0
- codetether/cli.py +112 -0
- codetether/worker_cli.py +57 -0
- codetether-1.2.2.dist-info/METADATA +570 -0
- codetether-1.2.2.dist-info/RECORD +66 -0
- codetether-1.2.2.dist-info/WHEEL +5 -0
- codetether-1.2.2.dist-info/entry_points.txt +4 -0
- codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
- codetether-1.2.2.dist-info/top_level.txt +5 -0
- codetether_voice_agent/__init__.py +6 -0
- codetether_voice_agent/agent.py +445 -0
- codetether_voice_agent/codetether_mcp.py +345 -0
- codetether_voice_agent/config.py +16 -0
- codetether_voice_agent/functiongemma_caller.py +380 -0
- codetether_voice_agent/session_playback.py +247 -0
- codetether_voice_agent/tools/__init__.py +21 -0
- codetether_voice_agent/tools/definitions.py +135 -0
- codetether_voice_agent/tools/handlers.py +380 -0
- run_server.py +314 -0
- ui/monitor-tailwind.html +1790 -0
- ui/monitor.html +1775 -0
- ui/monitor.js +2662 -0
a2a_server/worker_sse.py
ADDED
|
@@ -0,0 +1,873 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Worker SSE Task Stream - Push-based task distribution for A2A workers.
|
|
3
|
+
|
|
4
|
+
This module implements "reverse-polling" via Server-Sent Events (SSE) where workers
|
|
5
|
+
connect outbound to the server and receive task notifications pushed to them.
|
|
6
|
+
|
|
7
|
+
Design:
|
|
8
|
+
- Workers connect to GET /v1/worker/tasks/stream with their agent_name
|
|
9
|
+
- Server maintains a registry of connected workers
|
|
10
|
+
- When a task is created, server pushes it to an available connected worker
|
|
11
|
+
- Task claiming ensures only one worker gets each task
|
|
12
|
+
- Heartbeat/keepalive every 30 seconds
|
|
13
|
+
|
|
14
|
+
Security:
|
|
15
|
+
- Workers identify themselves via agent_name (query param or header)
|
|
16
|
+
- Optional Bearer token authentication via A2A_AUTH_TOKENS env var
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import json
|
|
21
|
+
import logging
|
|
22
|
+
import os
|
|
23
|
+
import uuid
|
|
24
|
+
from dataclasses import dataclass, field
|
|
25
|
+
from datetime import datetime, timezone
|
|
26
|
+
from typing import Any, Callable, Dict, List, Optional, Set
|
|
27
|
+
from functools import lru_cache
|
|
28
|
+
|
|
29
|
+
from fastapi import APIRouter, HTTPException, Request, Query, Header
|
|
30
|
+
from fastapi.responses import StreamingResponse, JSONResponse
|
|
31
|
+
from pydantic import BaseModel
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
# Router for worker SSE endpoints
|
|
36
|
+
worker_sse_router = APIRouter(prefix='/v1/worker', tags=['worker-sse'])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ConnectedWorker:
|
|
41
|
+
"""Represents a worker connected via SSE."""
|
|
42
|
+
|
|
43
|
+
worker_id: str
|
|
44
|
+
agent_name: str
|
|
45
|
+
queue: asyncio.Queue
|
|
46
|
+
connected_at: datetime
|
|
47
|
+
last_heartbeat: datetime
|
|
48
|
+
capabilities: List[str] = field(default_factory=list)
|
|
49
|
+
codebases: Set[str] = field(default_factory=set)
|
|
50
|
+
is_busy: bool = False # True when worker is processing a task
|
|
51
|
+
current_task_id: Optional[str] = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class WorkerRegistry:
|
|
55
|
+
"""
|
|
56
|
+
Registry of workers connected via SSE for push-based task distribution.
|
|
57
|
+
|
|
58
|
+
Thread-safe via asyncio locks. Supports:
|
|
59
|
+
- Worker registration/deregistration on SSE connect/disconnect
|
|
60
|
+
- Task routing to available workers
|
|
61
|
+
- Atomic task claiming to prevent double-assignment
|
|
62
|
+
- Heartbeat tracking for connection health
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self._workers: Dict[str, ConnectedWorker] = {}
|
|
67
|
+
self._lock = asyncio.Lock()
|
|
68
|
+
# Track claimed tasks: task_id -> worker_id
|
|
69
|
+
self._claimed_tasks: Dict[str, str] = {}
|
|
70
|
+
# Callbacks for task creation events
|
|
71
|
+
self._task_listeners: List[Callable] = []
|
|
72
|
+
|
|
73
|
+
async def register_worker(
|
|
74
|
+
self,
|
|
75
|
+
worker_id: str,
|
|
76
|
+
agent_name: str,
|
|
77
|
+
queue: asyncio.Queue,
|
|
78
|
+
capabilities: Optional[List[str]] = None,
|
|
79
|
+
codebases: Optional[Set[str]] = None,
|
|
80
|
+
) -> ConnectedWorker:
|
|
81
|
+
"""Register a new SSE-connected worker."""
|
|
82
|
+
async with self._lock:
|
|
83
|
+
now = datetime.now(timezone.utc)
|
|
84
|
+
worker = ConnectedWorker(
|
|
85
|
+
worker_id=worker_id,
|
|
86
|
+
agent_name=agent_name,
|
|
87
|
+
queue=queue,
|
|
88
|
+
connected_at=now,
|
|
89
|
+
last_heartbeat=now,
|
|
90
|
+
capabilities=capabilities or [],
|
|
91
|
+
codebases=codebases or set(),
|
|
92
|
+
)
|
|
93
|
+
self._workers[worker_id] = worker
|
|
94
|
+
logger.info(
|
|
95
|
+
f"Worker '{agent_name}' (id={worker_id}) connected via SSE. "
|
|
96
|
+
f'Total connected: {len(self._workers)}'
|
|
97
|
+
)
|
|
98
|
+
return worker
|
|
99
|
+
|
|
100
|
+
async def unregister_worker(
|
|
101
|
+
self, worker_id: str
|
|
102
|
+
) -> Optional[ConnectedWorker]:
|
|
103
|
+
"""Unregister a disconnected worker."""
|
|
104
|
+
async with self._lock:
|
|
105
|
+
worker = self._workers.pop(worker_id, None)
|
|
106
|
+
if worker:
|
|
107
|
+
# Release any claimed tasks back to pending
|
|
108
|
+
tasks_to_release = [
|
|
109
|
+
tid
|
|
110
|
+
for tid, wid in self._claimed_tasks.items()
|
|
111
|
+
if wid == worker_id
|
|
112
|
+
]
|
|
113
|
+
for tid in tasks_to_release:
|
|
114
|
+
del self._claimed_tasks[tid]
|
|
115
|
+
logger.warning(
|
|
116
|
+
f'Task {tid} released due to worker disconnect (worker={worker_id})'
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
logger.info(
|
|
120
|
+
f"Worker '{worker.agent_name}' (id={worker_id}) disconnected. "
|
|
121
|
+
f'Total connected: {len(self._workers)}'
|
|
122
|
+
)
|
|
123
|
+
return worker
|
|
124
|
+
|
|
125
|
+
async def update_heartbeat(self, worker_id: str) -> bool:
|
|
126
|
+
"""Update the last heartbeat time for a worker."""
|
|
127
|
+
async with self._lock:
|
|
128
|
+
worker = self._workers.get(worker_id)
|
|
129
|
+
if worker:
|
|
130
|
+
worker.last_heartbeat = datetime.now(timezone.utc)
|
|
131
|
+
return True
|
|
132
|
+
return False
|
|
133
|
+
|
|
134
|
+
async def update_worker_codebases(
|
|
135
|
+
self, worker_id: str, codebases: Set[str]
|
|
136
|
+
) -> bool:
|
|
137
|
+
"""Update the codebases a worker can handle."""
|
|
138
|
+
async with self._lock:
|
|
139
|
+
worker = self._workers.get(worker_id)
|
|
140
|
+
if worker:
|
|
141
|
+
worker.codebases = codebases
|
|
142
|
+
return True
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
async def claim_task(self, task_id: str, worker_id: str) -> bool:
|
|
146
|
+
"""
|
|
147
|
+
Atomically claim a task for a worker.
|
|
148
|
+
|
|
149
|
+
Returns True if claim succeeded, False if task was already claimed.
|
|
150
|
+
"""
|
|
151
|
+
async with self._lock:
|
|
152
|
+
if task_id in self._claimed_tasks:
|
|
153
|
+
existing_worker = self._claimed_tasks[task_id]
|
|
154
|
+
if existing_worker == worker_id:
|
|
155
|
+
return True # Already claimed by this worker
|
|
156
|
+
return False # Claimed by another worker
|
|
157
|
+
|
|
158
|
+
worker = self._workers.get(worker_id)
|
|
159
|
+
if not worker:
|
|
160
|
+
return False
|
|
161
|
+
|
|
162
|
+
self._claimed_tasks[task_id] = worker_id
|
|
163
|
+
worker.is_busy = True
|
|
164
|
+
worker.current_task_id = task_id
|
|
165
|
+
logger.info(f'Task {task_id} claimed by worker {worker_id}')
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
async def release_task(self, task_id: str, worker_id: str) -> bool:
|
|
169
|
+
"""Release a task claim (on completion or failure)."""
|
|
170
|
+
async with self._lock:
|
|
171
|
+
if self._claimed_tasks.get(task_id) == worker_id:
|
|
172
|
+
del self._claimed_tasks[task_id]
|
|
173
|
+
worker = self._workers.get(worker_id)
|
|
174
|
+
if worker:
|
|
175
|
+
worker.is_busy = False
|
|
176
|
+
worker.current_task_id = None
|
|
177
|
+
logger.info(f'Task {task_id} released by worker {worker_id}')
|
|
178
|
+
return True
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
async def get_available_workers(
|
|
182
|
+
self,
|
|
183
|
+
codebase_id: Optional[str] = None,
|
|
184
|
+
required_capabilities: Optional[List[str]] = None,
|
|
185
|
+
target_agent_name: Optional[str] = None,
|
|
186
|
+
) -> List[ConnectedWorker]:
|
|
187
|
+
"""
|
|
188
|
+
Get workers available to accept a new task.
|
|
189
|
+
|
|
190
|
+
Filters by:
|
|
191
|
+
- Not currently busy
|
|
192
|
+
- Handles the specified codebase (workers must explicitly register codebases)
|
|
193
|
+
- Optionally: has required capabilities
|
|
194
|
+
- Optionally: matches target_agent_name (for agent-targeted routing)
|
|
195
|
+
|
|
196
|
+
IMPORTANT: Workers with no registered codebases will ONLY receive
|
|
197
|
+
'global' or '__pending__' tasks. This prevents cross-server task leakage
|
|
198
|
+
where a worker picks up tasks for codebases it doesn't have access to.
|
|
199
|
+
|
|
200
|
+
Agent Targeting:
|
|
201
|
+
- If target_agent_name is set, ONLY notify workers with that agent_name
|
|
202
|
+
- This reduces noise/wakeups for targeted tasks
|
|
203
|
+
- Claim-time filtering is the real enforcement; this is for efficiency
|
|
204
|
+
"""
|
|
205
|
+
async with self._lock:
|
|
206
|
+
available = []
|
|
207
|
+
for worker in self._workers.values():
|
|
208
|
+
if worker.is_busy:
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
# Agent targeting filter (notify-time filtering for efficiency)
|
|
212
|
+
# If task is targeted at a specific agent, only notify that agent
|
|
213
|
+
if target_agent_name:
|
|
214
|
+
if worker.agent_name != target_agent_name:
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
# Check codebase filter
|
|
218
|
+
if codebase_id:
|
|
219
|
+
# Special codebase IDs that any worker can handle
|
|
220
|
+
if codebase_id in ('global', '__pending__'):
|
|
221
|
+
pass # Any worker can handle these
|
|
222
|
+
elif codebase_id in worker.codebases:
|
|
223
|
+
pass # Worker explicitly registered this codebase
|
|
224
|
+
else:
|
|
225
|
+
# Task is for a specific codebase the worker doesn't have
|
|
226
|
+
# Skip this worker even if it has no codebases registered
|
|
227
|
+
# (empty codebases does NOT mean "can handle anything")
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
# Check capabilities filter
|
|
231
|
+
if required_capabilities:
|
|
232
|
+
if not all(
|
|
233
|
+
cap in worker.capabilities
|
|
234
|
+
for cap in required_capabilities
|
|
235
|
+
):
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
available.append(worker)
|
|
239
|
+
|
|
240
|
+
return available
|
|
241
|
+
|
|
242
|
+
async def get_worker(self, worker_id: str) -> Optional[ConnectedWorker]:
|
|
243
|
+
"""Get a specific worker by ID."""
|
|
244
|
+
async with self._lock:
|
|
245
|
+
return self._workers.get(worker_id)
|
|
246
|
+
|
|
247
|
+
async def list_workers(self) -> List[Dict[str, Any]]:
|
|
248
|
+
"""List all connected workers."""
|
|
249
|
+
async with self._lock:
|
|
250
|
+
return [
|
|
251
|
+
{
|
|
252
|
+
'worker_id': w.worker_id,
|
|
253
|
+
'agent_name': w.agent_name,
|
|
254
|
+
'connected_at': w.connected_at.isoformat(),
|
|
255
|
+
'last_heartbeat': w.last_heartbeat.isoformat(),
|
|
256
|
+
'is_busy': w.is_busy,
|
|
257
|
+
'current_task_id': w.current_task_id,
|
|
258
|
+
'capabilities': w.capabilities,
|
|
259
|
+
'codebases': list(w.codebases),
|
|
260
|
+
}
|
|
261
|
+
for w in self._workers.values()
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
async def push_task_to_worker(
|
|
265
|
+
self,
|
|
266
|
+
worker_id: str,
|
|
267
|
+
task: Dict[str, Any],
|
|
268
|
+
) -> bool:
|
|
269
|
+
"""
|
|
270
|
+
Push a task notification to a specific worker.
|
|
271
|
+
|
|
272
|
+
Returns True if the message was queued successfully.
|
|
273
|
+
"""
|
|
274
|
+
async with self._lock:
|
|
275
|
+
worker = self._workers.get(worker_id)
|
|
276
|
+
if not worker:
|
|
277
|
+
return False
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
event = {
|
|
281
|
+
'event': 'task_available',
|
|
282
|
+
'data': task,
|
|
283
|
+
'timestamp': datetime.now(timezone.utc).isoformat(),
|
|
284
|
+
}
|
|
285
|
+
await worker.queue.put(event)
|
|
286
|
+
return True
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.error(f'Failed to push task to worker {worker_id}: {e}')
|
|
289
|
+
return False
|
|
290
|
+
|
|
291
|
+
async def broadcast_task(
|
|
292
|
+
self,
|
|
293
|
+
task: Dict[str, Any],
|
|
294
|
+
codebase_id: Optional[str] = None,
|
|
295
|
+
target_agent_name: Optional[str] = None,
|
|
296
|
+
required_capabilities: Optional[List[str]] = None,
|
|
297
|
+
) -> List[str]:
|
|
298
|
+
"""
|
|
299
|
+
Broadcast a task to all available workers that can handle it.
|
|
300
|
+
|
|
301
|
+
For targeted tasks (target_agent_name set), only notifies the specific agent.
|
|
302
|
+
This is notify-time filtering for efficiency; claim-time is the real enforcement.
|
|
303
|
+
|
|
304
|
+
Returns list of worker_ids that received the notification.
|
|
305
|
+
"""
|
|
306
|
+
available = await self.get_available_workers(
|
|
307
|
+
codebase_id=codebase_id,
|
|
308
|
+
target_agent_name=target_agent_name,
|
|
309
|
+
required_capabilities=required_capabilities,
|
|
310
|
+
)
|
|
311
|
+
notified = []
|
|
312
|
+
|
|
313
|
+
for worker in available:
|
|
314
|
+
if await self.push_task_to_worker(worker.worker_id, task):
|
|
315
|
+
notified.append(worker.worker_id)
|
|
316
|
+
|
|
317
|
+
routing_info = ''
|
|
318
|
+
if target_agent_name:
|
|
319
|
+
routing_info = f' (targeted at {target_agent_name})'
|
|
320
|
+
|
|
321
|
+
logger.info(
|
|
322
|
+
f'Task {task.get("id", "unknown")} broadcast to {len(notified)} workers{routing_info}'
|
|
323
|
+
)
|
|
324
|
+
return notified
|
|
325
|
+
|
|
326
|
+
def add_task_listener(self, callback: Callable) -> None:
|
|
327
|
+
"""Add a callback to be notified when tasks are created."""
|
|
328
|
+
self._task_listeners.append(callback)
|
|
329
|
+
|
|
330
|
+
def remove_task_listener(self, callback: Callable) -> None:
|
|
331
|
+
"""Remove a task listener callback."""
|
|
332
|
+
if callback in self._task_listeners:
|
|
333
|
+
self._task_listeners.remove(callback)
|
|
334
|
+
|
|
335
|
+
async def notify_task_created(self, task: Dict[str, Any]) -> None:
|
|
336
|
+
"""Notify all listeners that a new task was created."""
|
|
337
|
+
for callback in self._task_listeners:
|
|
338
|
+
try:
|
|
339
|
+
if asyncio.iscoroutinefunction(callback):
|
|
340
|
+
await callback(task)
|
|
341
|
+
else:
|
|
342
|
+
callback(task)
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.error(f'Error in task listener: {e}')
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# Global worker registry singleton
|
|
348
|
+
_worker_registry: Optional[WorkerRegistry] = None
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def get_worker_registry() -> WorkerRegistry:
|
|
352
|
+
"""Get or create the global worker registry."""
|
|
353
|
+
global _worker_registry
|
|
354
|
+
if _worker_registry is None:
|
|
355
|
+
_worker_registry = WorkerRegistry()
|
|
356
|
+
return _worker_registry
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@lru_cache(maxsize=1)
|
|
360
|
+
def _get_auth_tokens_set() -> set:
|
|
361
|
+
"""Return the set of configured auth tokens (values only)."""
|
|
362
|
+
raw = os.environ.get('A2A_AUTH_TOKENS')
|
|
363
|
+
if not raw:
|
|
364
|
+
return set()
|
|
365
|
+
tokens: set = set()
|
|
366
|
+
for pair in raw.split(','):
|
|
367
|
+
pair = pair.strip()
|
|
368
|
+
if not pair:
|
|
369
|
+
continue
|
|
370
|
+
if ':' in pair:
|
|
371
|
+
_, token = pair.split(':', 1)
|
|
372
|
+
token = token.strip()
|
|
373
|
+
if token:
|
|
374
|
+
tokens.add(token)
|
|
375
|
+
else:
|
|
376
|
+
# Single token without name prefix
|
|
377
|
+
tokens.add(pair)
|
|
378
|
+
return tokens
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _verify_auth(request: Request) -> Optional[str]:
|
|
382
|
+
"""
|
|
383
|
+
Verify Bearer token if authentication is configured.
|
|
384
|
+
|
|
385
|
+
Returns the token if valid, raises HTTPException if invalid,
|
|
386
|
+
returns None if no auth is configured.
|
|
387
|
+
"""
|
|
388
|
+
tokens = _get_auth_tokens_set()
|
|
389
|
+
if not tokens:
|
|
390
|
+
return None # Auth not configured, allow all
|
|
391
|
+
|
|
392
|
+
auth = (
|
|
393
|
+
request.headers.get('authorization')
|
|
394
|
+
or request.headers.get('Authorization')
|
|
395
|
+
or ''
|
|
396
|
+
)
|
|
397
|
+
if not auth.startswith('Bearer '):
|
|
398
|
+
raise HTTPException(status_code=401, detail='Missing Bearer token')
|
|
399
|
+
|
|
400
|
+
token = auth.removeprefix('Bearer ').strip()
|
|
401
|
+
if not token or token not in tokens:
|
|
402
|
+
raise HTTPException(status_code=403, detail='Invalid token')
|
|
403
|
+
|
|
404
|
+
return token
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class TaskClaimRequest(BaseModel):
|
|
408
|
+
"""Request to claim a task."""
|
|
409
|
+
|
|
410
|
+
task_id: str
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class TaskReleaseRequest(BaseModel):
|
|
414
|
+
"""Request to release a task."""
|
|
415
|
+
|
|
416
|
+
task_id: str
|
|
417
|
+
status: str = 'completed' # completed, failed, cancelled
|
|
418
|
+
result: Optional[str] = None
|
|
419
|
+
error: Optional[str] = None
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class CodebaseUpdateRequest(BaseModel):
|
|
423
|
+
"""Request to update worker's codebase list."""
|
|
424
|
+
|
|
425
|
+
codebases: List[str]
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@worker_sse_router.get('/tasks/stream')
|
|
429
|
+
async def worker_task_stream(
|
|
430
|
+
request: Request,
|
|
431
|
+
agent_name: Optional[str] = Query(None, description='Worker agent name'),
|
|
432
|
+
worker_id: Optional[str] = Query(
|
|
433
|
+
None, description='Worker ID (optional, generated if not provided)'
|
|
434
|
+
),
|
|
435
|
+
x_agent_name: Optional[str] = Header(None, alias='X-Agent-Name'),
|
|
436
|
+
x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
|
|
437
|
+
x_capabilities: Optional[str] = Header(None, alias='X-Capabilities'),
|
|
438
|
+
x_codebases: Optional[str] = Header(None, alias='X-Codebases'),
|
|
439
|
+
):
|
|
440
|
+
"""
|
|
441
|
+
SSE endpoint for workers to receive task notifications.
|
|
442
|
+
|
|
443
|
+
Workers connect to this endpoint and receive:
|
|
444
|
+
- `task_available` events when new tasks are created
|
|
445
|
+
- `heartbeat` events every 30 seconds
|
|
446
|
+
- `task_claimed` confirmation when a task is successfully claimed
|
|
447
|
+
|
|
448
|
+
Headers:
|
|
449
|
+
- Authorization: Bearer <token> (required if A2A_AUTH_TOKENS is set)
|
|
450
|
+
- X-Agent-Name: Worker's agent name (alternative to query param)
|
|
451
|
+
- X-Worker-ID: Stable worker ID (optional)
|
|
452
|
+
- X-Capabilities: Comma-separated list of capabilities
|
|
453
|
+
- X-Codebases: Comma-separated list of codebase IDs this worker handles
|
|
454
|
+
|
|
455
|
+
Query params:
|
|
456
|
+
- agent_name: Worker's agent name
|
|
457
|
+
- worker_id: Stable worker ID (optional)
|
|
458
|
+
|
|
459
|
+
Events sent to worker:
|
|
460
|
+
```
|
|
461
|
+
event: connected
|
|
462
|
+
data: {"worker_id": "...", "message": "Connected to task stream"}
|
|
463
|
+
|
|
464
|
+
event: task_available
|
|
465
|
+
data: {"id": "...", "title": "...", "codebase_id": "...", ...}
|
|
466
|
+
|
|
467
|
+
event: heartbeat
|
|
468
|
+
data: {"timestamp": "..."}
|
|
469
|
+
```
|
|
470
|
+
"""
|
|
471
|
+
# Verify authentication
|
|
472
|
+
_verify_auth(request)
|
|
473
|
+
|
|
474
|
+
# Resolve agent_name from query param or header
|
|
475
|
+
resolved_agent_name = agent_name or x_agent_name
|
|
476
|
+
if not resolved_agent_name:
|
|
477
|
+
raise HTTPException(
|
|
478
|
+
status_code=400,
|
|
479
|
+
detail='agent_name is required (query param or X-Agent-Name header)',
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Resolve worker_id (generate if not provided)
|
|
483
|
+
resolved_worker_id = worker_id or x_worker_id or str(uuid.uuid4())[:12]
|
|
484
|
+
|
|
485
|
+
# Parse capabilities and codebases from headers
|
|
486
|
+
capabilities = []
|
|
487
|
+
if x_capabilities:
|
|
488
|
+
capabilities = [
|
|
489
|
+
c.strip() for c in x_capabilities.split(',') if c.strip()
|
|
490
|
+
]
|
|
491
|
+
|
|
492
|
+
codebases = set()
|
|
493
|
+
if x_codebases:
|
|
494
|
+
codebases = {c.strip() for c in x_codebases.split(',') if c.strip()}
|
|
495
|
+
|
|
496
|
+
registry = get_worker_registry()
|
|
497
|
+
|
|
498
|
+
async def event_generator():
|
|
499
|
+
"""Generate SSE events for the connected worker."""
|
|
500
|
+
queue: asyncio.Queue = asyncio.Queue()
|
|
501
|
+
worker = None
|
|
502
|
+
|
|
503
|
+
try:
|
|
504
|
+
# Register this worker
|
|
505
|
+
worker = await registry.register_worker(
|
|
506
|
+
worker_id=resolved_worker_id,
|
|
507
|
+
agent_name=resolved_agent_name,
|
|
508
|
+
queue=queue,
|
|
509
|
+
capabilities=capabilities,
|
|
510
|
+
codebases=codebases,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Send connection confirmation
|
|
514
|
+
connect_event = {
|
|
515
|
+
'event': 'connected',
|
|
516
|
+
'worker_id': resolved_worker_id,
|
|
517
|
+
'agent_name': resolved_agent_name,
|
|
518
|
+
'message': 'Connected to task stream',
|
|
519
|
+
'timestamp': datetime.now(timezone.utc).isoformat(),
|
|
520
|
+
}
|
|
521
|
+
yield f'event: connected\ndata: {json.dumps(connect_event)}\n\n'
|
|
522
|
+
|
|
523
|
+
# Send any pending tasks to the newly connected worker
|
|
524
|
+
try:
|
|
525
|
+
from .monitor_api import get_opencode_bridge
|
|
526
|
+
from .opencode_bridge import AgentTaskStatus
|
|
527
|
+
|
|
528
|
+
bridge = get_opencode_bridge()
|
|
529
|
+
if bridge:
|
|
530
|
+
# Get all pending tasks
|
|
531
|
+
pending_tasks = await bridge.list_tasks(
|
|
532
|
+
status=AgentTaskStatus.PENDING
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Filter tasks that this worker can handle (based on codebases)
|
|
536
|
+
worker_codebases = codebases or set()
|
|
537
|
+
sent_count = 0
|
|
538
|
+
|
|
539
|
+
for task in pending_tasks:
|
|
540
|
+
task_codebase = task.codebase_id
|
|
541
|
+
# Worker can handle task if:
|
|
542
|
+
# 1. Task has no specific codebase (global/__pending__)
|
|
543
|
+
# 2. Worker has the task's codebase in their list
|
|
544
|
+
# NOTE: Workers with no codebases can ONLY handle global/__pending__ tasks
|
|
545
|
+
# This prevents cross-server task leakage
|
|
546
|
+
can_handle = (
|
|
547
|
+
task_codebase in ('__pending__', 'global')
|
|
548
|
+
or task_codebase in worker_codebases
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
if can_handle:
|
|
552
|
+
task_data = {
|
|
553
|
+
'id': task.id,
|
|
554
|
+
'codebase_id': task.codebase_id,
|
|
555
|
+
'title': task.title,
|
|
556
|
+
'prompt': task.prompt,
|
|
557
|
+
'agent_type': task.agent_type,
|
|
558
|
+
'priority': task.priority,
|
|
559
|
+
'metadata': task.metadata,
|
|
560
|
+
'model': task.model,
|
|
561
|
+
'created_at': task.created_at.isoformat()
|
|
562
|
+
if task.created_at
|
|
563
|
+
else None,
|
|
564
|
+
}
|
|
565
|
+
yield f'event: task_available\ndata: {json.dumps(task_data)}\n\n'
|
|
566
|
+
sent_count += 1
|
|
567
|
+
|
|
568
|
+
if sent_count > 0:
|
|
569
|
+
logger.info(
|
|
570
|
+
f'Sent {sent_count} pending tasks to worker {resolved_worker_id} on connect'
|
|
571
|
+
)
|
|
572
|
+
except Exception as e:
|
|
573
|
+
logger.warning(f'Failed to send pending tasks on connect: {e}')
|
|
574
|
+
|
|
575
|
+
# Main event loop
|
|
576
|
+
heartbeat_interval = 30 # seconds
|
|
577
|
+
last_heartbeat = asyncio.get_event_loop().time()
|
|
578
|
+
|
|
579
|
+
while True:
|
|
580
|
+
# Check if client disconnected
|
|
581
|
+
if await request.is_disconnected():
|
|
582
|
+
logger.info(
|
|
583
|
+
f'Worker {resolved_worker_id} client disconnected'
|
|
584
|
+
)
|
|
585
|
+
break
|
|
586
|
+
|
|
587
|
+
try:
|
|
588
|
+
# Wait for events with timeout for heartbeat
|
|
589
|
+
current_time = asyncio.get_event_loop().time()
|
|
590
|
+
timeout = max(
|
|
591
|
+
0.1,
|
|
592
|
+
heartbeat_interval - (current_time - last_heartbeat),
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
try:
|
|
596
|
+
event = await asyncio.wait_for(
|
|
597
|
+
queue.get(), timeout=timeout
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# Format and send the event
|
|
601
|
+
event_type = event.get('event', 'message')
|
|
602
|
+
event_data = event.get('data', event)
|
|
603
|
+
|
|
604
|
+
yield f'event: {event_type}\ndata: {json.dumps(event_data)}\n\n'
|
|
605
|
+
|
|
606
|
+
except asyncio.TimeoutError:
|
|
607
|
+
pass # No event, check if heartbeat needed
|
|
608
|
+
|
|
609
|
+
# Send heartbeat if interval elapsed
|
|
610
|
+
current_time = asyncio.get_event_loop().time()
|
|
611
|
+
if current_time - last_heartbeat >= heartbeat_interval:
|
|
612
|
+
heartbeat_data = {
|
|
613
|
+
'timestamp': datetime.now(timezone.utc).isoformat(),
|
|
614
|
+
'worker_id': resolved_worker_id,
|
|
615
|
+
}
|
|
616
|
+
yield f'event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n'
|
|
617
|
+
last_heartbeat = current_time
|
|
618
|
+
await registry.update_heartbeat(resolved_worker_id)
|
|
619
|
+
|
|
620
|
+
except asyncio.CancelledError:
|
|
621
|
+
logger.info(f'Worker {resolved_worker_id} stream cancelled')
|
|
622
|
+
break
|
|
623
|
+
except Exception as e:
|
|
624
|
+
logger.error(
|
|
625
|
+
f'Error in worker stream {resolved_worker_id}: {e}'
|
|
626
|
+
)
|
|
627
|
+
break
|
|
628
|
+
|
|
629
|
+
finally:
|
|
630
|
+
# Unregister worker on disconnect
|
|
631
|
+
if worker:
|
|
632
|
+
await registry.unregister_worker(resolved_worker_id)
|
|
633
|
+
|
|
634
|
+
return StreamingResponse(
|
|
635
|
+
event_generator(),
|
|
636
|
+
media_type='text/event-stream',
|
|
637
|
+
headers={
|
|
638
|
+
'Cache-Control': 'no-cache',
|
|
639
|
+
'Connection': 'keep-alive',
|
|
640
|
+
'X-Accel-Buffering': 'no',
|
|
641
|
+
},
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
@worker_sse_router.post('/tasks/claim')
|
|
646
|
+
async def claim_task(
|
|
647
|
+
request: Request,
|
|
648
|
+
claim: TaskClaimRequest,
|
|
649
|
+
worker_id: Optional[str] = Query(None),
|
|
650
|
+
x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
|
|
651
|
+
):
|
|
652
|
+
"""
|
|
653
|
+
Claim a task for processing.
|
|
654
|
+
|
|
655
|
+
Workers call this endpoint after receiving a task_available event
|
|
656
|
+
to atomically claim the task. This prevents multiple workers from
|
|
657
|
+
processing the same task.
|
|
658
|
+
|
|
659
|
+
Returns 200 if claim succeeded, 409 if task already claimed.
|
|
660
|
+
"""
|
|
661
|
+
_verify_auth(request)
|
|
662
|
+
|
|
663
|
+
resolved_worker_id = worker_id or x_worker_id
|
|
664
|
+
if not resolved_worker_id:
|
|
665
|
+
raise HTTPException(
|
|
666
|
+
status_code=400,
|
|
667
|
+
detail='worker_id is required (query param or X-Worker-ID header)',
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
registry = get_worker_registry()
|
|
671
|
+
|
|
672
|
+
success = await registry.claim_task(claim.task_id, resolved_worker_id)
|
|
673
|
+
|
|
674
|
+
if success:
|
|
675
|
+
return {
|
|
676
|
+
'success': True,
|
|
677
|
+
'task_id': claim.task_id,
|
|
678
|
+
'worker_id': resolved_worker_id,
|
|
679
|
+
'message': 'Task claimed successfully',
|
|
680
|
+
}
|
|
681
|
+
else:
|
|
682
|
+
raise HTTPException(
|
|
683
|
+
status_code=409,
|
|
684
|
+
detail=f'Task {claim.task_id} already claimed by another worker',
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
@worker_sse_router.post('/tasks/release')
|
|
689
|
+
async def release_task(
|
|
690
|
+
request: Request,
|
|
691
|
+
release: TaskReleaseRequest,
|
|
692
|
+
worker_id: Optional[str] = Query(None),
|
|
693
|
+
x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
|
|
694
|
+
):
|
|
695
|
+
"""
|
|
696
|
+
Release a task after completion or failure.
|
|
697
|
+
|
|
698
|
+
Workers call this endpoint when they finish processing a task
|
|
699
|
+
to release the claim and report the result.
|
|
700
|
+
"""
|
|
701
|
+
_verify_auth(request)
|
|
702
|
+
|
|
703
|
+
resolved_worker_id = worker_id or x_worker_id
|
|
704
|
+
if not resolved_worker_id:
|
|
705
|
+
raise HTTPException(
|
|
706
|
+
status_code=400,
|
|
707
|
+
detail='worker_id is required (query param or X-Worker-ID header)',
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
registry = get_worker_registry()
|
|
711
|
+
|
|
712
|
+
success = await registry.release_task(release.task_id, resolved_worker_id)
|
|
713
|
+
|
|
714
|
+
if success:
|
|
715
|
+
return {
|
|
716
|
+
'success': True,
|
|
717
|
+
'task_id': release.task_id,
|
|
718
|
+
'worker_id': resolved_worker_id,
|
|
719
|
+
'status': release.status,
|
|
720
|
+
'message': 'Task released successfully',
|
|
721
|
+
}
|
|
722
|
+
else:
|
|
723
|
+
raise HTTPException(
|
|
724
|
+
status_code=404,
|
|
725
|
+
detail=f'Task {release.task_id} not claimed by worker {resolved_worker_id}',
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
@worker_sse_router.put('/codebases')
|
|
730
|
+
async def update_worker_codebases(
|
|
731
|
+
request: Request,
|
|
732
|
+
update: CodebaseUpdateRequest,
|
|
733
|
+
worker_id: Optional[str] = Query(None),
|
|
734
|
+
x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
|
|
735
|
+
):
|
|
736
|
+
"""
|
|
737
|
+
Update the list of codebases a worker can handle.
|
|
738
|
+
|
|
739
|
+
Workers call this endpoint after registering new codebases
|
|
740
|
+
to update the server's routing table.
|
|
741
|
+
"""
|
|
742
|
+
_verify_auth(request)
|
|
743
|
+
|
|
744
|
+
resolved_worker_id = worker_id or x_worker_id
|
|
745
|
+
if not resolved_worker_id:
|
|
746
|
+
raise HTTPException(
|
|
747
|
+
status_code=400,
|
|
748
|
+
detail='worker_id is required (query param or X-Worker-ID header)',
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
registry = get_worker_registry()
|
|
752
|
+
|
|
753
|
+
success = await registry.update_worker_codebases(
|
|
754
|
+
resolved_worker_id, set(update.codebases)
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
if success:
|
|
758
|
+
return {
|
|
759
|
+
'success': True,
|
|
760
|
+
'worker_id': resolved_worker_id,
|
|
761
|
+
'codebases': update.codebases,
|
|
762
|
+
'message': 'Codebases updated successfully',
|
|
763
|
+
}
|
|
764
|
+
else:
|
|
765
|
+
raise HTTPException(
|
|
766
|
+
status_code=404,
|
|
767
|
+
detail=f'Worker {resolved_worker_id} not found (not connected via SSE?)',
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
@worker_sse_router.get('/connected')
|
|
772
|
+
async def list_connected_workers(request: Request):
|
|
773
|
+
"""
|
|
774
|
+
List all workers currently connected via SSE.
|
|
775
|
+
|
|
776
|
+
Returns information about each connected worker including
|
|
777
|
+
their capabilities, codebases, and current status.
|
|
778
|
+
"""
|
|
779
|
+
_verify_auth(request)
|
|
780
|
+
|
|
781
|
+
registry = get_worker_registry()
|
|
782
|
+
workers = await registry.list_workers()
|
|
783
|
+
|
|
784
|
+
return {
|
|
785
|
+
'workers': workers,
|
|
786
|
+
'count': len(workers),
|
|
787
|
+
'timestamp': datetime.now(timezone.utc).isoformat(),
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
@worker_sse_router.get('/connected/{worker_id}')
|
|
792
|
+
async def get_connected_worker(
|
|
793
|
+
request: Request,
|
|
794
|
+
worker_id: str,
|
|
795
|
+
):
|
|
796
|
+
"""Get details about a specific connected worker."""
|
|
797
|
+
_verify_auth(request)
|
|
798
|
+
|
|
799
|
+
registry = get_worker_registry()
|
|
800
|
+
worker = await registry.get_worker(worker_id)
|
|
801
|
+
|
|
802
|
+
if not worker:
|
|
803
|
+
raise HTTPException(
|
|
804
|
+
status_code=404,
|
|
805
|
+
detail=f'Worker {worker_id} not found or not connected',
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
return {
|
|
809
|
+
'worker_id': worker.worker_id,
|
|
810
|
+
'agent_name': worker.agent_name,
|
|
811
|
+
'connected_at': worker.connected_at.isoformat(),
|
|
812
|
+
'last_heartbeat': worker.last_heartbeat.isoformat(),
|
|
813
|
+
'is_busy': worker.is_busy,
|
|
814
|
+
'current_task_id': worker.current_task_id,
|
|
815
|
+
'capabilities': worker.capabilities,
|
|
816
|
+
'codebases': list(worker.codebases),
|
|
817
|
+
}
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
# ============================================================================
|
|
821
|
+
# Integration helpers for task creation
|
|
822
|
+
# ============================================================================
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
async def notify_workers_of_new_task(task: Dict[str, Any]) -> List[str]:
|
|
826
|
+
"""
|
|
827
|
+
Notify connected workers of a new task.
|
|
828
|
+
|
|
829
|
+
This function should be called when a new task is created to push
|
|
830
|
+
it to available workers via SSE.
|
|
831
|
+
|
|
832
|
+
Supports agent-targeted routing:
|
|
833
|
+
- If task has target_agent_name, only notifies that specific agent
|
|
834
|
+
- If task has required_capabilities, only notifies capable workers
|
|
835
|
+
|
|
836
|
+
Returns list of worker_ids that received the notification.
|
|
837
|
+
"""
|
|
838
|
+
registry = get_worker_registry()
|
|
839
|
+
codebase_id = task.get('codebase_id')
|
|
840
|
+
target_agent_name = task.get('target_agent_name')
|
|
841
|
+
required_capabilities = task.get('required_capabilities')
|
|
842
|
+
|
|
843
|
+
# Parse required_capabilities if it's a JSON string
|
|
844
|
+
if isinstance(required_capabilities, str):
|
|
845
|
+
import json
|
|
846
|
+
|
|
847
|
+
try:
|
|
848
|
+
required_capabilities = json.loads(required_capabilities)
|
|
849
|
+
except (json.JSONDecodeError, TypeError):
|
|
850
|
+
required_capabilities = None
|
|
851
|
+
|
|
852
|
+
return await registry.broadcast_task(
|
|
853
|
+
task,
|
|
854
|
+
codebase_id=codebase_id,
|
|
855
|
+
target_agent_name=target_agent_name,
|
|
856
|
+
required_capabilities=required_capabilities,
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
def setup_task_creation_hook(opencode_bridge) -> None:
|
|
861
|
+
"""
|
|
862
|
+
Set up a hook to notify workers when tasks are created.
|
|
863
|
+
|
|
864
|
+
This should be called during server initialization to connect
|
|
865
|
+
the task queue to the SSE push system.
|
|
866
|
+
"""
|
|
867
|
+
registry = get_worker_registry()
|
|
868
|
+
|
|
869
|
+
async def on_task_created(task: Dict[str, Any]):
|
|
870
|
+
await notify_workers_of_new_task(task)
|
|
871
|
+
|
|
872
|
+
registry.add_task_listener(on_task_created)
|
|
873
|
+
logger.info('Task creation hook installed for SSE worker notifications')
|