codetether 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. a2a_server/__init__.py +29 -0
  2. a2a_server/a2a_agent_card.py +365 -0
  3. a2a_server/a2a_errors.py +1133 -0
  4. a2a_server/a2a_executor.py +926 -0
  5. a2a_server/a2a_router.py +1033 -0
  6. a2a_server/a2a_types.py +344 -0
  7. a2a_server/agent_card.py +408 -0
  8. a2a_server/agents_server.py +271 -0
  9. a2a_server/auth_api.py +349 -0
  10. a2a_server/billing_api.py +638 -0
  11. a2a_server/billing_service.py +712 -0
  12. a2a_server/billing_webhooks.py +501 -0
  13. a2a_server/config.py +96 -0
  14. a2a_server/database.py +2165 -0
  15. a2a_server/email_inbound.py +398 -0
  16. a2a_server/email_notifications.py +486 -0
  17. a2a_server/enhanced_agents.py +919 -0
  18. a2a_server/enhanced_server.py +160 -0
  19. a2a_server/hosted_worker.py +1049 -0
  20. a2a_server/integrated_agents_server.py +347 -0
  21. a2a_server/keycloak_auth.py +750 -0
  22. a2a_server/livekit_bridge.py +439 -0
  23. a2a_server/marketing_tools.py +1364 -0
  24. a2a_server/mcp_client.py +196 -0
  25. a2a_server/mcp_http_server.py +2256 -0
  26. a2a_server/mcp_server.py +191 -0
  27. a2a_server/message_broker.py +725 -0
  28. a2a_server/mock_mcp.py +273 -0
  29. a2a_server/models.py +494 -0
  30. a2a_server/monitor_api.py +5904 -0
  31. a2a_server/opencode_bridge.py +1594 -0
  32. a2a_server/redis_task_manager.py +518 -0
  33. a2a_server/server.py +726 -0
  34. a2a_server/task_manager.py +668 -0
  35. a2a_server/task_queue.py +742 -0
  36. a2a_server/tenant_api.py +333 -0
  37. a2a_server/tenant_middleware.py +219 -0
  38. a2a_server/tenant_service.py +760 -0
  39. a2a_server/user_auth.py +721 -0
  40. a2a_server/vault_client.py +576 -0
  41. a2a_server/worker_sse.py +873 -0
  42. agent_worker/__init__.py +8 -0
  43. agent_worker/worker.py +4877 -0
  44. codetether/__init__.py +10 -0
  45. codetether/__main__.py +4 -0
  46. codetether/cli.py +112 -0
  47. codetether/worker_cli.py +57 -0
  48. codetether-1.2.2.dist-info/METADATA +570 -0
  49. codetether-1.2.2.dist-info/RECORD +66 -0
  50. codetether-1.2.2.dist-info/WHEEL +5 -0
  51. codetether-1.2.2.dist-info/entry_points.txt +4 -0
  52. codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
  53. codetether-1.2.2.dist-info/top_level.txt +5 -0
  54. codetether_voice_agent/__init__.py +6 -0
  55. codetether_voice_agent/agent.py +445 -0
  56. codetether_voice_agent/codetether_mcp.py +345 -0
  57. codetether_voice_agent/config.py +16 -0
  58. codetether_voice_agent/functiongemma_caller.py +380 -0
  59. codetether_voice_agent/session_playback.py +247 -0
  60. codetether_voice_agent/tools/__init__.py +21 -0
  61. codetether_voice_agent/tools/definitions.py +135 -0
  62. codetether_voice_agent/tools/handlers.py +380 -0
  63. run_server.py +314 -0
  64. ui/monitor-tailwind.html +1790 -0
  65. ui/monitor.html +1775 -0
  66. ui/monitor.js +2662 -0
