codetether 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. a2a_server/__init__.py +29 -0
  2. a2a_server/a2a_agent_card.py +365 -0
  3. a2a_server/a2a_errors.py +1133 -0
  4. a2a_server/a2a_executor.py +926 -0
  5. a2a_server/a2a_router.py +1033 -0
  6. a2a_server/a2a_types.py +344 -0
  7. a2a_server/agent_card.py +408 -0
  8. a2a_server/agents_server.py +271 -0
  9. a2a_server/auth_api.py +349 -0
  10. a2a_server/billing_api.py +638 -0
  11. a2a_server/billing_service.py +712 -0
  12. a2a_server/billing_webhooks.py +501 -0
  13. a2a_server/config.py +96 -0
  14. a2a_server/database.py +2165 -0
  15. a2a_server/email_inbound.py +398 -0
  16. a2a_server/email_notifications.py +486 -0
  17. a2a_server/enhanced_agents.py +919 -0
  18. a2a_server/enhanced_server.py +160 -0
  19. a2a_server/hosted_worker.py +1049 -0
  20. a2a_server/integrated_agents_server.py +347 -0
  21. a2a_server/keycloak_auth.py +750 -0
  22. a2a_server/livekit_bridge.py +439 -0
  23. a2a_server/marketing_tools.py +1364 -0
  24. a2a_server/mcp_client.py +196 -0
  25. a2a_server/mcp_http_server.py +2256 -0
  26. a2a_server/mcp_server.py +191 -0
  27. a2a_server/message_broker.py +725 -0
  28. a2a_server/mock_mcp.py +273 -0
  29. a2a_server/models.py +494 -0
  30. a2a_server/monitor_api.py +5904 -0
  31. a2a_server/opencode_bridge.py +1594 -0
  32. a2a_server/redis_task_manager.py +518 -0
  33. a2a_server/server.py +726 -0
  34. a2a_server/task_manager.py +668 -0
  35. a2a_server/task_queue.py +742 -0
  36. a2a_server/tenant_api.py +333 -0
  37. a2a_server/tenant_middleware.py +219 -0
  38. a2a_server/tenant_service.py +760 -0
  39. a2a_server/user_auth.py +721 -0
  40. a2a_server/vault_client.py +576 -0
  41. a2a_server/worker_sse.py +873 -0
  42. agent_worker/__init__.py +8 -0
  43. agent_worker/worker.py +4877 -0
  44. codetether/__init__.py +10 -0
  45. codetether/__main__.py +4 -0
  46. codetether/cli.py +112 -0
  47. codetether/worker_cli.py +57 -0
  48. codetether-1.2.2.dist-info/METADATA +570 -0
  49. codetether-1.2.2.dist-info/RECORD +66 -0
  50. codetether-1.2.2.dist-info/WHEEL +5 -0
  51. codetether-1.2.2.dist-info/entry_points.txt +4 -0
  52. codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
  53. codetether-1.2.2.dist-info/top_level.txt +5 -0
  54. codetether_voice_agent/__init__.py +6 -0
  55. codetether_voice_agent/agent.py +445 -0
  56. codetether_voice_agent/codetether_mcp.py +345 -0
  57. codetether_voice_agent/config.py +16 -0
  58. codetether_voice_agent/functiongemma_caller.py +380 -0
  59. codetether_voice_agent/session_playback.py +247 -0
  60. codetether_voice_agent/tools/__init__.py +21 -0
  61. codetether_voice_agent/tools/definitions.py +135 -0
  62. codetether_voice_agent/tools/handlers.py +380 -0
  63. run_server.py +314 -0
  64. ui/monitor-tailwind.html +1790 -0
  65. ui/monitor.html +1775 -0
  66. ui/monitor.js +2662 -0
