codetether 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_server/__init__.py +29 -0
- a2a_server/a2a_agent_card.py +365 -0
- a2a_server/a2a_errors.py +1133 -0
- a2a_server/a2a_executor.py +926 -0
- a2a_server/a2a_router.py +1033 -0
- a2a_server/a2a_types.py +344 -0
- a2a_server/agent_card.py +408 -0
- a2a_server/agents_server.py +271 -0
- a2a_server/auth_api.py +349 -0
- a2a_server/billing_api.py +638 -0
- a2a_server/billing_service.py +712 -0
- a2a_server/billing_webhooks.py +501 -0
- a2a_server/config.py +96 -0
- a2a_server/database.py +2165 -0
- a2a_server/email_inbound.py +398 -0
- a2a_server/email_notifications.py +486 -0
- a2a_server/enhanced_agents.py +919 -0
- a2a_server/enhanced_server.py +160 -0
- a2a_server/hosted_worker.py +1049 -0
- a2a_server/integrated_agents_server.py +347 -0
- a2a_server/keycloak_auth.py +750 -0
- a2a_server/livekit_bridge.py +439 -0
- a2a_server/marketing_tools.py +1364 -0
- a2a_server/mcp_client.py +196 -0
- a2a_server/mcp_http_server.py +2256 -0
- a2a_server/mcp_server.py +191 -0
- a2a_server/message_broker.py +725 -0
- a2a_server/mock_mcp.py +273 -0
- a2a_server/models.py +494 -0
- a2a_server/monitor_api.py +5904 -0
- a2a_server/opencode_bridge.py +1594 -0
- a2a_server/redis_task_manager.py +518 -0
- a2a_server/server.py +726 -0
- a2a_server/task_manager.py +668 -0
- a2a_server/task_queue.py +742 -0
- a2a_server/tenant_api.py +333 -0
- a2a_server/tenant_middleware.py +219 -0
- a2a_server/tenant_service.py +760 -0
- a2a_server/user_auth.py +721 -0
- a2a_server/vault_client.py +576 -0
- a2a_server/worker_sse.py +873 -0
- agent_worker/__init__.py +8 -0
- agent_worker/worker.py +4877 -0
- codetether/__init__.py +10 -0
- codetether/__main__.py +4 -0
- codetether/cli.py +112 -0
- codetether/worker_cli.py +57 -0
- codetether-1.2.2.dist-info/METADATA +570 -0
- codetether-1.2.2.dist-info/RECORD +66 -0
- codetether-1.2.2.dist-info/WHEEL +5 -0
- codetether-1.2.2.dist-info/entry_points.txt +4 -0
- codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
- codetether-1.2.2.dist-info/top_level.txt +5 -0
- codetether_voice_agent/__init__.py +6 -0
- codetether_voice_agent/agent.py +445 -0
- codetether_voice_agent/codetether_mcp.py +345 -0
- codetether_voice_agent/config.py +16 -0
- codetether_voice_agent/functiongemma_caller.py +380 -0
- codetether_voice_agent/session_playback.py +247 -0
- codetether_voice_agent/tools/__init__.py +21 -0
- codetether_voice_agent/tools/definitions.py +135 -0
- codetether_voice_agent/tools/handlers.py +380 -0
- run_server.py +314 -0
- ui/monitor-tailwind.html +1790 -0
- ui/monitor.html +1775 -0
- ui/monitor.js +2662 -0
|
@@ -0,0 +1,1049 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hosted Worker - Managed task execution for mid-market users.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Worker pool that claims and executes tasks from the queue
|
|
6
|
+
- Lease-based job locking with heartbeat renewal
|
|
7
|
+
- Automatic expired lease reclamation
|
|
8
|
+
- Graceful shutdown support
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
python -m a2a_server.hosted_worker --workers 2 --db-url postgresql://...
|
|
12
|
+
|
|
13
|
+
The worker process runs N concurrent workers that poll the task_runs queue.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import asyncio
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import signal
|
|
22
|
+
import socket
|
|
23
|
+
import sys
|
|
24
|
+
import uuid
|
|
25
|
+
from datetime import datetime, timezone
|
|
26
|
+
from typing import Any, Dict, List, Optional
|
|
27
|
+
|
|
28
|
+
import asyncpg
|
|
29
|
+
import httpx
|
|
30
|
+
|
|
31
|
+
from .task_queue import TaskRunStatus
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
# Configuration defaults
|
|
36
|
+
DEFAULT_WORKERS = 2
|
|
37
|
+
DEFAULT_POLL_INTERVAL = 2.0 # seconds
|
|
38
|
+
DEFAULT_LEASE_DURATION = 600 # 10 minutes
|
|
39
|
+
DEFAULT_HEARTBEAT_INTERVAL = 60 # 1 minute
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HostedWorker:
|
|
43
|
+
"""
|
|
44
|
+
A single worker that claims and executes tasks from the queue.
|
|
45
|
+
|
|
46
|
+
Multiple HostedWorker instances run concurrently in the same process.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
worker_id: str,
|
|
52
|
+
db_pool: asyncpg.Pool,
|
|
53
|
+
api_base_url: str,
|
|
54
|
+
poll_interval: float = DEFAULT_POLL_INTERVAL,
|
|
55
|
+
lease_duration: int = DEFAULT_LEASE_DURATION,
|
|
56
|
+
heartbeat_interval: int = DEFAULT_HEARTBEAT_INTERVAL,
|
|
57
|
+
agent_name: Optional[str] = None,
|
|
58
|
+
capabilities: Optional[List[str]] = None,
|
|
59
|
+
):
|
|
60
|
+
self.worker_id = worker_id
|
|
61
|
+
self._pool = db_pool
|
|
62
|
+
self._api_base_url = api_base_url.rstrip('/')
|
|
63
|
+
self._poll_interval = poll_interval
|
|
64
|
+
self._lease_duration = lease_duration
|
|
65
|
+
self._heartbeat_interval = heartbeat_interval
|
|
66
|
+
# Agent identity for targeted routing
|
|
67
|
+
self.agent_name = agent_name
|
|
68
|
+
self.capabilities = capabilities or []
|
|
69
|
+
|
|
70
|
+
self._running = False
|
|
71
|
+
self._current_run_id: Optional[str] = None
|
|
72
|
+
self._current_task_id: Optional[str] = None
|
|
73
|
+
self._heartbeat_task: Optional[asyncio.Task] = None
|
|
74
|
+
self._http_client: Optional[httpx.AsyncClient] = None
|
|
75
|
+
|
|
76
|
+
# Stats
|
|
77
|
+
self.tasks_completed = 0
|
|
78
|
+
self.tasks_failed = 0
|
|
79
|
+
self.total_runtime_seconds = 0
|
|
80
|
+
|
|
81
|
+
async def start(self) -> None:
|
|
82
|
+
"""Start the worker loop."""
|
|
83
|
+
self._running = True
|
|
84
|
+
self._http_client = httpx.AsyncClient(
|
|
85
|
+
timeout=300.0
|
|
86
|
+
) # 5 min timeout for task execution
|
|
87
|
+
|
|
88
|
+
logger.info(f'Worker {self.worker_id} starting')
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
while self._running:
|
|
92
|
+
try:
|
|
93
|
+
# Try to claim next job
|
|
94
|
+
claimed = await self._claim_next_job()
|
|
95
|
+
|
|
96
|
+
if claimed:
|
|
97
|
+
# Execute the task
|
|
98
|
+
await self._execute_current_task()
|
|
99
|
+
else:
|
|
100
|
+
# No work available, wait before polling again
|
|
101
|
+
await asyncio.sleep(self._poll_interval)
|
|
102
|
+
|
|
103
|
+
except asyncio.CancelledError:
|
|
104
|
+
logger.info(f'Worker {self.worker_id} cancelled')
|
|
105
|
+
break
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error(
|
|
108
|
+
f'Worker {self.worker_id} error: {e}', exc_info=True
|
|
109
|
+
)
|
|
110
|
+
await asyncio.sleep(self._poll_interval)
|
|
111
|
+
finally:
|
|
112
|
+
await self._cleanup()
|
|
113
|
+
|
|
114
|
+
async def stop(self) -> None:
|
|
115
|
+
"""Gracefully stop the worker."""
|
|
116
|
+
logger.info(f'Worker {self.worker_id} stopping')
|
|
117
|
+
self._running = False
|
|
118
|
+
|
|
119
|
+
# Stop heartbeat if running
|
|
120
|
+
if self._heartbeat_task:
|
|
121
|
+
self._heartbeat_task.cancel()
|
|
122
|
+
try:
|
|
123
|
+
await self._heartbeat_task
|
|
124
|
+
except asyncio.CancelledError:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
async def _claim_next_job(self) -> bool:
|
|
128
|
+
"""
|
|
129
|
+
Attempt to claim the next available job from the queue.
|
|
130
|
+
|
|
131
|
+
Uses the claim_next_task_run() SQL function for atomic claiming
|
|
132
|
+
with concurrency limit enforcement and agent-targeted routing.
|
|
133
|
+
|
|
134
|
+
The agent_name and capabilities are passed to filter tasks:
|
|
135
|
+
- Tasks with target_agent_name set will only be claimed by matching workers
|
|
136
|
+
- Tasks with required_capabilities will only be claimed by workers with ALL required caps
|
|
137
|
+
|
|
138
|
+
Returns True if a job was claimed.
|
|
139
|
+
"""
|
|
140
|
+
import json
|
|
141
|
+
|
|
142
|
+
# Convert capabilities list to JSONB format for SQL
|
|
143
|
+
capabilities_json = (
|
|
144
|
+
json.dumps(self.capabilities) if self.capabilities else '[]'
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
async with self._pool.acquire() as conn:
|
|
148
|
+
row = await conn.fetchrow(
|
|
149
|
+
'SELECT * FROM claim_next_task_run($1, $2, $3, $4::jsonb)',
|
|
150
|
+
self.worker_id,
|
|
151
|
+
self._lease_duration,
|
|
152
|
+
self.agent_name, # Pass agent_name for targeted routing
|
|
153
|
+
capabilities_json, # Pass capabilities for capability-based routing
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if row and row['run_id']:
|
|
157
|
+
self._current_run_id = row['run_id']
|
|
158
|
+
self._current_task_id = row['task_id']
|
|
159
|
+
target_info = (
|
|
160
|
+
f', targeted_at={row.get("target_agent_name")}'
|
|
161
|
+
if row.get('target_agent_name')
|
|
162
|
+
else ''
|
|
163
|
+
)
|
|
164
|
+
logger.info(
|
|
165
|
+
f'Worker {self.worker_id} (agent={self.agent_name}) claimed run {self._current_run_id} '
|
|
166
|
+
f'(task={self._current_task_id}, priority={row["priority"]}{target_info})'
|
|
167
|
+
)
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
async def _execute_current_task(self) -> None:
|
|
173
|
+
"""Execute the currently claimed task."""
|
|
174
|
+
if not self._current_run_id or not self._current_task_id:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
run_id = self._current_run_id
|
|
178
|
+
task_id = self._current_task_id
|
|
179
|
+
started_at = datetime.now(timezone.utc)
|
|
180
|
+
|
|
181
|
+
# Start heartbeat to keep lease alive
|
|
182
|
+
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop(run_id))
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
logger.info(f'Worker {self.worker_id} executing task {task_id}')
|
|
186
|
+
|
|
187
|
+
# Get task details from API
|
|
188
|
+
task_data = await self._get_task_details(task_id)
|
|
189
|
+
if not task_data:
|
|
190
|
+
raise Exception(f'Task {task_id} not found')
|
|
191
|
+
|
|
192
|
+
# Execute the task via API (this triggers the actual agent work)
|
|
193
|
+
result = await self._run_task(task_id, task_data)
|
|
194
|
+
|
|
195
|
+
# Mark completed
|
|
196
|
+
runtime = int(
|
|
197
|
+
(datetime.now(timezone.utc) - started_at).total_seconds()
|
|
198
|
+
)
|
|
199
|
+
await self._complete_run(
|
|
200
|
+
run_id,
|
|
201
|
+
status='completed',
|
|
202
|
+
result_summary=result.get('summary', 'Task completed'),
|
|
203
|
+
result_full=result,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
self.tasks_completed += 1
|
|
207
|
+
self.total_runtime_seconds += runtime
|
|
208
|
+
logger.info(
|
|
209
|
+
f'Worker {self.worker_id} completed task {task_id} in {runtime}s'
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error(
|
|
214
|
+
f'Worker {self.worker_id} failed task {task_id}: {e}',
|
|
215
|
+
exc_info=True,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
await self._complete_run(
|
|
219
|
+
run_id,
|
|
220
|
+
status='failed',
|
|
221
|
+
error=str(e),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self.tasks_failed += 1
|
|
225
|
+
|
|
226
|
+
finally:
|
|
227
|
+
# Stop heartbeat
|
|
228
|
+
if self._heartbeat_task:
|
|
229
|
+
self._heartbeat_task.cancel()
|
|
230
|
+
try:
|
|
231
|
+
await self._heartbeat_task
|
|
232
|
+
except asyncio.CancelledError:
|
|
233
|
+
pass
|
|
234
|
+
self._heartbeat_task = None
|
|
235
|
+
|
|
236
|
+
self._current_run_id = None
|
|
237
|
+
self._current_task_id = None
|
|
238
|
+
|
|
239
|
+
async def _get_task_details(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
240
|
+
"""Get task details from the API."""
|
|
241
|
+
if not self._http_client:
|
|
242
|
+
return None
|
|
243
|
+
try:
|
|
244
|
+
response = await self._http_client.post(
|
|
245
|
+
f'{self._api_base_url}/mcp/v1/rpc',
|
|
246
|
+
json={
|
|
247
|
+
'jsonrpc': '2.0',
|
|
248
|
+
'method': 'tools/call',
|
|
249
|
+
'params': {
|
|
250
|
+
'name': 'get_task',
|
|
251
|
+
'arguments': {'task_id': task_id},
|
|
252
|
+
},
|
|
253
|
+
'id': str(uuid.uuid4()),
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
response.raise_for_status()
|
|
257
|
+
data = response.json()
|
|
258
|
+
|
|
259
|
+
if 'error' in data:
|
|
260
|
+
logger.error(
|
|
261
|
+
f'API error getting task {task_id}: {data["error"]}'
|
|
262
|
+
)
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
result = data.get('result', {})
|
|
266
|
+
# Handle MCP response format
|
|
267
|
+
if isinstance(result, dict) and 'content' in result:
|
|
268
|
+
for content in result['content']:
|
|
269
|
+
if content.get('type') == 'text':
|
|
270
|
+
return json.loads(content['text'])
|
|
271
|
+
return result
|
|
272
|
+
|
|
273
|
+
except Exception as e:
|
|
274
|
+
logger.error(f'Failed to get task {task_id}: {e}')
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
async def _run_task(
|
|
278
|
+
self, task_id: str, task_data: Dict[str, Any]
|
|
279
|
+
) -> Dict[str, Any]:
|
|
280
|
+
"""
|
|
281
|
+
Execute the task by invoking the appropriate agent.
|
|
282
|
+
|
|
283
|
+
For now, this calls the existing worker SSE task execution path.
|
|
284
|
+
In the future, this could run agents directly.
|
|
285
|
+
"""
|
|
286
|
+
# Get the task prompt
|
|
287
|
+
prompt = task_data.get('description') or task_data.get('prompt', '')
|
|
288
|
+
codebase_id = task_data.get('codebase_id', 'global')
|
|
289
|
+
agent_type = task_data.get('agent_type', 'build')
|
|
290
|
+
model = task_data.get('model')
|
|
291
|
+
|
|
292
|
+
logger.info(
|
|
293
|
+
f'Running task {task_id}: agent={agent_type}, codebase={codebase_id}'
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# For global/pending tasks, we can execute directly via API
|
|
297
|
+
# This triggers the existing execution path which may use SSE workers
|
|
298
|
+
# or we can implement direct execution here
|
|
299
|
+
|
|
300
|
+
# Option 1: Use the continue_task endpoint to trigger execution
|
|
301
|
+
# This works with the existing worker infrastructure
|
|
302
|
+
if not self._http_client:
|
|
303
|
+
raise Exception('HTTP client not initialized')
|
|
304
|
+
try:
|
|
305
|
+
response = await self._http_client.post(
|
|
306
|
+
f'{self._api_base_url}/mcp/v1/rpc',
|
|
307
|
+
json={
|
|
308
|
+
'jsonrpc': '2.0',
|
|
309
|
+
'method': 'tools/call',
|
|
310
|
+
'params': {
|
|
311
|
+
'name': 'continue_task',
|
|
312
|
+
'arguments': {
|
|
313
|
+
'task_id': task_id,
|
|
314
|
+
'input': prompt, # Pass the prompt as continuation
|
|
315
|
+
},
|
|
316
|
+
},
|
|
317
|
+
'id': str(uuid.uuid4()),
|
|
318
|
+
},
|
|
319
|
+
timeout=self._lease_duration, # Don't timeout before lease
|
|
320
|
+
)
|
|
321
|
+
response.raise_for_status()
|
|
322
|
+
data = response.json()
|
|
323
|
+
|
|
324
|
+
if 'error' in data:
|
|
325
|
+
raise Exception(f'Task execution error: {data["error"]}')
|
|
326
|
+
|
|
327
|
+
result = data.get('result', {})
|
|
328
|
+
|
|
329
|
+
# Extract result from MCP format
|
|
330
|
+
if isinstance(result, dict) and 'content' in result:
|
|
331
|
+
for content in result['content']:
|
|
332
|
+
if content.get('type') == 'text':
|
|
333
|
+
try:
|
|
334
|
+
return json.loads(content['text'])
|
|
335
|
+
except json.JSONDecodeError:
|
|
336
|
+
return {
|
|
337
|
+
'summary': content['text'],
|
|
338
|
+
'raw': content['text'],
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
return {'summary': 'Task completed', 'result': result}
|
|
342
|
+
|
|
343
|
+
except httpx.TimeoutException:
|
|
344
|
+
# Task took too long - it may still be running
|
|
345
|
+
# Check task status
|
|
346
|
+
task_details = await self._get_task_details(task_id)
|
|
347
|
+
if task_details and task_details.get('status') == 'completed':
|
|
348
|
+
return {
|
|
349
|
+
'summary': 'Task completed (timeout during response)',
|
|
350
|
+
'result': task_details,
|
|
351
|
+
}
|
|
352
|
+
raise Exception('Task execution timed out')
|
|
353
|
+
|
|
354
|
+
async def _heartbeat_loop(self, run_id: str) -> None:
|
|
355
|
+
"""Periodically renew the lease on the current job."""
|
|
356
|
+
try:
|
|
357
|
+
while True:
|
|
358
|
+
await asyncio.sleep(self._heartbeat_interval)
|
|
359
|
+
|
|
360
|
+
async with self._pool.acquire() as conn:
|
|
361
|
+
renewed = await conn.fetchval(
|
|
362
|
+
'SELECT renew_task_run_lease($1, $2, $3)',
|
|
363
|
+
run_id,
|
|
364
|
+
self.worker_id,
|
|
365
|
+
self._lease_duration,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
if not renewed:
|
|
369
|
+
logger.warning(
|
|
370
|
+
f'Failed to renew lease for run {run_id}'
|
|
371
|
+
)
|
|
372
|
+
break
|
|
373
|
+
|
|
374
|
+
logger.debug(f'Renewed lease for run {run_id}')
|
|
375
|
+
|
|
376
|
+
except asyncio.CancelledError:
|
|
377
|
+
pass
|
|
378
|
+
|
|
379
|
+
async def _complete_run(
|
|
380
|
+
self,
|
|
381
|
+
run_id: str,
|
|
382
|
+
status: str,
|
|
383
|
+
result_summary: Optional[str] = None,
|
|
384
|
+
result_full: Optional[Dict[str, Any]] = None,
|
|
385
|
+
error: Optional[str] = None,
|
|
386
|
+
) -> None:
|
|
387
|
+
"""Mark a task run as completed or failed."""
|
|
388
|
+
async with self._pool.acquire() as conn:
|
|
389
|
+
await conn.fetchval(
|
|
390
|
+
'SELECT complete_task_run($1, $2, $3, $4, $5, $6)',
|
|
391
|
+
run_id,
|
|
392
|
+
self.worker_id,
|
|
393
|
+
status,
|
|
394
|
+
result_summary,
|
|
395
|
+
json.dumps(result_full) if result_full else None,
|
|
396
|
+
error,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
logger.info(f'Run {run_id} marked as {status}')
|
|
400
|
+
|
|
401
|
+
# TODO: Send notification (email/webhook) if configured
|
|
402
|
+
await self._send_completion_notification(run_id, status)
|
|
403
|
+
|
|
404
|
+
async def _send_completion_notification(
|
|
405
|
+
self, run_id: str, status: str
|
|
406
|
+
) -> None:
|
|
407
|
+
"""
|
|
408
|
+
Send completion notification (email/webhook) with retry-safe 3-state flow.
|
|
409
|
+
|
|
410
|
+
Flow:
|
|
411
|
+
1. Atomically claim notification for send (increments attempts)
|
|
412
|
+
2. Try to send
|
|
413
|
+
3. On success: mark_notification_sent()
|
|
414
|
+
4. On failure: mark_notification_failed() with backoff for retry
|
|
415
|
+
|
|
416
|
+
This prevents both duplicate sends AND permanent silence.
|
|
417
|
+
"""
|
|
418
|
+
# Get notification settings and task details for this run
|
|
419
|
+
async with self._pool.acquire() as conn:
|
|
420
|
+
row = await conn.fetchrow(
|
|
421
|
+
"""
|
|
422
|
+
SELECT tr.notify_email, tr.notify_webhook_url,
|
|
423
|
+
tr.notification_status, tr.webhook_status,
|
|
424
|
+
tr.result_summary, tr.result_full, tr.task_id,
|
|
425
|
+
tr.runtime_seconds, tr.last_error,
|
|
426
|
+
t.title, t.prompt
|
|
427
|
+
FROM task_runs tr
|
|
428
|
+
LEFT JOIN tasks t ON tr.task_id = t.id
|
|
429
|
+
WHERE tr.id = $1
|
|
430
|
+
""",
|
|
431
|
+
run_id,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
if not row:
|
|
435
|
+
return
|
|
436
|
+
|
|
437
|
+
# Send email notification with 3-state tracking
|
|
438
|
+
if row['notify_email'] and row['notification_status'] != 'sent':
|
|
439
|
+
await self._send_email_notification(run_id, row, status)
|
|
440
|
+
|
|
441
|
+
# Send webhook notification with 3-state tracking
|
|
442
|
+
if row['notify_webhook_url'] and row['webhook_status'] != 'sent':
|
|
443
|
+
await self._send_webhook_notification(run_id, row, status)
|
|
444
|
+
|
|
445
|
+
async def _send_email_notification(
|
|
446
|
+
self, run_id: str, row: dict, status: str
|
|
447
|
+
) -> None:
|
|
448
|
+
"""Send email notification with atomic claim and retry support."""
|
|
449
|
+
# Atomically claim the notification (prevents double-send)
|
|
450
|
+
async with self._pool.acquire() as conn:
|
|
451
|
+
claimed = await conn.fetchval(
|
|
452
|
+
'SELECT claim_notification_for_send($1, $2)',
|
|
453
|
+
run_id,
|
|
454
|
+
3, # max_attempts
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if not claimed:
|
|
458
|
+
logger.debug(
|
|
459
|
+
f'Notification already claimed or sent for run {run_id}'
|
|
460
|
+
)
|
|
461
|
+
return
|
|
462
|
+
|
|
463
|
+
# Now try to send the email
|
|
464
|
+
try:
|
|
465
|
+
from .email_notifications import send_task_completion_email
|
|
466
|
+
|
|
467
|
+
# Extract result from JSON if stored
|
|
468
|
+
result_text = row['result_summary']
|
|
469
|
+
if row['result_full']:
|
|
470
|
+
try:
|
|
471
|
+
result_full = (
|
|
472
|
+
json.loads(row['result_full'])
|
|
473
|
+
if isinstance(row['result_full'], str)
|
|
474
|
+
else row['result_full']
|
|
475
|
+
)
|
|
476
|
+
result_text = (
|
|
477
|
+
result_full.get('summary')
|
|
478
|
+
or result_full.get('result')
|
|
479
|
+
or result_text
|
|
480
|
+
)
|
|
481
|
+
except (json.JSONDecodeError, TypeError):
|
|
482
|
+
pass
|
|
483
|
+
|
|
484
|
+
email_sent = await send_task_completion_email(
|
|
485
|
+
to_email=row['notify_email'],
|
|
486
|
+
task_id=row['task_id'],
|
|
487
|
+
title=row['title'] or 'Task',
|
|
488
|
+
status=status,
|
|
489
|
+
result=result_text,
|
|
490
|
+
error=row['last_error'] if status == 'failed' else None,
|
|
491
|
+
runtime_seconds=row['runtime_seconds'],
|
|
492
|
+
worker_name=self.worker_id,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
if email_sent:
|
|
496
|
+
# Mark as sent - success!
|
|
497
|
+
async with self._pool.acquire() as conn:
|
|
498
|
+
await conn.execute(
|
|
499
|
+
'SELECT mark_notification_sent($1)', run_id
|
|
500
|
+
)
|
|
501
|
+
logger.info(
|
|
502
|
+
f'Completion email sent to {row["notify_email"]} for run {run_id}'
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
# Mark as failed with retry backoff
|
|
506
|
+
async with self._pool.acquire() as conn:
|
|
507
|
+
await conn.execute(
|
|
508
|
+
'SELECT mark_notification_failed($1, $2, $3)',
|
|
509
|
+
run_id,
|
|
510
|
+
'SendGrid returned failure (check API key/config)',
|
|
511
|
+
3,
|
|
512
|
+
)
|
|
513
|
+
logger.warning(
|
|
514
|
+
f'Failed to send completion email for run {run_id}, will retry'
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
except ImportError as e:
|
|
518
|
+
async with self._pool.acquire() as conn:
|
|
519
|
+
await conn.execute(
|
|
520
|
+
'SELECT mark_notification_failed($1, $2, $3)',
|
|
521
|
+
run_id,
|
|
522
|
+
f'email_notifications module not available: {e}',
|
|
523
|
+
3,
|
|
524
|
+
)
|
|
525
|
+
logger.warning('email_notifications module not available')
|
|
526
|
+
except Exception as e:
|
|
527
|
+
# Mark as failed with retry backoff
|
|
528
|
+
async with self._pool.acquire() as conn:
|
|
529
|
+
await conn.execute(
|
|
530
|
+
'SELECT mark_notification_failed($1, $2, $3)',
|
|
531
|
+
run_id,
|
|
532
|
+
str(e)[:500], # Truncate error message
|
|
533
|
+
3,
|
|
534
|
+
)
|
|
535
|
+
logger.error(f'Error sending completion email: {e}')
|
|
536
|
+
|
|
537
|
+
async def _send_webhook_notification(
|
|
538
|
+
self, run_id: str, row: dict, status: str
|
|
539
|
+
) -> None:
|
|
540
|
+
"""Send webhook notification with atomic claim and retry support."""
|
|
541
|
+
# Atomically claim the webhook notification
|
|
542
|
+
async with self._pool.acquire() as conn:
|
|
543
|
+
claimed = await conn.fetchval(
|
|
544
|
+
'SELECT claim_webhook_for_send($1, $2)',
|
|
545
|
+
run_id,
|
|
546
|
+
3, # max_attempts
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
if not claimed:
|
|
550
|
+
logger.debug(
|
|
551
|
+
f'Webhook already claimed or sent for run {run_id}'
|
|
552
|
+
)
|
|
553
|
+
return
|
|
554
|
+
|
|
555
|
+
# Now try to call the webhook
|
|
556
|
+
try:
|
|
557
|
+
await self._call_webhook(
|
|
558
|
+
url=row['notify_webhook_url'],
|
|
559
|
+
run_id=run_id,
|
|
560
|
+
task_id=row['task_id'],
|
|
561
|
+
status=status,
|
|
562
|
+
result=row['result_summary'],
|
|
563
|
+
error=row['last_error'],
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
# Mark as sent - success!
|
|
567
|
+
async with self._pool.acquire() as conn:
|
|
568
|
+
await conn.execute('SELECT mark_webhook_sent($1)', run_id)
|
|
569
|
+
logger.info(f'Webhook called successfully for run {run_id}')
|
|
570
|
+
|
|
571
|
+
except Exception as e:
|
|
572
|
+
# Mark as failed with retry backoff
|
|
573
|
+
async with self._pool.acquire() as conn:
|
|
574
|
+
await conn.execute(
|
|
575
|
+
'SELECT mark_webhook_failed($1, $2, $3)',
|
|
576
|
+
run_id,
|
|
577
|
+
str(e)[:500],
|
|
578
|
+
3,
|
|
579
|
+
)
|
|
580
|
+
logger.error(f'Error calling webhook: {e}')
|
|
581
|
+
|
|
582
|
+
async def _call_webhook(
|
|
583
|
+
self,
|
|
584
|
+
url: str,
|
|
585
|
+
run_id: str,
|
|
586
|
+
task_id: str,
|
|
587
|
+
status: str,
|
|
588
|
+
result: Optional[str] = None,
|
|
589
|
+
error: Optional[str] = None,
|
|
590
|
+
) -> None:
|
|
591
|
+
"""Call webhook URL with task completion data."""
|
|
592
|
+
if not self._http_client:
|
|
593
|
+
return
|
|
594
|
+
|
|
595
|
+
payload = {
|
|
596
|
+
'event': 'task_completed',
|
|
597
|
+
'run_id': run_id,
|
|
598
|
+
'task_id': task_id,
|
|
599
|
+
'status': status,
|
|
600
|
+
'result': result,
|
|
601
|
+
'error': error,
|
|
602
|
+
'timestamp': datetime.now(timezone.utc).isoformat(),
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
try:
|
|
606
|
+
response = await self._http_client.post(
|
|
607
|
+
url,
|
|
608
|
+
json=payload,
|
|
609
|
+
timeout=10.0,
|
|
610
|
+
)
|
|
611
|
+
if response.status_code < 300:
|
|
612
|
+
logger.info(f'Webhook called successfully for run {run_id}')
|
|
613
|
+
else:
|
|
614
|
+
logger.warning(
|
|
615
|
+
f'Webhook returned {response.status_code} for run {run_id}'
|
|
616
|
+
)
|
|
617
|
+
except Exception as e:
|
|
618
|
+
logger.error(f'Webhook call failed for run {run_id}: {e}')
|
|
619
|
+
|
|
620
|
+
async def _cleanup(self) -> None:
|
|
621
|
+
"""Cleanup resources."""
|
|
622
|
+
if self._http_client:
|
|
623
|
+
await self._http_client.aclose()
|
|
624
|
+
self._http_client = None
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
class HostedWorkerPool:
|
|
628
|
+
"""
|
|
629
|
+
Manages a pool of hosted workers.
|
|
630
|
+
|
|
631
|
+
Handles:
|
|
632
|
+
- Starting N workers
|
|
633
|
+
- Periodic expired lease reclamation
|
|
634
|
+
- Graceful shutdown
|
|
635
|
+
- Health reporting
|
|
636
|
+
"""
|
|
637
|
+
|
|
638
|
+
def __init__(
|
|
639
|
+
self,
|
|
640
|
+
db_url: str,
|
|
641
|
+
api_base_url: str,
|
|
642
|
+
num_workers: int = DEFAULT_WORKERS,
|
|
643
|
+
poll_interval: float = DEFAULT_POLL_INTERVAL,
|
|
644
|
+
lease_duration: int = DEFAULT_LEASE_DURATION,
|
|
645
|
+
agent_name: Optional[str] = None,
|
|
646
|
+
capabilities: Optional[List[str]] = None,
|
|
647
|
+
):
|
|
648
|
+
self._db_url = db_url
|
|
649
|
+
self._api_base_url = api_base_url
|
|
650
|
+
self._num_workers = num_workers
|
|
651
|
+
self._poll_interval = poll_interval
|
|
652
|
+
self._lease_duration = lease_duration
|
|
653
|
+
# Agent identity for targeted routing (shared by all workers in pool)
|
|
654
|
+
self._agent_name = agent_name
|
|
655
|
+
self._capabilities = capabilities or []
|
|
656
|
+
|
|
657
|
+
self._pool: Optional[asyncpg.Pool] = None
|
|
658
|
+
self._workers: List[HostedWorker] = []
|
|
659
|
+
self._worker_tasks: List[asyncio.Task] = []
|
|
660
|
+
self._reclaim_task: Optional[asyncio.Task] = None
|
|
661
|
+
self._running = False
|
|
662
|
+
|
|
663
|
+
# Pool identification
|
|
664
|
+
self._pool_id = (
|
|
665
|
+
f'{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}'
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
async def start(self) -> None:
|
|
669
|
+
"""Start the worker pool."""
|
|
670
|
+
logger.info(
|
|
671
|
+
f'Starting hosted worker pool {self._pool_id} with {self._num_workers} workers'
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# Connect to database
|
|
675
|
+
self._pool = await asyncpg.create_pool(
|
|
676
|
+
self._db_url,
|
|
677
|
+
min_size=self._num_workers + 1, # +1 for reclaim task
|
|
678
|
+
max_size=self._num_workers * 2,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Register this pool
|
|
682
|
+
await self._register_pool()
|
|
683
|
+
|
|
684
|
+
self._running = True
|
|
685
|
+
|
|
686
|
+
# Start workers
|
|
687
|
+
for i in range(self._num_workers):
|
|
688
|
+
worker_id = f'{self._pool_id}-worker-{i}'
|
|
689
|
+
worker = HostedWorker(
|
|
690
|
+
worker_id=worker_id,
|
|
691
|
+
db_pool=self._pool,
|
|
692
|
+
api_base_url=self._api_base_url,
|
|
693
|
+
poll_interval=self._poll_interval,
|
|
694
|
+
lease_duration=self._lease_duration,
|
|
695
|
+
agent_name=self._agent_name,
|
|
696
|
+
capabilities=self._capabilities,
|
|
697
|
+
)
|
|
698
|
+
self._workers.append(worker)
|
|
699
|
+
task = asyncio.create_task(worker.start())
|
|
700
|
+
self._worker_tasks.append(task)
|
|
701
|
+
|
|
702
|
+
# Start expired lease reclamation
|
|
703
|
+
self._reclaim_task = asyncio.create_task(self._reclaim_loop())
|
|
704
|
+
|
|
705
|
+
logger.info(f'Worker pool {self._pool_id} started')
|
|
706
|
+
|
|
707
|
+
async def stop(self) -> None:
|
|
708
|
+
"""Gracefully stop the worker pool."""
|
|
709
|
+
logger.info(f'Stopping worker pool {self._pool_id}')
|
|
710
|
+
self._running = False
|
|
711
|
+
|
|
712
|
+
# Stop all workers
|
|
713
|
+
for worker in self._workers:
|
|
714
|
+
await worker.stop()
|
|
715
|
+
|
|
716
|
+
# Cancel worker tasks
|
|
717
|
+
for task in self._worker_tasks:
|
|
718
|
+
task.cancel()
|
|
719
|
+
|
|
720
|
+
# Wait for workers to finish
|
|
721
|
+
if self._worker_tasks:
|
|
722
|
+
await asyncio.gather(*self._worker_tasks, return_exceptions=True)
|
|
723
|
+
|
|
724
|
+
# Stop reclaim task
|
|
725
|
+
if self._reclaim_task:
|
|
726
|
+
self._reclaim_task.cancel()
|
|
727
|
+
try:
|
|
728
|
+
await self._reclaim_task
|
|
729
|
+
except asyncio.CancelledError:
|
|
730
|
+
pass
|
|
731
|
+
|
|
732
|
+
# Unregister pool
|
|
733
|
+
await self._unregister_pool()
|
|
734
|
+
|
|
735
|
+
# Close database pool
|
|
736
|
+
if self._pool:
|
|
737
|
+
await self._pool.close()
|
|
738
|
+
|
|
739
|
+
logger.info(f'Worker pool {self._pool_id} stopped')
|
|
740
|
+
|
|
741
|
+
async def _register_pool(self) -> None:
|
|
742
|
+
"""Register this worker pool in the database."""
|
|
743
|
+
if not self._pool:
|
|
744
|
+
return
|
|
745
|
+
async with self._pool.acquire() as conn:
|
|
746
|
+
await conn.execute(
|
|
747
|
+
"""
|
|
748
|
+
INSERT INTO hosted_workers (id, hostname, process_id, max_concurrent_tasks, started_at)
|
|
749
|
+
VALUES ($1, $2, $3, $4, NOW())
|
|
750
|
+
ON CONFLICT (id) DO UPDATE SET
|
|
751
|
+
status = 'active',
|
|
752
|
+
last_heartbeat = NOW(),
|
|
753
|
+
stopped_at = NULL
|
|
754
|
+
""",
|
|
755
|
+
self._pool_id,
|
|
756
|
+
socket.gethostname(),
|
|
757
|
+
os.getpid(),
|
|
758
|
+
self._num_workers,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
async def _unregister_pool(self) -> None:
|
|
762
|
+
"""Mark this worker pool as stopped."""
|
|
763
|
+
if not self._pool:
|
|
764
|
+
return
|
|
765
|
+
|
|
766
|
+
try:
|
|
767
|
+
async with self._pool.acquire() as conn:
|
|
768
|
+
# Get stats from workers
|
|
769
|
+
total_completed = sum(w.tasks_completed for w in self._workers)
|
|
770
|
+
total_failed = sum(w.tasks_failed for w in self._workers)
|
|
771
|
+
total_runtime = sum(
|
|
772
|
+
w.total_runtime_seconds for w in self._workers
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
await conn.execute(
|
|
776
|
+
"""
|
|
777
|
+
UPDATE hosted_workers SET
|
|
778
|
+
status = 'stopped',
|
|
779
|
+
stopped_at = NOW(),
|
|
780
|
+
tasks_completed = $2,
|
|
781
|
+
tasks_failed = $3,
|
|
782
|
+
total_runtime_seconds = $4
|
|
783
|
+
WHERE id = $1
|
|
784
|
+
""",
|
|
785
|
+
self._pool_id,
|
|
786
|
+
total_completed,
|
|
787
|
+
total_failed,
|
|
788
|
+
total_runtime,
|
|
789
|
+
)
|
|
790
|
+
except Exception as e:
|
|
791
|
+
logger.error(f'Failed to unregister pool: {e}')
|
|
792
|
+
|
|
793
|
+
async def _reclaim_loop(self) -> None:
|
|
794
|
+
"""Periodically reclaim expired leases and retry failed notifications."""
|
|
795
|
+
if not self._pool:
|
|
796
|
+
return
|
|
797
|
+
try:
|
|
798
|
+
while self._running:
|
|
799
|
+
await asyncio.sleep(60) # Check every 60 seconds
|
|
800
|
+
|
|
801
|
+
async with self._pool.acquire() as conn:
|
|
802
|
+
reclaimed = await conn.fetchval(
|
|
803
|
+
'SELECT reclaim_expired_task_runs()'
|
|
804
|
+
)
|
|
805
|
+
if reclaimed and reclaimed > 0:
|
|
806
|
+
logger.info(f'Reclaimed {reclaimed} expired task runs')
|
|
807
|
+
|
|
808
|
+
# Fail tasks that exceeded their routing deadline
|
|
809
|
+
deadline_failed = await conn.fetchval(
|
|
810
|
+
'SELECT fail_deadline_exceeded_tasks()'
|
|
811
|
+
)
|
|
812
|
+
if deadline_failed and deadline_failed > 0:
|
|
813
|
+
logger.info(
|
|
814
|
+
f'Failed {deadline_failed} tasks that exceeded deadline'
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# Update pool heartbeat
|
|
818
|
+
current_tasks = sum(
|
|
819
|
+
1 for w in self._workers if w._current_run_id
|
|
820
|
+
)
|
|
821
|
+
await conn.execute(
|
|
822
|
+
"""
|
|
823
|
+
UPDATE hosted_workers SET
|
|
824
|
+
last_heartbeat = NOW(),
|
|
825
|
+
current_tasks = $2
|
|
826
|
+
WHERE id = $1
|
|
827
|
+
""",
|
|
828
|
+
self._pool_id,
|
|
829
|
+
current_tasks,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# Retry failed notifications
|
|
833
|
+
await self._retry_failed_notifications()
|
|
834
|
+
|
|
835
|
+
except asyncio.CancelledError:
|
|
836
|
+
pass
|
|
837
|
+
except Exception as e:
|
|
838
|
+
logger.error(f'Reclaim loop error: {e}', exc_info=True)
|
|
839
|
+
|
|
840
|
+
async def _retry_failed_notifications(self) -> None:
|
|
841
|
+
"""
|
|
842
|
+
Process failed notifications that are ready for retry.
|
|
843
|
+
|
|
844
|
+
This runs periodically in the pool's reclaim loop to ensure
|
|
845
|
+
no notifications are permanently lost due to transient failures.
|
|
846
|
+
"""
|
|
847
|
+
if not self._pool:
|
|
848
|
+
return
|
|
849
|
+
|
|
850
|
+
try:
|
|
851
|
+
async with self._pool.acquire() as conn:
|
|
852
|
+
# Get notifications ready for retry
|
|
853
|
+
rows = await conn.fetch(
|
|
854
|
+
'SELECT * FROM get_pending_notification_retries($1)',
|
|
855
|
+
10, # Process up to 10 at a time
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
if not rows:
|
|
859
|
+
return
|
|
860
|
+
|
|
861
|
+
logger.info(f'Processing {len(rows)} notification retries')
|
|
862
|
+
|
|
863
|
+
# Use one of our workers to send the notifications
|
|
864
|
+
# (reuse their HTTP client and notification logic)
|
|
865
|
+
if self._workers:
|
|
866
|
+
worker = self._workers[0]
|
|
867
|
+
|
|
868
|
+
for row in rows:
|
|
869
|
+
run_id = row['run_id']
|
|
870
|
+
|
|
871
|
+
# Get full task details for notification
|
|
872
|
+
async with self._pool.acquire() as conn:
|
|
873
|
+
full_row = await conn.fetchrow(
|
|
874
|
+
"""
|
|
875
|
+
SELECT tr.notify_email, tr.notify_webhook_url,
|
|
876
|
+
tr.notification_status, tr.webhook_status,
|
|
877
|
+
tr.result_summary, tr.result_full, tr.task_id,
|
|
878
|
+
tr.runtime_seconds, tr.last_error, tr.status,
|
|
879
|
+
t.title, t.prompt
|
|
880
|
+
FROM task_runs tr
|
|
881
|
+
LEFT JOIN tasks t ON tr.task_id = t.id
|
|
882
|
+
WHERE tr.id = $1
|
|
883
|
+
""",
|
|
884
|
+
run_id,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
if not full_row:
|
|
888
|
+
continue
|
|
889
|
+
|
|
890
|
+
task_status = full_row['status'] or 'completed'
|
|
891
|
+
|
|
892
|
+
# Retry email if needed
|
|
893
|
+
if (
|
|
894
|
+
row['notification_status'] == 'failed'
|
|
895
|
+
and full_row['notify_email']
|
|
896
|
+
):
|
|
897
|
+
logger.info(
|
|
898
|
+
f'Retrying email notification for run {run_id} '
|
|
899
|
+
f'(attempt {row["notification_attempts"] + 1})'
|
|
900
|
+
)
|
|
901
|
+
await worker._send_email_notification(
|
|
902
|
+
run_id, dict(full_row), task_status
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# Retry webhook if needed
|
|
906
|
+
if (
|
|
907
|
+
row['webhook_status'] == 'failed'
|
|
908
|
+
and full_row['notify_webhook_url']
|
|
909
|
+
):
|
|
910
|
+
logger.info(
|
|
911
|
+
f'Retrying webhook notification for run {run_id} '
|
|
912
|
+
f'(attempt {row["webhook_attempts"] + 1})'
|
|
913
|
+
)
|
|
914
|
+
await worker._send_webhook_notification(
|
|
915
|
+
run_id, dict(full_row), task_status
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
except Exception as e:
|
|
919
|
+
logger.error(f'Error retrying notifications: {e}', exc_info=True)
|
|
920
|
+
|
|
921
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
922
|
+
"""Get pool statistics."""
|
|
923
|
+
return {
|
|
924
|
+
'pool_id': self._pool_id,
|
|
925
|
+
'num_workers': self._num_workers,
|
|
926
|
+
'running': self._running,
|
|
927
|
+
'workers': [
|
|
928
|
+
{
|
|
929
|
+
'id': w.worker_id,
|
|
930
|
+
'current_run': w._current_run_id,
|
|
931
|
+
'tasks_completed': w.tasks_completed,
|
|
932
|
+
'tasks_failed': w.tasks_failed,
|
|
933
|
+
'total_runtime': w.total_runtime_seconds,
|
|
934
|
+
}
|
|
935
|
+
for w in self._workers
|
|
936
|
+
],
|
|
937
|
+
'totals': {
|
|
938
|
+
'completed': sum(w.tasks_completed for w in self._workers),
|
|
939
|
+
'failed': sum(w.tasks_failed for w in self._workers),
|
|
940
|
+
'runtime': sum(w.total_runtime_seconds for w in self._workers),
|
|
941
|
+
},
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
async def main():
|
|
946
|
+
"""Main entry point for hosted worker process."""
|
|
947
|
+
parser = argparse.ArgumentParser(
|
|
948
|
+
description='CodeTether Hosted Worker Pool'
|
|
949
|
+
)
|
|
950
|
+
parser.add_argument(
|
|
951
|
+
'--workers',
|
|
952
|
+
'-w',
|
|
953
|
+
type=int,
|
|
954
|
+
default=DEFAULT_WORKERS,
|
|
955
|
+
help=f'Number of concurrent workers (default: {DEFAULT_WORKERS})',
|
|
956
|
+
)
|
|
957
|
+
parser.add_argument(
|
|
958
|
+
'--db-url',
|
|
959
|
+
default=os.environ.get(
|
|
960
|
+
'DATABASE_URL',
|
|
961
|
+
'postgresql://postgres:spike2@192.168.50.70:5432/a2a_server',
|
|
962
|
+
),
|
|
963
|
+
help='PostgreSQL connection URL',
|
|
964
|
+
)
|
|
965
|
+
parser.add_argument(
|
|
966
|
+
'--api-url',
|
|
967
|
+
default=os.environ.get('API_BASE_URL', 'http://localhost:9001'),
|
|
968
|
+
help='CodeTether API base URL',
|
|
969
|
+
)
|
|
970
|
+
parser.add_argument(
|
|
971
|
+
'--poll-interval',
|
|
972
|
+
type=float,
|
|
973
|
+
default=DEFAULT_POLL_INTERVAL,
|
|
974
|
+
help=f'Poll interval in seconds (default: {DEFAULT_POLL_INTERVAL})',
|
|
975
|
+
)
|
|
976
|
+
parser.add_argument(
|
|
977
|
+
'--lease-duration',
|
|
978
|
+
type=int,
|
|
979
|
+
default=DEFAULT_LEASE_DURATION,
|
|
980
|
+
help=f'Lease duration in seconds (default: {DEFAULT_LEASE_DURATION})',
|
|
981
|
+
)
|
|
982
|
+
parser.add_argument(
|
|
983
|
+
'--log-level',
|
|
984
|
+
default='INFO',
|
|
985
|
+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
|
986
|
+
help='Log level (default: INFO)',
|
|
987
|
+
)
|
|
988
|
+
parser.add_argument(
|
|
989
|
+
'--agent-name',
|
|
990
|
+
default=os.environ.get('AGENT_NAME'),
|
|
991
|
+
help='Agent name for targeted task routing (env: AGENT_NAME)',
|
|
992
|
+
)
|
|
993
|
+
parser.add_argument(
|
|
994
|
+
'--capabilities',
|
|
995
|
+
default=os.environ.get('AGENT_CAPABILITIES', ''),
|
|
996
|
+
help='Comma-separated list of capabilities (env: AGENT_CAPABILITIES)',
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
args = parser.parse_args()
|
|
1000
|
+
|
|
1001
|
+
# Parse capabilities
|
|
1002
|
+
capabilities = []
|
|
1003
|
+
if args.capabilities:
|
|
1004
|
+
capabilities = [
|
|
1005
|
+
c.strip() for c in args.capabilities.split(',') if c.strip()
|
|
1006
|
+
]
|
|
1007
|
+
|
|
1008
|
+
# Configure logging
|
|
1009
|
+
logging.basicConfig(
|
|
1010
|
+
level=getattr(logging, args.log_level),
|
|
1011
|
+
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
# Create worker pool
|
|
1015
|
+
pool = HostedWorkerPool(
|
|
1016
|
+
db_url=args.db_url,
|
|
1017
|
+
api_base_url=args.api_url,
|
|
1018
|
+
num_workers=args.workers,
|
|
1019
|
+
poll_interval=args.poll_interval,
|
|
1020
|
+
lease_duration=args.lease_duration,
|
|
1021
|
+
agent_name=args.agent_name,
|
|
1022
|
+
capabilities=capabilities,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# Handle shutdown signals
|
|
1026
|
+
loop = asyncio.get_event_loop()
|
|
1027
|
+
|
|
1028
|
+
def signal_handler():
|
|
1029
|
+
logger.info('Received shutdown signal')
|
|
1030
|
+
asyncio.create_task(pool.stop())
|
|
1031
|
+
|
|
1032
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
1033
|
+
loop.add_signal_handler(sig, signal_handler)
|
|
1034
|
+
|
|
1035
|
+
# Start pool
|
|
1036
|
+
await pool.start()
|
|
1037
|
+
|
|
1038
|
+
# Wait for shutdown
|
|
1039
|
+
try:
|
|
1040
|
+
while pool._running:
|
|
1041
|
+
await asyncio.sleep(1)
|
|
1042
|
+
except asyncio.CancelledError:
|
|
1043
|
+
pass
|
|
1044
|
+
|
|
1045
|
+
logger.info('Hosted worker pool shutdown complete')
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
if __name__ == '__main__':
|
|
1049
|
+
asyncio.run(main())
|