@@ -0,0 +1,1049 @@
1
+ """
2
+ Hosted Worker - Managed task execution for mid-market users.
3
+
4
+ This module provides:
5
+ - Worker pool that claims and executes tasks from the queue
6
+ - Lease-based job locking with heartbeat renewal
7
+ - Automatic expired lease reclamation
8
+ - Graceful shutdown support
9
+
10
+ Usage:
11
+ python -m a2a_server.hosted_worker --workers 2 --db-url postgresql://...
12
+
13
+ The worker process runs N concurrent workers that poll the task_runs queue.
14
+ """
15
+
16
+ import argparse
17
+ import asyncio
18
+ import json
19
+ import logging
20
+ import os
21
+ import signal
22
+ import socket
23
+ import sys
24
+ import uuid
25
+ from datetime import datetime, timezone
26
+ from typing import Any, Dict, List, Optional
27
+
28
+ import asyncpg
29
+ import httpx
30
+
31
+ from .task_queue import TaskRunStatus
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Configuration defaults
36
+ DEFAULT_WORKERS = 2
37
+ DEFAULT_POLL_INTERVAL = 2.0 # seconds
38
+ DEFAULT_LEASE_DURATION = 600 # 10 minutes
39
+ DEFAULT_HEARTBEAT_INTERVAL = 60 # 1 minute
40
+
41
+
42
+ class HostedWorker:
43
+ """
44
+ A single worker that claims and executes tasks from the queue.
45
+
46
+ Multiple HostedWorker instances run concurrently in the same process.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ worker_id: str,
52
+ db_pool: asyncpg.Pool,
53
+ api_base_url: str,
54
+ poll_interval: float = DEFAULT_POLL_INTERVAL,
55
+ lease_duration: int = DEFAULT_LEASE_DURATION,
56
+ heartbeat_interval: int = DEFAULT_HEARTBEAT_INTERVAL,
57
+ agent_name: Optional[str] = None,
58
+ capabilities: Optional[List[str]] = None,
59
+ ):
60
+ self.worker_id = worker_id
61
+ self._pool = db_pool
62
+ self._api_base_url = api_base_url.rstrip('/')
63
+ self._poll_interval = poll_interval
64
+ self._lease_duration = lease_duration
65
+ self._heartbeat_interval = heartbeat_interval
66
+ # Agent identity for targeted routing
67
+ self.agent_name = agent_name
68
+ self.capabilities = capabilities or []
69
+
70
+ self._running = False
71
+ self._current_run_id: Optional[str] = None
72
+ self._current_task_id: Optional[str] = None
73
+ self._heartbeat_task: Optional[asyncio.Task] = None
74
+ self._http_client: Optional[httpx.AsyncClient] = None
75
+
76
+ # Stats
77
+ self.tasks_completed = 0
78
+ self.tasks_failed = 0
79
+ self.total_runtime_seconds = 0
80
+
81
+ async def start(self) -> None:
82
+ """Start the worker loop."""
83
+ self._running = True
84
+ self._http_client = httpx.AsyncClient(
85
+ timeout=300.0
86
+ ) # 5 min timeout for task execution
87
+
88
+ logger.info(f'Worker {self.worker_id} starting')
89
+
90
+ try:
91
+ while self._running:
92
+ try:
93
+ # Try to claim next job
94
+ claimed = await self._claim_next_job()
95
+
96
+ if claimed:
97
+ # Execute the task
98
+ await self._execute_current_task()
99
+ else:
100
+ # No work available, wait before polling again
101
+ await asyncio.sleep(self._poll_interval)
102
+
103
+ except asyncio.CancelledError:
104
+ logger.info(f'Worker {self.worker_id} cancelled')
105
+ break
106
+ except Exception as e:
107
+ logger.error(
108
+ f'Worker {self.worker_id} error: {e}', exc_info=True
109
+ )
110
+ await asyncio.sleep(self._poll_interval)
111
+ finally:
112
+ await self._cleanup()
113
+
114
+ async def stop(self) -> None:
115
+ """Gracefully stop the worker."""
116
+ logger.info(f'Worker {self.worker_id} stopping')
117
+ self._running = False
118
+
119
+ # Stop heartbeat if running
120
+ if self._heartbeat_task:
121
+ self._heartbeat_task.cancel()
122
+ try:
123
+ await self._heartbeat_task
124
+ except asyncio.CancelledError:
125
+ pass
126
+
127
+ async def _claim_next_job(self) -> bool:
128
+ """
129
+ Attempt to claim the next available job from the queue.
130
+
131
+ Uses the claim_next_task_run() SQL function for atomic claiming
132
+ with concurrency limit enforcement and agent-targeted routing.
133
+
134
+ The agent_name and capabilities are passed to filter tasks:
135
+ - Tasks with target_agent_name set will only be claimed by matching workers
136
+ - Tasks with required_capabilities will only be claimed by workers with ALL required caps
137
+
138
+ Returns True if a job was claimed.
139
+ """
140
+ import json
141
+
142
+ # Convert capabilities list to JSONB format for SQL
143
+ capabilities_json = (
144
+ json.dumps(self.capabilities) if self.capabilities else '[]'
145
+ )
146
+
147
+ async with self._pool.acquire() as conn:
148
+ row = await conn.fetchrow(
149
+ 'SELECT * FROM claim_next_task_run($1, $2, $3, $4::jsonb)',
150
+ self.worker_id,
151
+ self._lease_duration,
152
+ self.agent_name, # Pass agent_name for targeted routing
153
+ capabilities_json, # Pass capabilities for capability-based routing
154
+ )
155
+
156
+ if row and row['run_id']:
157
+ self._current_run_id = row['run_id']
158
+ self._current_task_id = row['task_id']
159
+ target_info = (
160
+ f', targeted_at={row.get("target_agent_name")}'
161
+ if row.get('target_agent_name')
162
+ else ''
163
+ )
164
+ logger.info(
165
+ f'Worker {self.worker_id} (agent={self.agent_name}) claimed run {self._current_run_id} '
166
+ f'(task={self._current_task_id}, priority={row["priority"]}{target_info})'
167
+ )
168
+ return True
169
+
170
+ return False
171
+
172
+ async def _execute_current_task(self) -> None:
173
+ """Execute the currently claimed task."""
174
+ if not self._current_run_id or not self._current_task_id:
175
+ return
176
+
177
+ run_id = self._current_run_id
178
+ task_id = self._current_task_id
179
+ started_at = datetime.now(timezone.utc)
180
+
181
+ # Start heartbeat to keep lease alive
182
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop(run_id))
183
+
184
+ try:
185
+ logger.info(f'Worker {self.worker_id} executing task {task_id}')
186
+
187
+ # Get task details from API
188
+ task_data = await self._get_task_details(task_id)
189
+ if not task_data:
190
+ raise Exception(f'Task {task_id} not found')
191
+
192
+ # Execute the task via API (this triggers the actual agent work)
193
+ result = await self._run_task(task_id, task_data)
194
+
195
+ # Mark completed
196
+ runtime = int(
197
+ (datetime.now(timezone.utc) - started_at).total_seconds()
198
+ )
199
+ await self._complete_run(
200
+ run_id,
201
+ status='completed',
202
+ result_summary=result.get('summary', 'Task completed'),
203
+ result_full=result,
204
+ )
205
+
206
+ self.tasks_completed += 1
207
+ self.total_runtime_seconds += runtime
208
+ logger.info(
209
+ f'Worker {self.worker_id} completed task {task_id} in {runtime}s'
210
+ )
211
+
212
+ except Exception as e:
213
+ logger.error(
214
+ f'Worker {self.worker_id} failed task {task_id}: {e}',
215
+ exc_info=True,
216
+ )
217
+
218
+ await self._complete_run(
219
+ run_id,
220
+ status='failed',
221
+ error=str(e),
222
+ )
223
+
224
+ self.tasks_failed += 1
225
+
226
+ finally:
227
+ # Stop heartbeat
228
+ if self._heartbeat_task:
229
+ self._heartbeat_task.cancel()
230
+ try:
231
+ await self._heartbeat_task
232
+ except asyncio.CancelledError:
233
+ pass
234
+ self._heartbeat_task = None
235
+
236
+ self._current_run_id = None
237
+ self._current_task_id = None
238
+
239
+ async def _get_task_details(self, task_id: str) -> Optional[Dict[str, Any]]:
240
+ """Get task details from the API."""
241
+ if not self._http_client:
242
+ return None
243
+ try:
244
+ response = await self._http_client.post(
245
+ f'{self._api_base_url}/mcp/v1/rpc',
246
+ json={
247
+ 'jsonrpc': '2.0',
248
+ 'method': 'tools/call',
249
+ 'params': {
250
+ 'name': 'get_task',
251
+ 'arguments': {'task_id': task_id},
252
+ },
253
+ 'id': str(uuid.uuid4()),
254
+ },
255
+ )
256
+ response.raise_for_status()
257
+ data = response.json()
258
+
259
+ if 'error' in data:
260
+ logger.error(
261
+ f'API error getting task {task_id}: {data["error"]}'
262
+ )
263
+ return None
264
+
265
+ result = data.get('result', {})
266
+ # Handle MCP response format
267
+ if isinstance(result, dict) and 'content' in result:
268
+ for content in result['content']:
269
+ if content.get('type') == 'text':
270
+ return json.loads(content['text'])
271
+ return result
272
+
273
+ except Exception as e:
274
+ logger.error(f'Failed to get task {task_id}: {e}')
275
+ return None
276
+
277
+ async def _run_task(
278
+ self, task_id: str, task_data: Dict[str, Any]
279
+ ) -> Dict[str, Any]:
280
+ """
281
+ Execute the task by invoking the appropriate agent.
282
+
283
+ For now, this calls the existing worker SSE task execution path.
284
+ In the future, this could run agents directly.
285
+ """
286
+ # Get the task prompt
287
+ prompt = task_data.get('description') or task_data.get('prompt', '')
288
+ codebase_id = task_data.get('codebase_id', 'global')
289
+ agent_type = task_data.get('agent_type', 'build')
290
+ model = task_data.get('model')
291
+
292
+ logger.info(
293
+ f'Running task {task_id}: agent={agent_type}, codebase={codebase_id}'
294
+ )
295
+
296
+ # For global/pending tasks, we can execute directly via API
297
+ # This triggers the existing execution path which may use SSE workers
298
+ # or we can implement direct execution here
299
+
300
+ # Option 1: Use the continue_task endpoint to trigger execution
301
+ # This works with the existing worker infrastructure
302
+ if not self._http_client:
303
+ raise Exception('HTTP client not initialized')
304
+ try:
305
+ response = await self._http_client.post(
306
+ f'{self._api_base_url}/mcp/v1/rpc',
307
+ json={
308
+ 'jsonrpc': '2.0',
309
+ 'method': 'tools/call',
310
+ 'params': {
311
+ 'name': 'continue_task',
312
+ 'arguments': {
313
+ 'task_id': task_id,
314
+ 'input': prompt, # Pass the prompt as continuation
315
+ },
316
+ },
317
+ 'id': str(uuid.uuid4()),
318
+ },
319
+ timeout=self._lease_duration, # Don't timeout before lease
320
+ )
321
+ response.raise_for_status()
322
+ data = response.json()
323
+
324
+ if 'error' in data:
325
+ raise Exception(f'Task execution error: {data["error"]}')
326
+
327
+ result = data.get('result', {})
328
+
329
+ # Extract result from MCP format
330
+ if isinstance(result, dict) and 'content' in result:
331
+ for content in result['content']:
332
+ if content.get('type') == 'text':
333
+ try:
334
+ return json.loads(content['text'])
335
+ except json.JSONDecodeError:
336
+ return {
337
+ 'summary': content['text'],
338
+ 'raw': content['text'],
339
+ }
340
+
341
+ return {'summary': 'Task completed', 'result': result}
342
+
343
+ except httpx.TimeoutException:
344
+ # Task took too long - it may still be running
345
+ # Check task status
346
+ task_details = await self._get_task_details(task_id)
347
+ if task_details and task_details.get('status') == 'completed':
348
+ return {
349
+ 'summary': 'Task completed (timeout during response)',
350
+ 'result': task_details,
351
+ }
352
+ raise Exception('Task execution timed out')
353
+
354
+ async def _heartbeat_loop(self, run_id: str) -> None:
355
+ """Periodically renew the lease on the current job."""
356
+ try:
357
+ while True:
358
+ await asyncio.sleep(self._heartbeat_interval)
359
+
360
+ async with self._pool.acquire() as conn:
361
+ renewed = await conn.fetchval(
362
+ 'SELECT renew_task_run_lease($1, $2, $3)',
363
+ run_id,
364
+ self.worker_id,
365
+ self._lease_duration,
366
+ )
367
+
368
+ if not renewed:
369
+ logger.warning(
370
+ f'Failed to renew lease for run {run_id}'
371
+ )
372
+ break
373
+
374
+ logger.debug(f'Renewed lease for run {run_id}')
375
+
376
+ except asyncio.CancelledError:
377
+ pass
378
+
379
+ async def _complete_run(
380
+ self,
381
+ run_id: str,
382
+ status: str,
383
+ result_summary: Optional[str] = None,
384
+ result_full: Optional[Dict[str, Any]] = None,
385
+ error: Optional[str] = None,
386
+ ) -> None:
387
+ """Mark a task run as completed or failed."""
388
+ async with self._pool.acquire() as conn:
389
+ await conn.fetchval(
390
+ 'SELECT complete_task_run($1, $2, $3, $4, $5, $6)',
391
+ run_id,
392
+ self.worker_id,
393
+ status,
394
+ result_summary,
395
+ json.dumps(result_full) if result_full else None,
396
+ error,
397
+ )
398
+
399
+ logger.info(f'Run {run_id} marked as {status}')
400
+
401
+ # TODO: Send notification (email/webhook) if configured
402
+ await self._send_completion_notification(run_id, status)
403
+
404
+ async def _send_completion_notification(
405
+ self, run_id: str, status: str
406
+ ) -> None:
407
+ """
408
+ Send completion notification (email/webhook) with retry-safe 3-state flow.
409
+
410
+ Flow:
411
+ 1. Atomically claim notification for send (increments attempts)
412
+ 2. Try to send
413
+ 3. On success: mark_notification_sent()
414
+ 4. On failure: mark_notification_failed() with backoff for retry
415
+
416
+ This prevents both duplicate sends AND permanent silence.
417
+ """
418
+ # Get notification settings and task details for this run
419
+ async with self._pool.acquire() as conn:
420
+ row = await conn.fetchrow(
421
+ """
422
+ SELECT tr.notify_email, tr.notify_webhook_url,
423
+ tr.notification_status, tr.webhook_status,
424
+ tr.result_summary, tr.result_full, tr.task_id,
425
+ tr.runtime_seconds, tr.last_error,
426
+ t.title, t.prompt
427
+ FROM task_runs tr
428
+ LEFT JOIN tasks t ON tr.task_id = t.id
429
+ WHERE tr.id = $1
430
+ """,
431
+ run_id,
432
+ )
433
+
434
+ if not row:
435
+ return
436
+
437
+ # Send email notification with 3-state tracking
438
+ if row['notify_email'] and row['notification_status'] != 'sent':
439
+ await self._send_email_notification(run_id, row, status)
440
+
441
+ # Send webhook notification with 3-state tracking
442
+ if row['notify_webhook_url'] and row['webhook_status'] != 'sent':
443
+ await self._send_webhook_notification(run_id, row, status)
444
+
445
+ async def _send_email_notification(
446
+ self, run_id: str, row: dict, status: str
447
+ ) -> None:
448
+ """Send email notification with atomic claim and retry support."""
449
+ # Atomically claim the notification (prevents double-send)
450
+ async with self._pool.acquire() as conn:
451
+ claimed = await conn.fetchval(
452
+ 'SELECT claim_notification_for_send($1, $2)',
453
+ run_id,
454
+ 3, # max_attempts
455
+ )
456
+
457
+ if not claimed:
458
+ logger.debug(
459
+ f'Notification already claimed or sent for run {run_id}'
460
+ )
461
+ return
462
+
463
+ # Now try to send the email
464
+ try:
465
+ from .email_notifications import send_task_completion_email
466
+
467
+ # Extract result from JSON if stored
468
+ result_text = row['result_summary']
469
+ if row['result_full']:
470
+ try:
471
+ result_full = (
472
+ json.loads(row['result_full'])
473
+ if isinstance(row['result_full'], str)
474
+ else row['result_full']
475
+ )
476
+ result_text = (
477
+ result_full.get('summary')
478
+ or result_full.get('result')
479
+ or result_text
480
+ )
481
+ except (json.JSONDecodeError, TypeError):
482
+ pass
483
+
484
+ email_sent = await send_task_completion_email(
485
+ to_email=row['notify_email'],
486
+ task_id=row['task_id'],
487
+ title=row['title'] or 'Task',
488
+ status=status,
489
+ result=result_text,
490
+ error=row['last_error'] if status == 'failed' else None,
491
+ runtime_seconds=row['runtime_seconds'],
492
+ worker_name=self.worker_id,
493
+ )
494
+
495
+ if email_sent:
496
+ # Mark as sent - success!
497
+ async with self._pool.acquire() as conn:
498
+ await conn.execute(
499
+ 'SELECT mark_notification_sent($1)', run_id
500
+ )
501
+ logger.info(
502
+ f'Completion email sent to {row["notify_email"]} for run {run_id}'
503
+ )
504
+ else:
505
+ # Mark as failed with retry backoff
506
+ async with self._pool.acquire() as conn:
507
+ await conn.execute(
508
+ 'SELECT mark_notification_failed($1, $2, $3)',
509
+ run_id,
510
+ 'SendGrid returned failure (check API key/config)',
511
+ 3,
512
+ )
513
+ logger.warning(
514
+ f'Failed to send completion email for run {run_id}, will retry'
515
+ )
516
+
517
+ except ImportError as e:
518
+ async with self._pool.acquire() as conn:
519
+ await conn.execute(
520
+ 'SELECT mark_notification_failed($1, $2, $3)',
521
+ run_id,
522
+ f'email_notifications module not available: {e}',
523
+ 3,
524
+ )
525
+ logger.warning('email_notifications module not available')
526
+ except Exception as e:
527
+ # Mark as failed with retry backoff
528
+ async with self._pool.acquire() as conn:
529
+ await conn.execute(
530
+ 'SELECT mark_notification_failed($1, $2, $3)',
531
+ run_id,
532
+ str(e)[:500], # Truncate error message
533
+ 3,
534
+ )
535
+ logger.error(f'Error sending completion email: {e}')
536
+
537
+ async def _send_webhook_notification(
538
+ self, run_id: str, row: dict, status: str
539
+ ) -> None:
540
+ """Send webhook notification with atomic claim and retry support."""
541
+ # Atomically claim the webhook notification
542
+ async with self._pool.acquire() as conn:
543
+ claimed = await conn.fetchval(
544
+ 'SELECT claim_webhook_for_send($1, $2)',
545
+ run_id,
546
+ 3, # max_attempts
547
+ )
548
+
549
+ if not claimed:
550
+ logger.debug(
551
+ f'Webhook already claimed or sent for run {run_id}'
552
+ )
553
+ return
554
+
555
+ # Now try to call the webhook
556
+ try:
557
+ await self._call_webhook(
558
+ url=row['notify_webhook_url'],
559
+ run_id=run_id,
560
+ task_id=row['task_id'],
561
+ status=status,
562
+ result=row['result_summary'],
563
+ error=row['last_error'],
564
+ )
565
+
566
+ # Mark as sent - success!
567
+ async with self._pool.acquire() as conn:
568
+ await conn.execute('SELECT mark_webhook_sent($1)', run_id)
569
+ logger.info(f'Webhook called successfully for run {run_id}')
570
+
571
+ except Exception as e:
572
+ # Mark as failed with retry backoff
573
+ async with self._pool.acquire() as conn:
574
+ await conn.execute(
575
+ 'SELECT mark_webhook_failed($1, $2, $3)',
576
+ run_id,
577
+ str(e)[:500],
578
+ 3,
579
+ )
580
+ logger.error(f'Error calling webhook: {e}')
581
+
582
+ async def _call_webhook(
583
+ self,
584
+ url: str,
585
+ run_id: str,
586
+ task_id: str,
587
+ status: str,
588
+ result: Optional[str] = None,
589
+ error: Optional[str] = None,
590
+ ) -> None:
591
+ """Call webhook URL with task completion data."""
592
+ if not self._http_client:
593
+ return
594
+
595
+ payload = {
596
+ 'event': 'task_completed',
597
+ 'run_id': run_id,
598
+ 'task_id': task_id,
599
+ 'status': status,
600
+ 'result': result,
601
+ 'error': error,
602
+ 'timestamp': datetime.now(timezone.utc).isoformat(),
603
+ }
604
+
605
+ try:
606
+ response = await self._http_client.post(
607
+ url,
608
+ json=payload,
609
+ timeout=10.0,
610
+ )
611
+ if response.status_code < 300:
612
+ logger.info(f'Webhook called successfully for run {run_id}')
613
+ else:
614
+ logger.warning(
615
+ f'Webhook returned {response.status_code} for run {run_id}'
616
+ )
617
+ except Exception as e:
618
+ logger.error(f'Webhook call failed for run {run_id}: {e}')
619
+
620
+ async def _cleanup(self) -> None:
621
+ """Cleanup resources."""
622
+ if self._http_client:
623
+ await self._http_client.aclose()
624
+ self._http_client = None
625
+
626
+
627
+ class HostedWorkerPool:
628
+ """
629
+ Manages a pool of hosted workers.
630
+
631
+ Handles:
632
+ - Starting N workers
633
+ - Periodic expired lease reclamation
634
+ - Graceful shutdown
635
+ - Health reporting
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ db_url: str,
641
+ api_base_url: str,
642
+ num_workers: int = DEFAULT_WORKERS,
643
+ poll_interval: float = DEFAULT_POLL_INTERVAL,
644
+ lease_duration: int = DEFAULT_LEASE_DURATION,
645
+ agent_name: Optional[str] = None,
646
+ capabilities: Optional[List[str]] = None,
647
+ ):
648
+ self._db_url = db_url
649
+ self._api_base_url = api_base_url
650
+ self._num_workers = num_workers
651
+ self._poll_interval = poll_interval
652
+ self._lease_duration = lease_duration
653
+ # Agent identity for targeted routing (shared by all workers in pool)
654
+ self._agent_name = agent_name
655
+ self._capabilities = capabilities or []
656
+
657
+ self._pool: Optional[asyncpg.Pool] = None
658
+ self._workers: List[HostedWorker] = []
659
+ self._worker_tasks: List[asyncio.Task] = []
660
+ self._reclaim_task: Optional[asyncio.Task] = None
661
+ self._running = False
662
+
663
+ # Pool identification
664
+ self._pool_id = (
665
+ f'{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}'
666
+ )
667
+
668
+ async def start(self) -> None:
669
+ """Start the worker pool."""
670
+ logger.info(
671
+ f'Starting hosted worker pool {self._pool_id} with {self._num_workers} workers'
672
+ )
673
+
674
+ # Connect to database
675
+ self._pool = await asyncpg.create_pool(
676
+ self._db_url,
677
+ min_size=self._num_workers + 1, # +1 for reclaim task
678
+ max_size=self._num_workers * 2,
679
+ )
680
+
681
+ # Register this pool
682
+ await self._register_pool()
683
+
684
+ self._running = True
685
+
686
+ # Start workers
687
+ for i in range(self._num_workers):
688
+ worker_id = f'{self._pool_id}-worker-{i}'
689
+ worker = HostedWorker(
690
+ worker_id=worker_id,
691
+ db_pool=self._pool,
692
+ api_base_url=self._api_base_url,
693
+ poll_interval=self._poll_interval,
694
+ lease_duration=self._lease_duration,
695
+ agent_name=self._agent_name,
696
+ capabilities=self._capabilities,
697
+ )
698
+ self._workers.append(worker)
699
+ task = asyncio.create_task(worker.start())
700
+ self._worker_tasks.append(task)
701
+
702
+ # Start expired lease reclamation
703
+ self._reclaim_task = asyncio.create_task(self._reclaim_loop())
704
+
705
+ logger.info(f'Worker pool {self._pool_id} started')
706
+
707
+ async def stop(self) -> None:
708
+ """Gracefully stop the worker pool."""
709
+ logger.info(f'Stopping worker pool {self._pool_id}')
710
+ self._running = False
711
+
712
+ # Stop all workers
713
+ for worker in self._workers:
714
+ await worker.stop()
715
+
716
+ # Cancel worker tasks
717
+ for task in self._worker_tasks:
718
+ task.cancel()
719
+
720
+ # Wait for workers to finish
721
+ if self._worker_tasks:
722
+ await asyncio.gather(*self._worker_tasks, return_exceptions=True)
723
+
724
+ # Stop reclaim task
725
+ if self._reclaim_task:
726
+ self._reclaim_task.cancel()
727
+ try:
728
+ await self._reclaim_task
729
+ except asyncio.CancelledError:
730
+ pass
731
+
732
+ # Unregister pool
733
+ await self._unregister_pool()
734
+
735
+ # Close database pool
736
+ if self._pool:
737
+ await self._pool.close()
738
+
739
+ logger.info(f'Worker pool {self._pool_id} stopped')
740
+
741
+ async def _register_pool(self) -> None:
742
+ """Register this worker pool in the database."""
743
+ if not self._pool:
744
+ return
745
+ async with self._pool.acquire() as conn:
746
+ await conn.execute(
747
+ """
748
+ INSERT INTO hosted_workers (id, hostname, process_id, max_concurrent_tasks, started_at)
749
+ VALUES ($1, $2, $3, $4, NOW())
750
+ ON CONFLICT (id) DO UPDATE SET
751
+ status = 'active',
752
+ last_heartbeat = NOW(),
753
+ stopped_at = NULL
754
+ """,
755
+ self._pool_id,
756
+ socket.gethostname(),
757
+ os.getpid(),
758
+ self._num_workers,
759
+ )
760
+
761
+ async def _unregister_pool(self) -> None:
762
+ """Mark this worker pool as stopped."""
763
+ if not self._pool:
764
+ return
765
+
766
+ try:
767
+ async with self._pool.acquire() as conn:
768
+ # Get stats from workers
769
+ total_completed = sum(w.tasks_completed for w in self._workers)
770
+ total_failed = sum(w.tasks_failed for w in self._workers)
771
+ total_runtime = sum(
772
+ w.total_runtime_seconds for w in self._workers
773
+ )
774
+
775
+ await conn.execute(
776
+ """
777
+ UPDATE hosted_workers SET
778
+ status = 'stopped',
779
+ stopped_at = NOW(),
780
+ tasks_completed = $2,
781
+ tasks_failed = $3,
782
+ total_runtime_seconds = $4
783
+ WHERE id = $1
784
+ """,
785
+ self._pool_id,
786
+ total_completed,
787
+ total_failed,
788
+ total_runtime,
789
+ )
790
+ except Exception as e:
791
+ logger.error(f'Failed to unregister pool: {e}')
792
+
793
+ async def _reclaim_loop(self) -> None:
794
+ """Periodically reclaim expired leases and retry failed notifications."""
795
+ if not self._pool:
796
+ return
797
+ try:
798
+ while self._running:
799
+ await asyncio.sleep(60) # Check every 60 seconds
800
+
801
+ async with self._pool.acquire() as conn:
802
+ reclaimed = await conn.fetchval(
803
+ 'SELECT reclaim_expired_task_runs()'
804
+ )
805
+ if reclaimed and reclaimed > 0:
806
+ logger.info(f'Reclaimed {reclaimed} expired task runs')
807
+
808
+ # Fail tasks that exceeded their routing deadline
809
+ deadline_failed = await conn.fetchval(
810
+ 'SELECT fail_deadline_exceeded_tasks()'
811
+ )
812
+ if deadline_failed and deadline_failed > 0:
813
+ logger.info(
814
+ f'Failed {deadline_failed} tasks that exceeded deadline'
815
+ )
816
+
817
+ # Update pool heartbeat
818
+ current_tasks = sum(
819
+ 1 for w in self._workers if w._current_run_id
820
+ )
821
+ await conn.execute(
822
+ """
823
+ UPDATE hosted_workers SET
824
+ last_heartbeat = NOW(),
825
+ current_tasks = $2
826
+ WHERE id = $1
827
+ """,
828
+ self._pool_id,
829
+ current_tasks,
830
+ )
831
+
832
+ # Retry failed notifications
833
+ await self._retry_failed_notifications()
834
+
835
+ except asyncio.CancelledError:
836
+ pass
837
+ except Exception as e:
838
+ logger.error(f'Reclaim loop error: {e}', exc_info=True)
839
+
840
+ async def _retry_failed_notifications(self) -> None:
841
+ """
842
+ Process failed notifications that are ready for retry.
843
+
844
+ This runs periodically in the pool's reclaim loop to ensure
845
+ no notifications are permanently lost due to transient failures.
846
+ """
847
+ if not self._pool:
848
+ return
849
+
850
+ try:
851
+ async with self._pool.acquire() as conn:
852
+ # Get notifications ready for retry
853
+ rows = await conn.fetch(
854
+ 'SELECT * FROM get_pending_notification_retries($1)',
855
+ 10, # Process up to 10 at a time
856
+ )
857
+
858
+ if not rows:
859
+ return
860
+
861
+ logger.info(f'Processing {len(rows)} notification retries')
862
+
863
+ # Use one of our workers to send the notifications
864
+ # (reuse their HTTP client and notification logic)
865
+ if self._workers:
866
+ worker = self._workers[0]
867
+
868
+ for row in rows:
869
+ run_id = row['run_id']
870
+
871
+ # Get full task details for notification
872
+ async with self._pool.acquire() as conn:
873
+ full_row = await conn.fetchrow(
874
+ """
875
+ SELECT tr.notify_email, tr.notify_webhook_url,
876
+ tr.notification_status, tr.webhook_status,
877
+ tr.result_summary, tr.result_full, tr.task_id,
878
+ tr.runtime_seconds, tr.last_error, tr.status,
879
+ t.title, t.prompt
880
+ FROM task_runs tr
881
+ LEFT JOIN tasks t ON tr.task_id = t.id
882
+ WHERE tr.id = $1
883
+ """,
884
+ run_id,
885
+ )
886
+
887
+ if not full_row:
888
+ continue
889
+
890
+ task_status = full_row['status'] or 'completed'
891
+
892
+ # Retry email if needed
893
+ if (
894
+ row['notification_status'] == 'failed'
895
+ and full_row['notify_email']
896
+ ):
897
+ logger.info(
898
+ f'Retrying email notification for run {run_id} '
899
+ f'(attempt {row["notification_attempts"] + 1})'
900
+ )
901
+ await worker._send_email_notification(
902
+ run_id, dict(full_row), task_status
903
+ )
904
+
905
+ # Retry webhook if needed
906
+ if (
907
+ row['webhook_status'] == 'failed'
908
+ and full_row['notify_webhook_url']
909
+ ):
910
+ logger.info(
911
+ f'Retrying webhook notification for run {run_id} '
912
+ f'(attempt {row["webhook_attempts"] + 1})'
913
+ )
914
+ await worker._send_webhook_notification(
915
+ run_id, dict(full_row), task_status
916
+ )
917
+
918
+ except Exception as e:
919
+ logger.error(f'Error retrying notifications: {e}', exc_info=True)
920
+
921
+ def get_stats(self) -> Dict[str, Any]:
922
+ """Get pool statistics."""
923
+ return {
924
+ 'pool_id': self._pool_id,
925
+ 'num_workers': self._num_workers,
926
+ 'running': self._running,
927
+ 'workers': [
928
+ {
929
+ 'id': w.worker_id,
930
+ 'current_run': w._current_run_id,
931
+ 'tasks_completed': w.tasks_completed,
932
+ 'tasks_failed': w.tasks_failed,
933
+ 'total_runtime': w.total_runtime_seconds,
934
+ }
935
+ for w in self._workers
936
+ ],
937
+ 'totals': {
938
+ 'completed': sum(w.tasks_completed for w in self._workers),
939
+ 'failed': sum(w.tasks_failed for w in self._workers),
940
+ 'runtime': sum(w.total_runtime_seconds for w in self._workers),
941
+ },
942
+ }
943
+
944
+
945
+ async def main():
946
+ """Main entry point for hosted worker process."""
947
+ parser = argparse.ArgumentParser(
948
+ description='CodeTether Hosted Worker Pool'
949
+ )
950
+ parser.add_argument(
951
+ '--workers',
952
+ '-w',
953
+ type=int,
954
+ default=DEFAULT_WORKERS,
955
+ help=f'Number of concurrent workers (default: {DEFAULT_WORKERS})',
956
+ )
957
+ parser.add_argument(
958
+ '--db-url',
959
+ default=os.environ.get(
960
+ 'DATABASE_URL',
961
+ 'postgresql://postgres:spike2@192.168.50.70:5432/a2a_server',
962
+ ),
963
+ help='PostgreSQL connection URL',
964
+ )
965
+ parser.add_argument(
966
+ '--api-url',
967
+ default=os.environ.get('API_BASE_URL', 'http://localhost:9001'),
968
+ help='CodeTether API base URL',
969
+ )
970
+ parser.add_argument(
971
+ '--poll-interval',
972
+ type=float,
973
+ default=DEFAULT_POLL_INTERVAL,
974
+ help=f'Poll interval in seconds (default: {DEFAULT_POLL_INTERVAL})',
975
+ )
976
+ parser.add_argument(
977
+ '--lease-duration',
978
+ type=int,
979
+ default=DEFAULT_LEASE_DURATION,
980
+ help=f'Lease duration in seconds (default: {DEFAULT_LEASE_DURATION})',
981
+ )
982
+ parser.add_argument(
983
+ '--log-level',
984
+ default='INFO',
985
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
986
+ help='Log level (default: INFO)',
987
+ )
988
+ parser.add_argument(
989
+ '--agent-name',
990
+ default=os.environ.get('AGENT_NAME'),
991
+ help='Agent name for targeted task routing (env: AGENT_NAME)',
992
+ )
993
+ parser.add_argument(
994
+ '--capabilities',
995
+ default=os.environ.get('AGENT_CAPABILITIES', ''),
996
+ help='Comma-separated list of capabilities (env: AGENT_CAPABILITIES)',
997
+ )
998
+
999
+ args = parser.parse_args()
1000
+
1001
+ # Parse capabilities
1002
+ capabilities = []
1003
+ if args.capabilities:
1004
+ capabilities = [
1005
+ c.strip() for c in args.capabilities.split(',') if c.strip()
1006
+ ]
1007
+
1008
+ # Configure logging
1009
+ logging.basicConfig(
1010
+ level=getattr(logging, args.log_level),
1011
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
1012
+ )
1013
+
1014
+ # Create worker pool
1015
+ pool = HostedWorkerPool(
1016
+ db_url=args.db_url,
1017
+ api_base_url=args.api_url,
1018
+ num_workers=args.workers,
1019
+ poll_interval=args.poll_interval,
1020
+ lease_duration=args.lease_duration,
1021
+ agent_name=args.agent_name,
1022
+ capabilities=capabilities,
1023
+ )
1024
+
1025
+ # Handle shutdown signals
1026
+ loop = asyncio.get_event_loop()
1027
+
1028
+ def signal_handler():
1029
+ logger.info('Received shutdown signal')
1030
+ asyncio.create_task(pool.stop())
1031
+
1032
+ for sig in (signal.SIGINT, signal.SIGTERM):
1033
+ loop.add_signal_handler(sig, signal_handler)
1034
+
1035
+ # Start pool
1036
+ await pool.start()
1037
+
1038
+ # Wait for shutdown
1039
+ try:
1040
+ while pool._running:
1041
+ await asyncio.sleep(1)
1042
+ except asyncio.CancelledError:
1043
+ pass
1044
+
1045
+ logger.info('Hosted worker pool shutdown complete')
1046
+
1047
+
1048
+ if __name__ == '__main__':
1049
+ asyncio.run(main())