@@ -0,0 +1,873 @@
1
+ """
2
+ Worker SSE Task Stream - Push-based task distribution for A2A workers.
3
+
4
+ This module implements "reverse-polling" via Server-Sent Events (SSE) where workers
5
+ connect outbound to the server and receive task notifications pushed to them.
6
+
7
+ Design:
8
+ - Workers connect to GET /v1/worker/tasks/stream with their agent_name
9
+ - Server maintains a registry of connected workers
10
+ - When a task is created, server pushes it to an available connected worker
11
+ - Task claiming ensures only one worker gets each task
12
+ - Heartbeat/keepalive every 30 seconds
13
+
14
+ Security:
15
+ - Workers identify themselves via agent_name (query param or header)
16
+ - Optional Bearer token authentication via A2A_AUTH_TOKENS env var
17
+ """
18
+
19
+ import asyncio
20
+ import json
21
+ import logging
22
+ import os
23
+ import uuid
24
+ from dataclasses import dataclass, field
25
+ from datetime import datetime, timezone
26
+ from typing import Any, Callable, Dict, List, Optional, Set
27
+ from functools import lru_cache
28
+
29
+ from fastapi import APIRouter, HTTPException, Request, Query, Header
30
+ from fastapi.responses import StreamingResponse, JSONResponse
31
+ from pydantic import BaseModel
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Router for worker SSE endpoints
36
+ worker_sse_router = APIRouter(prefix='/v1/worker', tags=['worker-sse'])
37
+
38
+
39
+ @dataclass
40
+ class ConnectedWorker:
41
+ """Represents a worker connected via SSE."""
42
+
43
+ worker_id: str
44
+ agent_name: str
45
+ queue: asyncio.Queue
46
+ connected_at: datetime
47
+ last_heartbeat: datetime
48
+ capabilities: List[str] = field(default_factory=list)
49
+ codebases: Set[str] = field(default_factory=set)
50
+ is_busy: bool = False # True when worker is processing a task
51
+ current_task_id: Optional[str] = None
52
+
53
+
54
+ class WorkerRegistry:
55
+ """
56
+ Registry of workers connected via SSE for push-based task distribution.
57
+
58
+ Thread-safe via asyncio locks. Supports:
59
+ - Worker registration/deregistration on SSE connect/disconnect
60
+ - Task routing to available workers
61
+ - Atomic task claiming to prevent double-assignment
62
+ - Heartbeat tracking for connection health
63
+ """
64
+
65
+ def __init__(self):
66
+ self._workers: Dict[str, ConnectedWorker] = {}
67
+ self._lock = asyncio.Lock()
68
+ # Track claimed tasks: task_id -> worker_id
69
+ self._claimed_tasks: Dict[str, str] = {}
70
+ # Callbacks for task creation events
71
+ self._task_listeners: List[Callable] = []
72
+
73
+ async def register_worker(
74
+ self,
75
+ worker_id: str,
76
+ agent_name: str,
77
+ queue: asyncio.Queue,
78
+ capabilities: Optional[List[str]] = None,
79
+ codebases: Optional[Set[str]] = None,
80
+ ) -> ConnectedWorker:
81
+ """Register a new SSE-connected worker."""
82
+ async with self._lock:
83
+ now = datetime.now(timezone.utc)
84
+ worker = ConnectedWorker(
85
+ worker_id=worker_id,
86
+ agent_name=agent_name,
87
+ queue=queue,
88
+ connected_at=now,
89
+ last_heartbeat=now,
90
+ capabilities=capabilities or [],
91
+ codebases=codebases or set(),
92
+ )
93
+ self._workers[worker_id] = worker
94
+ logger.info(
95
+ f"Worker '{agent_name}' (id={worker_id}) connected via SSE. "
96
+ f'Total connected: {len(self._workers)}'
97
+ )
98
+ return worker
99
+
100
+ async def unregister_worker(
101
+ self, worker_id: str
102
+ ) -> Optional[ConnectedWorker]:
103
+ """Unregister a disconnected worker."""
104
+ async with self._lock:
105
+ worker = self._workers.pop(worker_id, None)
106
+ if worker:
107
+ # Release any claimed tasks back to pending
108
+ tasks_to_release = [
109
+ tid
110
+ for tid, wid in self._claimed_tasks.items()
111
+ if wid == worker_id
112
+ ]
113
+ for tid in tasks_to_release:
114
+ del self._claimed_tasks[tid]
115
+ logger.warning(
116
+ f'Task {tid} released due to worker disconnect (worker={worker_id})'
117
+ )
118
+
119
+ logger.info(
120
+ f"Worker '{worker.agent_name}' (id={worker_id}) disconnected. "
121
+ f'Total connected: {len(self._workers)}'
122
+ )
123
+ return worker
124
+
125
+ async def update_heartbeat(self, worker_id: str) -> bool:
126
+ """Update the last heartbeat time for a worker."""
127
+ async with self._lock:
128
+ worker = self._workers.get(worker_id)
129
+ if worker:
130
+ worker.last_heartbeat = datetime.now(timezone.utc)
131
+ return True
132
+ return False
133
+
134
+ async def update_worker_codebases(
135
+ self, worker_id: str, codebases: Set[str]
136
+ ) -> bool:
137
+ """Update the codebases a worker can handle."""
138
+ async with self._lock:
139
+ worker = self._workers.get(worker_id)
140
+ if worker:
141
+ worker.codebases = codebases
142
+ return True
143
+ return False
144
+
145
+ async def claim_task(self, task_id: str, worker_id: str) -> bool:
146
+ """
147
+ Atomically claim a task for a worker.
148
+
149
+ Returns True if claim succeeded, False if task was already claimed.
150
+ """
151
+ async with self._lock:
152
+ if task_id in self._claimed_tasks:
153
+ existing_worker = self._claimed_tasks[task_id]
154
+ if existing_worker == worker_id:
155
+ return True # Already claimed by this worker
156
+ return False # Claimed by another worker
157
+
158
+ worker = self._workers.get(worker_id)
159
+ if not worker:
160
+ return False
161
+
162
+ self._claimed_tasks[task_id] = worker_id
163
+ worker.is_busy = True
164
+ worker.current_task_id = task_id
165
+ logger.info(f'Task {task_id} claimed by worker {worker_id}')
166
+ return True
167
+
168
+ async def release_task(self, task_id: str, worker_id: str) -> bool:
169
+ """Release a task claim (on completion or failure)."""
170
+ async with self._lock:
171
+ if self._claimed_tasks.get(task_id) == worker_id:
172
+ del self._claimed_tasks[task_id]
173
+ worker = self._workers.get(worker_id)
174
+ if worker:
175
+ worker.is_busy = False
176
+ worker.current_task_id = None
177
+ logger.info(f'Task {task_id} released by worker {worker_id}')
178
+ return True
179
+ return False
180
+
181
+ async def get_available_workers(
182
+ self,
183
+ codebase_id: Optional[str] = None,
184
+ required_capabilities: Optional[List[str]] = None,
185
+ target_agent_name: Optional[str] = None,
186
+ ) -> List[ConnectedWorker]:
187
+ """
188
+ Get workers available to accept a new task.
189
+
190
+ Filters by:
191
+ - Not currently busy
192
+ - Handles the specified codebase (workers must explicitly register codebases)
193
+ - Optionally: has required capabilities
194
+ - Optionally: matches target_agent_name (for agent-targeted routing)
195
+
196
+ IMPORTANT: Workers with no registered codebases will ONLY receive
197
+ 'global' or '__pending__' tasks. This prevents cross-server task leakage
198
+ where a worker picks up tasks for codebases it doesn't have access to.
199
+
200
+ Agent Targeting:
201
+ - If target_agent_name is set, ONLY notify workers with that agent_name
202
+ - This reduces noise/wakeups for targeted tasks
203
+ - Claim-time filtering is the real enforcement; this is for efficiency
204
+ """
205
+ async with self._lock:
206
+ available = []
207
+ for worker in self._workers.values():
208
+ if worker.is_busy:
209
+ continue
210
+
211
+ # Agent targeting filter (notify-time filtering for efficiency)
212
+ # If task is targeted at a specific agent, only notify that agent
213
+ if target_agent_name:
214
+ if worker.agent_name != target_agent_name:
215
+ continue
216
+
217
+ # Check codebase filter
218
+ if codebase_id:
219
+ # Special codebase IDs that any worker can handle
220
+ if codebase_id in ('global', '__pending__'):
221
+ pass # Any worker can handle these
222
+ elif codebase_id in worker.codebases:
223
+ pass # Worker explicitly registered this codebase
224
+ else:
225
+ # Task is for a specific codebase the worker doesn't have
226
+ # Skip this worker even if it has no codebases registered
227
+ # (empty codebases does NOT mean "can handle anything")
228
+ continue
229
+
230
+ # Check capabilities filter
231
+ if required_capabilities:
232
+ if not all(
233
+ cap in worker.capabilities
234
+ for cap in required_capabilities
235
+ ):
236
+ continue
237
+
238
+ available.append(worker)
239
+
240
+ return available
241
+
242
+ async def get_worker(self, worker_id: str) -> Optional[ConnectedWorker]:
243
+ """Get a specific worker by ID."""
244
+ async with self._lock:
245
+ return self._workers.get(worker_id)
246
+
247
+ async def list_workers(self) -> List[Dict[str, Any]]:
248
+ """List all connected workers."""
249
+ async with self._lock:
250
+ return [
251
+ {
252
+ 'worker_id': w.worker_id,
253
+ 'agent_name': w.agent_name,
254
+ 'connected_at': w.connected_at.isoformat(),
255
+ 'last_heartbeat': w.last_heartbeat.isoformat(),
256
+ 'is_busy': w.is_busy,
257
+ 'current_task_id': w.current_task_id,
258
+ 'capabilities': w.capabilities,
259
+ 'codebases': list(w.codebases),
260
+ }
261
+ for w in self._workers.values()
262
+ ]
263
+
264
+ async def push_task_to_worker(
265
+ self,
266
+ worker_id: str,
267
+ task: Dict[str, Any],
268
+ ) -> bool:
269
+ """
270
+ Push a task notification to a specific worker.
271
+
272
+ Returns True if the message was queued successfully.
273
+ """
274
+ async with self._lock:
275
+ worker = self._workers.get(worker_id)
276
+ if not worker:
277
+ return False
278
+
279
+ try:
280
+ event = {
281
+ 'event': 'task_available',
282
+ 'data': task,
283
+ 'timestamp': datetime.now(timezone.utc).isoformat(),
284
+ }
285
+ await worker.queue.put(event)
286
+ return True
287
+ except Exception as e:
288
+ logger.error(f'Failed to push task to worker {worker_id}: {e}')
289
+ return False
290
+
291
+ async def broadcast_task(
292
+ self,
293
+ task: Dict[str, Any],
294
+ codebase_id: Optional[str] = None,
295
+ target_agent_name: Optional[str] = None,
296
+ required_capabilities: Optional[List[str]] = None,
297
+ ) -> List[str]:
298
+ """
299
+ Broadcast a task to all available workers that can handle it.
300
+
301
+ For targeted tasks (target_agent_name set), only notifies the specific agent.
302
+ This is notify-time filtering for efficiency; claim-time is the real enforcement.
303
+
304
+ Returns list of worker_ids that received the notification.
305
+ """
306
+ available = await self.get_available_workers(
307
+ codebase_id=codebase_id,
308
+ target_agent_name=target_agent_name,
309
+ required_capabilities=required_capabilities,
310
+ )
311
+ notified = []
312
+
313
+ for worker in available:
314
+ if await self.push_task_to_worker(worker.worker_id, task):
315
+ notified.append(worker.worker_id)
316
+
317
+ routing_info = ''
318
+ if target_agent_name:
319
+ routing_info = f' (targeted at {target_agent_name})'
320
+
321
+ logger.info(
322
+ f'Task {task.get("id", "unknown")} broadcast to {len(notified)} workers{routing_info}'
323
+ )
324
+ return notified
325
+
326
+ def add_task_listener(self, callback: Callable) -> None:
327
+ """Add a callback to be notified when tasks are created."""
328
+ self._task_listeners.append(callback)
329
+
330
+ def remove_task_listener(self, callback: Callable) -> None:
331
+ """Remove a task listener callback."""
332
+ if callback in self._task_listeners:
333
+ self._task_listeners.remove(callback)
334
+
335
+ async def notify_task_created(self, task: Dict[str, Any]) -> None:
336
+ """Notify all listeners that a new task was created."""
337
+ for callback in self._task_listeners:
338
+ try:
339
+ if asyncio.iscoroutinefunction(callback):
340
+ await callback(task)
341
+ else:
342
+ callback(task)
343
+ except Exception as e:
344
+ logger.error(f'Error in task listener: {e}')
345
+
346
+
347
+ # Global worker registry singleton
348
+ _worker_registry: Optional[WorkerRegistry] = None
349
+
350
+
351
+ def get_worker_registry() -> WorkerRegistry:
352
+ """Get or create the global worker registry."""
353
+ global _worker_registry
354
+ if _worker_registry is None:
355
+ _worker_registry = WorkerRegistry()
356
+ return _worker_registry
357
+
358
+
359
+ @lru_cache(maxsize=1)
360
+ def _get_auth_tokens_set() -> set:
361
+ """Return the set of configured auth tokens (values only)."""
362
+ raw = os.environ.get('A2A_AUTH_TOKENS')
363
+ if not raw:
364
+ return set()
365
+ tokens: set = set()
366
+ for pair in raw.split(','):
367
+ pair = pair.strip()
368
+ if not pair:
369
+ continue
370
+ if ':' in pair:
371
+ _, token = pair.split(':', 1)
372
+ token = token.strip()
373
+ if token:
374
+ tokens.add(token)
375
+ else:
376
+ # Single token without name prefix
377
+ tokens.add(pair)
378
+ return tokens
379
+
380
+
381
+ def _verify_auth(request: Request) -> Optional[str]:
382
+ """
383
+ Verify Bearer token if authentication is configured.
384
+
385
+ Returns the token if valid, raises HTTPException if invalid,
386
+ returns None if no auth is configured.
387
+ """
388
+ tokens = _get_auth_tokens_set()
389
+ if not tokens:
390
+ return None # Auth not configured, allow all
391
+
392
+ auth = (
393
+ request.headers.get('authorization')
394
+ or request.headers.get('Authorization')
395
+ or ''
396
+ )
397
+ if not auth.startswith('Bearer '):
398
+ raise HTTPException(status_code=401, detail='Missing Bearer token')
399
+
400
+ token = auth.removeprefix('Bearer ').strip()
401
+ if not token or token not in tokens:
402
+ raise HTTPException(status_code=403, detail='Invalid token')
403
+
404
+ return token
405
+
406
+
407
+ class TaskClaimRequest(BaseModel):
408
+ """Request to claim a task."""
409
+
410
+ task_id: str
411
+
412
+
413
+ class TaskReleaseRequest(BaseModel):
414
+ """Request to release a task."""
415
+
416
+ task_id: str
417
+ status: str = 'completed' # completed, failed, cancelled
418
+ result: Optional[str] = None
419
+ error: Optional[str] = None
420
+
421
+
422
+ class CodebaseUpdateRequest(BaseModel):
423
+ """Request to update worker's codebase list."""
424
+
425
+ codebases: List[str]
426
+
427
+
428
+ @worker_sse_router.get('/tasks/stream')
429
+ async def worker_task_stream(
430
+ request: Request,
431
+ agent_name: Optional[str] = Query(None, description='Worker agent name'),
432
+ worker_id: Optional[str] = Query(
433
+ None, description='Worker ID (optional, generated if not provided)'
434
+ ),
435
+ x_agent_name: Optional[str] = Header(None, alias='X-Agent-Name'),
436
+ x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
437
+ x_capabilities: Optional[str] = Header(None, alias='X-Capabilities'),
438
+ x_codebases: Optional[str] = Header(None, alias='X-Codebases'),
439
+ ):
440
+ """
441
+ SSE endpoint for workers to receive task notifications.
442
+
443
+ Workers connect to this endpoint and receive:
444
+ - `task_available` events when new tasks are created
445
+ - `heartbeat` events every 30 seconds
446
+ - `task_claimed` confirmation when a task is successfully claimed
447
+
448
+ Headers:
449
+ - Authorization: Bearer <token> (required if A2A_AUTH_TOKENS is set)
450
+ - X-Agent-Name: Worker's agent name (alternative to query param)
451
+ - X-Worker-ID: Stable worker ID (optional)
452
+ - X-Capabilities: Comma-separated list of capabilities
453
+ - X-Codebases: Comma-separated list of codebase IDs this worker handles
454
+
455
+ Query params:
456
+ - agent_name: Worker's agent name
457
+ - worker_id: Stable worker ID (optional)
458
+
459
+ Events sent to worker:
460
+ ```
461
+ event: connected
462
+ data: {"worker_id": "...", "message": "Connected to task stream"}
463
+
464
+ event: task_available
465
+ data: {"id": "...", "title": "...", "codebase_id": "...", ...}
466
+
467
+ event: heartbeat
468
+ data: {"timestamp": "..."}
469
+ ```
470
+ """
471
+ # Verify authentication
472
+ _verify_auth(request)
473
+
474
+ # Resolve agent_name from query param or header
475
+ resolved_agent_name = agent_name or x_agent_name
476
+ if not resolved_agent_name:
477
+ raise HTTPException(
478
+ status_code=400,
479
+ detail='agent_name is required (query param or X-Agent-Name header)',
480
+ )
481
+
482
+ # Resolve worker_id (generate if not provided)
483
+ resolved_worker_id = worker_id or x_worker_id or str(uuid.uuid4())[:12]
484
+
485
+ # Parse capabilities and codebases from headers
486
+ capabilities = []
487
+ if x_capabilities:
488
+ capabilities = [
489
+ c.strip() for c in x_capabilities.split(',') if c.strip()
490
+ ]
491
+
492
+ codebases = set()
493
+ if x_codebases:
494
+ codebases = {c.strip() for c in x_codebases.split(',') if c.strip()}
495
+
496
+ registry = get_worker_registry()
497
+
498
+ async def event_generator():
499
+ """Generate SSE events for the connected worker."""
500
+ queue: asyncio.Queue = asyncio.Queue()
501
+ worker = None
502
+
503
+ try:
504
+ # Register this worker
505
+ worker = await registry.register_worker(
506
+ worker_id=resolved_worker_id,
507
+ agent_name=resolved_agent_name,
508
+ queue=queue,
509
+ capabilities=capabilities,
510
+ codebases=codebases,
511
+ )
512
+
513
+ # Send connection confirmation
514
+ connect_event = {
515
+ 'event': 'connected',
516
+ 'worker_id': resolved_worker_id,
517
+ 'agent_name': resolved_agent_name,
518
+ 'message': 'Connected to task stream',
519
+ 'timestamp': datetime.now(timezone.utc).isoformat(),
520
+ }
521
+ yield f'event: connected\ndata: {json.dumps(connect_event)}\n\n'
522
+
523
+ # Send any pending tasks to the newly connected worker
524
+ try:
525
+ from .monitor_api import get_opencode_bridge
526
+ from .opencode_bridge import AgentTaskStatus
527
+
528
+ bridge = get_opencode_bridge()
529
+ if bridge:
530
+ # Get all pending tasks
531
+ pending_tasks = await bridge.list_tasks(
532
+ status=AgentTaskStatus.PENDING
533
+ )
534
+
535
+ # Filter tasks that this worker can handle (based on codebases)
536
+ worker_codebases = codebases or set()
537
+ sent_count = 0
538
+
539
+ for task in pending_tasks:
540
+ task_codebase = task.codebase_id
541
+ # Worker can handle task if:
542
+ # 1. Task has no specific codebase (global/__pending__)
543
+ # 2. Worker has the task's codebase in their list
544
+ # NOTE: Workers with no codebases can ONLY handle global/__pending__ tasks
545
+ # This prevents cross-server task leakage
546
+ can_handle = (
547
+ task_codebase in ('__pending__', 'global')
548
+ or task_codebase in worker_codebases
549
+ )
550
+
551
+ if can_handle:
552
+ task_data = {
553
+ 'id': task.id,
554
+ 'codebase_id': task.codebase_id,
555
+ 'title': task.title,
556
+ 'prompt': task.prompt,
557
+ 'agent_type': task.agent_type,
558
+ 'priority': task.priority,
559
+ 'metadata': task.metadata,
560
+ 'model': task.model,
561
+ 'created_at': task.created_at.isoformat()
562
+ if task.created_at
563
+ else None,
564
+ }
565
+ yield f'event: task_available\ndata: {json.dumps(task_data)}\n\n'
566
+ sent_count += 1
567
+
568
+ if sent_count > 0:
569
+ logger.info(
570
+ f'Sent {sent_count} pending tasks to worker {resolved_worker_id} on connect'
571
+ )
572
+ except Exception as e:
573
+ logger.warning(f'Failed to send pending tasks on connect: {e}')
574
+
575
+ # Main event loop
576
+ heartbeat_interval = 30 # seconds
577
+ last_heartbeat = asyncio.get_event_loop().time()
578
+
579
+ while True:
580
+ # Check if client disconnected
581
+ if await request.is_disconnected():
582
+ logger.info(
583
+ f'Worker {resolved_worker_id} client disconnected'
584
+ )
585
+ break
586
+
587
+ try:
588
+ # Wait for events with timeout for heartbeat
589
+ current_time = asyncio.get_event_loop().time()
590
+ timeout = max(
591
+ 0.1,
592
+ heartbeat_interval - (current_time - last_heartbeat),
593
+ )
594
+
595
+ try:
596
+ event = await asyncio.wait_for(
597
+ queue.get(), timeout=timeout
598
+ )
599
+
600
+ # Format and send the event
601
+ event_type = event.get('event', 'message')
602
+ event_data = event.get('data', event)
603
+
604
+ yield f'event: {event_type}\ndata: {json.dumps(event_data)}\n\n'
605
+
606
+ except asyncio.TimeoutError:
607
+ pass # No event, check if heartbeat needed
608
+
609
+ # Send heartbeat if interval elapsed
610
+ current_time = asyncio.get_event_loop().time()
611
+ if current_time - last_heartbeat >= heartbeat_interval:
612
+ heartbeat_data = {
613
+ 'timestamp': datetime.now(timezone.utc).isoformat(),
614
+ 'worker_id': resolved_worker_id,
615
+ }
616
+ yield f'event: heartbeat\ndata: {json.dumps(heartbeat_data)}\n\n'
617
+ last_heartbeat = current_time
618
+ await registry.update_heartbeat(resolved_worker_id)
619
+
620
+ except asyncio.CancelledError:
621
+ logger.info(f'Worker {resolved_worker_id} stream cancelled')
622
+ break
623
+ except Exception as e:
624
+ logger.error(
625
+ f'Error in worker stream {resolved_worker_id}: {e}'
626
+ )
627
+ break
628
+
629
+ finally:
630
+ # Unregister worker on disconnect
631
+ if worker:
632
+ await registry.unregister_worker(resolved_worker_id)
633
+
634
+ return StreamingResponse(
635
+ event_generator(),
636
+ media_type='text/event-stream',
637
+ headers={
638
+ 'Cache-Control': 'no-cache',
639
+ 'Connection': 'keep-alive',
640
+ 'X-Accel-Buffering': 'no',
641
+ },
642
+ )
643
+
644
+
645
+ @worker_sse_router.post('/tasks/claim')
646
+ async def claim_task(
647
+ request: Request,
648
+ claim: TaskClaimRequest,
649
+ worker_id: Optional[str] = Query(None),
650
+ x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
651
+ ):
652
+ """
653
+ Claim a task for processing.
654
+
655
+ Workers call this endpoint after receiving a task_available event
656
+ to atomically claim the task. This prevents multiple workers from
657
+ processing the same task.
658
+
659
+ Returns 200 if claim succeeded, 409 if task already claimed.
660
+ """
661
+ _verify_auth(request)
662
+
663
+ resolved_worker_id = worker_id or x_worker_id
664
+ if not resolved_worker_id:
665
+ raise HTTPException(
666
+ status_code=400,
667
+ detail='worker_id is required (query param or X-Worker-ID header)',
668
+ )
669
+
670
+ registry = get_worker_registry()
671
+
672
+ success = await registry.claim_task(claim.task_id, resolved_worker_id)
673
+
674
+ if success:
675
+ return {
676
+ 'success': True,
677
+ 'task_id': claim.task_id,
678
+ 'worker_id': resolved_worker_id,
679
+ 'message': 'Task claimed successfully',
680
+ }
681
+ else:
682
+ raise HTTPException(
683
+ status_code=409,
684
+ detail=f'Task {claim.task_id} already claimed by another worker',
685
+ )
686
+
687
+
688
+ @worker_sse_router.post('/tasks/release')
689
+ async def release_task(
690
+ request: Request,
691
+ release: TaskReleaseRequest,
692
+ worker_id: Optional[str] = Query(None),
693
+ x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
694
+ ):
695
+ """
696
+ Release a task after completion or failure.
697
+
698
+ Workers call this endpoint when they finish processing a task
699
+ to release the claim and report the result.
700
+ """
701
+ _verify_auth(request)
702
+
703
+ resolved_worker_id = worker_id or x_worker_id
704
+ if not resolved_worker_id:
705
+ raise HTTPException(
706
+ status_code=400,
707
+ detail='worker_id is required (query param or X-Worker-ID header)',
708
+ )
709
+
710
+ registry = get_worker_registry()
711
+
712
+ success = await registry.release_task(release.task_id, resolved_worker_id)
713
+
714
+ if success:
715
+ return {
716
+ 'success': True,
717
+ 'task_id': release.task_id,
718
+ 'worker_id': resolved_worker_id,
719
+ 'status': release.status,
720
+ 'message': 'Task released successfully',
721
+ }
722
+ else:
723
+ raise HTTPException(
724
+ status_code=404,
725
+ detail=f'Task {release.task_id} not claimed by worker {resolved_worker_id}',
726
+ )
727
+
728
+
729
+ @worker_sse_router.put('/codebases')
730
+ async def update_worker_codebases(
731
+ request: Request,
732
+ update: CodebaseUpdateRequest,
733
+ worker_id: Optional[str] = Query(None),
734
+ x_worker_id: Optional[str] = Header(None, alias='X-Worker-ID'),
735
+ ):
736
+ """
737
+ Update the list of codebases a worker can handle.
738
+
739
+ Workers call this endpoint after registering new codebases
740
+ to update the server's routing table.
741
+ """
742
+ _verify_auth(request)
743
+
744
+ resolved_worker_id = worker_id or x_worker_id
745
+ if not resolved_worker_id:
746
+ raise HTTPException(
747
+ status_code=400,
748
+ detail='worker_id is required (query param or X-Worker-ID header)',
749
+ )
750
+
751
+ registry = get_worker_registry()
752
+
753
+ success = await registry.update_worker_codebases(
754
+ resolved_worker_id, set(update.codebases)
755
+ )
756
+
757
+ if success:
758
+ return {
759
+ 'success': True,
760
+ 'worker_id': resolved_worker_id,
761
+ 'codebases': update.codebases,
762
+ 'message': 'Codebases updated successfully',
763
+ }
764
+ else:
765
+ raise HTTPException(
766
+ status_code=404,
767
+ detail=f'Worker {resolved_worker_id} not found (not connected via SSE?)',
768
+ )
769
+
770
+
771
+ @worker_sse_router.get('/connected')
772
+ async def list_connected_workers(request: Request):
773
+ """
774
+ List all workers currently connected via SSE.
775
+
776
+ Returns information about each connected worker including
777
+ their capabilities, codebases, and current status.
778
+ """
779
+ _verify_auth(request)
780
+
781
+ registry = get_worker_registry()
782
+ workers = await registry.list_workers()
783
+
784
+ return {
785
+ 'workers': workers,
786
+ 'count': len(workers),
787
+ 'timestamp': datetime.now(timezone.utc).isoformat(),
788
+ }
789
+
790
+
791
+ @worker_sse_router.get('/connected/{worker_id}')
792
+ async def get_connected_worker(
793
+ request: Request,
794
+ worker_id: str,
795
+ ):
796
+ """Get details about a specific connected worker."""
797
+ _verify_auth(request)
798
+
799
+ registry = get_worker_registry()
800
+ worker = await registry.get_worker(worker_id)
801
+
802
+ if not worker:
803
+ raise HTTPException(
804
+ status_code=404,
805
+ detail=f'Worker {worker_id} not found or not connected',
806
+ )
807
+
808
+ return {
809
+ 'worker_id': worker.worker_id,
810
+ 'agent_name': worker.agent_name,
811
+ 'connected_at': worker.connected_at.isoformat(),
812
+ 'last_heartbeat': worker.last_heartbeat.isoformat(),
813
+ 'is_busy': worker.is_busy,
814
+ 'current_task_id': worker.current_task_id,
815
+ 'capabilities': worker.capabilities,
816
+ 'codebases': list(worker.codebases),
817
+ }
818
+
819
+
820
+ # ============================================================================
821
+ # Integration helpers for task creation
822
+ # ============================================================================
823
+
824
+
825
+ async def notify_workers_of_new_task(task: Dict[str, Any]) -> List[str]:
826
+ """
827
+ Notify connected workers of a new task.
828
+
829
+ This function should be called when a new task is created to push
830
+ it to available workers via SSE.
831
+
832
+ Supports agent-targeted routing:
833
+ - If task has target_agent_name, only notifies that specific agent
834
+ - If task has required_capabilities, only notifies capable workers
835
+
836
+ Returns list of worker_ids that received the notification.
837
+ """
838
+ registry = get_worker_registry()
839
+ codebase_id = task.get('codebase_id')
840
+ target_agent_name = task.get('target_agent_name')
841
+ required_capabilities = task.get('required_capabilities')
842
+
843
+ # Parse required_capabilities if it's a JSON string
844
+ if isinstance(required_capabilities, str):
845
+ import json
846
+
847
+ try:
848
+ required_capabilities = json.loads(required_capabilities)
849
+ except (json.JSONDecodeError, TypeError):
850
+ required_capabilities = None
851
+
852
+ return await registry.broadcast_task(
853
+ task,
854
+ codebase_id=codebase_id,
855
+ target_agent_name=target_agent_name,
856
+ required_capabilities=required_capabilities,
857
+ )
858
+
859
+
860
+ def setup_task_creation_hook(opencode_bridge) -> None:
861
+ """
862
+ Set up a hook to notify workers when tasks are created.
863
+
864
+ This should be called during server initialization to connect
865
+ the task queue to the SSE push system.
866
+ """
867
+ registry = get_worker_registry()
868
+
869
+ async def on_task_created(task: Dict[str, Any]):
870
+ await notify_workers_of_new_task(task)
871
+
872
+ registry.add_task_listener(on_task_created)
873
+ logger.info('Task creation hook installed for SSE worker notifications')