agyqueue 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agyqueue/mcp_server.py ADDED
@@ -0,0 +1,438 @@
1
+ import uuid
2
+ import json
3
+ import logging
4
+ import sys
5
+ import os
6
+ from starlette.requests import Request
7
+ from starlette.responses import JSONResponse, HTMLResponse
8
+ from mcp.server.fastmcp import FastMCP
9
+ from agyqueue.models import Task, TaskStatus
10
+ from agyqueue.storage import TaskStore
11
+ from agyqueue.task_queue import TaskQueue
12
+
13
+ # Configure logging to stderr (since stdio is used for MCP messages)
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
17
+ handlers=[logging.StreamHandler(sys.stderr)]
18
+ )
19
+ logger = logging.getLogger("agyqueue.mcp_server")
20
+
21
+ from agyqueue.config import settings
22
+ mcp = FastMCP(
23
+ "AgyQueue",
24
+ host=settings.host,
25
+ port=settings.port
26
+ )
27
+
28
+ store = TaskStore()
29
+ queue = TaskQueue()
30
+
31
+ import asyncio
32
+ import threading
33
+ from datetime import datetime, timezone
34
+ import mcp.types as types
35
+ from mcp.server.session import ServerSession
36
+
37
+ # Session tracking via monkey-patching
38
+ active_sessions = set()
39
+ active_sessions_lock = threading.Lock()
40
+
41
+ async def send_catchup_notifications(session):
42
+ # Give the session a brief moment to complete initialization
43
+ await asyncio.sleep(1.0)
44
+ try:
45
+ tasks = store.list_tasks()
46
+ logger.info(f"[SSE Push] Sending {len(tasks)} catch-up notification(s) to new session {id(session)}.")
47
+ for task in tasks:
48
+ notification = types.JSONRPCNotification(
49
+ jsonrpc="2.0",
50
+ method="notifications/task_updated",
51
+ params={
52
+ "task_id": task.task_id,
53
+ "status": task.status.value,
54
+ "progress": task.progress,
55
+ "step": task.step
56
+ }
57
+ )
58
+ try:
59
+ await session._send(notification)
60
+ except Exception as e:
61
+ logger.debug(f"[SSE Push] Failed catchup for session {id(session)}: {e}")
62
+ except Exception as e:
63
+ logger.error(f"Error in send_catchup_notifications: {e}")
64
+
65
+ original_init = ServerSession.__init__
66
+ def new_init(self, *args, **kwargs):
67
+ original_init(self, *args, **kwargs)
68
+ with active_sessions_lock:
69
+ active_sessions.add(self)
70
+ logger.info(f"[MCP Session] Registered session {id(self)}. Total active: {len(active_sessions)}")
71
+
72
+ try:
73
+ loop = asyncio.get_running_loop()
74
+ loop.create_task(send_catchup_notifications(self))
75
+ except RuntimeError:
76
+ pass
77
+
78
+ original_aexit = ServerSession.__aexit__
79
+ async def new_aexit(self, exc_type, exc_val, exc_tb):
80
+ try:
81
+ await original_aexit(self, exc_type, exc_val, exc_tb)
82
+ finally:
83
+ with active_sessions_lock:
84
+ active_sessions.discard(self)
85
+ logger.info(f"[MCP Session] Discarded session {id(self)}. Total active: {len(active_sessions)}")
86
+
87
+ ServerSession.__init__ = new_init
88
+ ServerSession.__aexit__ = new_aexit
89
+
90
+ # Broadcast task updates to all connected clients
91
+ async def broadcast_task_notification(task_id: str, status: str, progress: int, step: str):
92
+ notification = types.JSONRPCNotification(
93
+ jsonrpc="2.0",
94
+ method="notifications/task_updated",
95
+ params={
96
+ "task_id": task_id,
97
+ "status": status,
98
+ "progress": progress,
99
+ "step": step
100
+ }
101
+ )
102
+ with active_sessions_lock:
103
+ sessions = list(active_sessions)
104
+ if sessions:
105
+ logger.info(f"[SSE Push] Broadcasting task {task_id} status change ({status}) to {len(sessions)} client(s).")
106
+ for session in sessions:
107
+ try:
108
+ await session._send(notification)
109
+ except Exception as e:
110
+ logger.debug(f"[SSE Push] Failed to notify session {id(session)}: {e}")
111
+
112
+ # Dict to track last known state: { task_id: (status, progress, step) }
113
+ last_known_task_states = {}
114
+ db_monitor_task = None
115
+
116
+ def trigger_notifications_wrapper(task: Task):
117
+ try:
118
+ from agyqueue.notifications import notifications
119
+ loop = asyncio.get_running_loop()
120
+ loop.run_in_executor(
121
+ None,
122
+ notifications.trigger_notifications,
123
+ task.task_id,
124
+ task.status.value,
125
+ task.progress,
126
+ task.step,
127
+ task.result,
128
+ task.error
129
+ )
130
+ except Exception as e:
131
+ logger.error(f"Error triggering notifications wrapper: {e}")
132
+
133
+ async def db_monitor_and_broadcast_loop(store: TaskStore, queue: TaskQueue):
134
+ logger.info("Database monitor and push notifications loop started.")
135
+ check_counter = 0
136
+ while True:
137
+ try:
138
+ tasks = store.list_tasks()
139
+ now = datetime.now(timezone.utc).replace(tzinfo=None)
140
+
141
+ # 1. Check for state updates and broadcast
142
+ for task in tasks:
143
+ state_key = task.task_id
144
+ current_state = (task.status.value, task.progress, task.step)
145
+
146
+ if state_key not in last_known_task_states:
147
+ last_known_task_states[state_key] = current_state
148
+ await broadcast_task_notification(task.task_id, task.status.value, task.progress, task.step)
149
+ if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
150
+ trigger_notifications_wrapper(task)
151
+ elif last_known_task_states[state_key] != current_state:
152
+ old_status = last_known_task_states[state_key][0]
153
+ last_known_task_states[state_key] = current_state
154
+ await broadcast_task_notification(task.task_id, task.status.value, task.progress, task.step)
155
+ if task.status != old_status and task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
156
+ trigger_notifications_wrapper(task)
157
+
158
+ # 2. Check for stale tasks (heartbeat timeout) every ~4.5 seconds
159
+ check_counter += 1
160
+ if check_counter >= 15:
161
+ check_counter = 0
162
+ for task in tasks:
163
+ if task.status == TaskStatus.RUNNING:
164
+ try:
165
+ updated_at = datetime.fromisoformat(task.updated_at).replace(tzinfo=None)
166
+ delta = (now - updated_at).total_seconds()
167
+ if delta > 15.0: # Heartbeat timeout threshold
168
+ logger.warning(f"[Timeout Recovery] Task {task.task_id} heartbeat expired (last updated {delta:.1f}s ago). Marking as FAILED.")
169
+ store.update_task(
170
+ task_id=task.task_id,
171
+ status=TaskStatus.FAILED,
172
+ progress=100,
173
+ step="Aborted due to worker heartbeat timeout.",
174
+ error="Worker heartbeat timeout. The background worker executing this task may have crashed."
175
+ )
176
+ # If it has a parent in WAITING state, re-queue parent so it wakes up
177
+ if task.parent_id:
178
+ parent = store.get_task(task.parent_id)
179
+ if parent and parent.status == TaskStatus.WAITING:
180
+ store.update_task(
181
+ task_id=task.parent_id,
182
+ status=TaskStatus.QUEUED,
183
+ progress=60,
184
+ step="Subtask timeout detected. Re-queueing parent for aggregation..."
185
+ )
186
+ queue.enqueue(task.parent_id)
187
+ except Exception as parse_err:
188
+ logger.error(f"Error parsing updated_at for task {task.task_id}: {parse_err}")
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error in DB monitor loop: {e}")
192
+
193
+ await asyncio.sleep(0.3)
194
+
195
+ def start_db_monitor_if_needed():
196
+ global db_monitor_task
197
+ if db_monitor_task is None or db_monitor_task.done():
198
+ try:
199
+ loop = asyncio.get_running_loop()
200
+ db_monitor_task = loop.create_task(db_monitor_and_broadcast_loop(store, queue))
201
+ logger.info("Successfully scheduled background DB monitor loop on running event loop.")
202
+ except RuntimeError:
203
+ pass
204
+
205
+ @mcp.tool()
206
+ def submit_task(prompt: str, task_type: str = "generic") -> str:
207
+ """Submit a long-running task to the background queue.
208
+
209
+ Args:
210
+ prompt: The task instruction or compliance prompt (e.g. 'Validate this k8s manifest')
211
+ task_type: The type of task (e.g., 'manifest_compliance', 'fastapi_gen', 'generic')
212
+
213
+ Returns:
214
+ A JSON string containing the generated task_id and initial status.
215
+ """
216
+ start_db_monitor_if_needed()
217
+ task_id = f"agy-{uuid.uuid4().hex[:8]}"
218
+ task = Task(
219
+ task_id=task_id,
220
+ prompt=prompt,
221
+ task_type=task_type,
222
+ status=TaskStatus.QUEUED,
223
+ progress=0,
224
+ step="Queued in AgyQueue"
225
+ )
226
+ store.save_task(task)
227
+ queue.enqueue(task_id)
228
+
229
+ logger.info(f"Submitted task {task_id} of type {task_type}")
230
+ return json.dumps({
231
+ "task_id": task_id,
232
+ "status": "QUEUED",
233
+ "message": f"Task {task_id} successfully submitted. Use get_task_status to monitor progress."
234
+ }, indent=2)
235
+
236
+ @mcp.tool()
237
+ def get_task_status(task_id: str) -> str:
238
+ """Check the execution status, progress, and current step of a task.
239
+
240
+ Args:
241
+ task_id: The unique task ID returned by submit_task.
242
+
243
+ Returns:
244
+ A JSON string containing the status, progress, and current step details.
245
+ """
246
+ start_db_monitor_if_needed()
247
+ task = store.get_task(task_id)
248
+ if not task:
249
+ return json.dumps({"error": f"Task {task_id} not found"}, indent=2)
250
+
251
+ return json.dumps({
252
+ "task_id": task.task_id,
253
+ "prompt": task.prompt,
254
+ "task_type": task.task_type,
255
+ "status": task.status.value,
256
+ "progress": task.progress,
257
+ "step": task.step,
258
+ "updated_at": task.updated_at,
259
+ "result": task.result,
260
+ "error": task.error
261
+ }, indent=2)
262
+
263
+ @mcp.tool()
264
+ def get_task_result(task_id: str) -> str:
265
+ """Retrieve the final execution result or error of a completed/failed task.
266
+
267
+ Args:
268
+ task_id: The unique task ID returned by submit_task.
269
+
270
+ Returns:
271
+ A JSON string containing the task result or failure reason.
272
+ """
273
+ start_db_monitor_if_needed()
274
+ task = store.get_task(task_id)
275
+ if not task:
276
+ return json.dumps({"error": f"Task {task_id} not found"}, indent=2)
277
+
278
+ if task.status in (TaskStatus.QUEUED, TaskStatus.RUNNING, TaskStatus.WAITING):
279
+ return json.dumps({
280
+ "task_id": task.task_id,
281
+ "status": task.status.value,
282
+ "progress": task.progress,
283
+ "message": "Task is still running. Please wait for completion before requesting results."
284
+ }, indent=2)
285
+
286
+ return json.dumps({
287
+ "task_id": task.task_id,
288
+ "status": task.status.value,
289
+ "result": task.result,
290
+ "error": task.error
291
+ }, indent=2)
292
+
293
+ @mcp.tool()
294
+ def list_tasks() -> str:
295
+ """List all submitted tasks and their current state summaries.
296
+
297
+ Returns:
298
+ A JSON string listing all tasks.
299
+ """
300
+ start_db_monitor_if_needed()
301
+ tasks = store.list_tasks()
302
+ if not tasks:
303
+ return json.dumps({"message": "No tasks in the queue"}, indent=2)
304
+
305
+ return json.dumps([
306
+ {
307
+ "task_id": t.task_id,
308
+ "prompt": t.prompt[:60] + "..." if len(t.prompt) > 60 else t.prompt,
309
+ "task_type": t.task_type,
310
+ "status": t.status.value,
311
+ "progress": t.progress,
312
+ "step": t.step,
313
+ "created_at": t.created_at
314
+ } for t in tasks
315
+ ], indent=2)
316
+
317
+ @mcp.tool()
318
+ def cancel_task(task_id: str) -> str:
319
+ """Request cancellation of a queued or running task.
320
+
321
+ Args:
322
+ task_id: The unique task ID to cancel.
323
+
324
+ Returns:
325
+ A JSON string confirming the cancellation status.
326
+ """
327
+ start_db_monitor_if_needed()
328
+ task = store.get_task(task_id)
329
+ if not task:
330
+ return json.dumps({"error": f"Task {task_id} not found"}, indent=2)
331
+
332
+ if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
333
+ return json.dumps({
334
+ "task_id": task.task_id,
335
+ "status": task.status.value,
336
+ "message": f"Task is already in a terminal state ({task.status.value})."
337
+ }, indent=2)
338
+
339
+ # Recursively cancel subtasks if this is an orchestrator/parent task
340
+ subtasks = store.get_subtasks(task_id)
341
+ for sub in subtasks:
342
+ if sub.status not in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
343
+ store.update_task(
344
+ task_id=sub.task_id,
345
+ status=TaskStatus.CANCELLED,
346
+ progress=100,
347
+ step="Cancelled by parent orchestrator cancellation request",
348
+ error="Cancelled"
349
+ )
350
+ logger.info(f"Cancelled subtask {sub.task_id} of parent task {task_id}")
351
+
352
+ # Cancel the main task
353
+ store.update_task(
354
+ task_id=task_id,
355
+ status=TaskStatus.CANCELLED,
356
+ progress=100,
357
+ step="Cancelled by user request.",
358
+ error="Cancelled"
359
+ )
360
+ logger.info(f"Cancelled task {task_id}")
361
+
362
+ return json.dumps({
363
+ "task_id": task_id,
364
+ "status": "CANCELLED",
365
+ "message": f"Task {task_id} cancellation requested. Active workloads will be aborted."
366
+ }, indent=2)
367
+
368
+ # REST API Custom Routes
369
+ @mcp.custom_route("/api/tasks", methods=["POST"])
370
+ async def api_submit_task(request: Request) -> JSONResponse:
371
+ try:
372
+ data = await request.json()
373
+ except Exception:
374
+ return JSONResponse({"error": "Invalid JSON payload"}, status_code=400)
375
+
376
+ prompt = data.get("prompt")
377
+ if not prompt:
378
+ return JSONResponse({"error": "Missing required field: prompt"}, status_code=400)
379
+
380
+ task_type = data.get("task_type", "generic")
381
+ res_str = submit_task(prompt, task_type)
382
+ return JSONResponse(json.loads(res_str))
383
+
384
+ @mcp.custom_route("/api/tasks/{task_id}", methods=["GET"])
385
+ async def api_get_task_status(request: Request) -> JSONResponse:
386
+ task_id = request.path_params.get("task_id")
387
+ task = store.get_task(task_id)
388
+ if not task:
389
+ return JSONResponse({"error": f"Task {task_id} not found"}, status_code=404)
390
+ res_str = get_task_status(task_id)
391
+ return JSONResponse(json.loads(res_str))
392
+
393
+ @mcp.custom_route("/api/tasks/{task_id}/result", methods=["GET"])
394
+ async def api_get_task_result(request: Request) -> JSONResponse:
395
+ task_id = request.path_params.get("task_id")
396
+ task = store.get_task(task_id)
397
+ if not task:
398
+ return JSONResponse({"error": f"Task {task_id} not found"}, status_code=404)
399
+ res_str = get_task_result(task_id)
400
+ return JSONResponse(json.loads(res_str))
401
+
402
+ @mcp.custom_route("/api/tasks", methods=["GET"])
403
+ async def api_list_tasks(request: Request) -> JSONResponse:
404
+ res_str = list_tasks()
405
+ return JSONResponse(json.loads(res_str))
406
+
407
+ @mcp.custom_route("/api/tasks/{task_id}/cancel", methods=["POST"])
408
+ async def api_cancel_task(request: Request) -> JSONResponse:
409
+ task_id = request.path_params.get("task_id")
410
+ task = store.get_task(task_id)
411
+ if not task:
412
+ return JSONResponse({"error": f"Task {task_id} not found"}, status_code=404)
413
+ res_str = cancel_task(task_id)
414
+ return JSONResponse(json.loads(res_str))
415
+ @mcp.custom_route("/dashboard", methods=["GET"])
416
+ async def serve_dashboard(request: Request) -> HTMLResponse:
417
+ current_dir = os.path.dirname(os.path.abspath(__file__))
418
+ html_path = os.path.join(current_dir, "dashboard.html")
419
+ try:
420
+ with open(html_path, "r", encoding="utf-8") as f:
421
+ html_content = f.read()
422
+ return HTMLResponse(html_content)
423
+ except Exception as e:
424
+ return HTMLResponse(f"<h3>Error loading dashboard: {str(e)}</h3>", status_code=500)
425
+
426
+ def main():
427
+ transport = settings.transport
428
+
429
+ if transport == "sse":
430
+ logger.info("Starting AgyQueue MCP server via SSE transport...")
431
+ mcp.run(transport="sse")
432
+ else:
433
+ logger.info("Starting AgyQueue MCP server via STDIO transport...")
434
+ mcp.run(transport="stdio")
435
+
436
+ if __name__ == "__main__":
437
+ main()
438
+
agyqueue/models.py ADDED
@@ -0,0 +1,38 @@
1
+ import json
2
+ from dataclasses import dataclass, asdict, field
3
+ from datetime import datetime, timezone
4
+ from enum import Enum
5
+ from typing import Optional, Any
6
+
7
+ class TaskStatus(str, Enum):
8
+ QUEUED = "QUEUED"
9
+ RUNNING = "RUNNING"
10
+ WAITING = "WAITING"
11
+ COMPLETED = "COMPLETED"
12
+ FAILED = "FAILED"
13
+ CANCELLED = "CANCELLED"
14
+
15
+ @dataclass
16
+ class Task:
17
+ task_id: str
18
+ prompt: str
19
+ task_type: str
20
+ status: TaskStatus = TaskStatus.QUEUED
21
+ progress: int = 0
22
+ step: str = "Queued"
23
+ result: Optional[str] = None
24
+ error: Optional[str] = None
25
+ parent_id: Optional[str] = None
26
+ created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None).isoformat())
27
+ updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None).isoformat())
28
+
29
+ def to_dict(self) -> dict[str, Any]:
30
+ return asdict(self)
31
+
32
+ @classmethod
33
+ def from_dict(cls, data: dict[str, Any]) -> "Task":
34
+ # Copy to avoid modifying original
35
+ d = dict(data)
36
+ if isinstance(d.get("status"), str):
37
+ d["status"] = TaskStatus(d["status"])
38
+ return cls(**d)
@@ -0,0 +1,187 @@
1
+ import os
2
+ import logging
3
+ import json
4
+ import urllib.request
5
+ import smtplib
6
+ from email.mime.text import MIMEText
7
+ from abc import ABC, abstractmethod
8
+ from typing import List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class BaseNotificationBackend(ABC):
13
+ """Abstract base class for notification channels."""
14
+
15
+ @abstractmethod
16
+ def send_notification(self, task_id: str, status: str, progress: int, step: str, result: Optional[str] = None, error: Optional[str] = None) -> None:
17
+ """Sends a notification about a task's status change."""
18
+ pass
19
+
20
+
21
+ class SlackWebhookBackend(BaseNotificationBackend):
22
+ """Sends task updates to a configured Slack Webhook channel."""
23
+
24
+ def __init__(self, webhook_url: str):
25
+ self.webhook_url = webhook_url
26
+
27
+ def send_notification(self, task_id: str, status: str, progress: int, step: str, result: Optional[str] = None, error: Optional[str] = None) -> None:
28
+ if not self.webhook_url:
29
+ return
30
+
31
+ emoji = "ℹ️"
32
+ color = "#36a64f"
33
+
34
+ if status == "COMPLETED":
35
+ emoji = "✅"
36
+ color = "#2eb886"
37
+ elif status == "FAILED":
38
+ emoji = "❌"
39
+ color = "#a30200"
40
+ elif status == "CANCELLED":
41
+ emoji = "⚠️"
42
+ color = "#e0a115"
43
+ elif status == "RUNNING":
44
+ emoji = "🔄"
45
+ color = "#1d9bd1"
46
+
47
+ title = f"{emoji} AgyQueue Task {status}: {task_id}"
48
+
49
+ fields = [
50
+ {"title": "Task ID", "value": task_id, "short": True},
51
+ {"title": "Status", "value": status, "short": True},
52
+ {"title": "Progress", "value": f"{progress}%", "short": True},
53
+ {"title": "Current Step", "value": step, "short": False}
54
+ ]
55
+
56
+ if error:
57
+ fields.append({"title": "Error Message", "value": error, "short": False})
58
+
59
+ payload = {
60
+ "attachments": [
61
+ {
62
+ "color": color,
63
+ "title": title,
64
+ "fields": fields,
65
+ "fallback": f"AgyQueue Task {task_id} updated: {status} ({progress}%) - {step}"
66
+ }
67
+ ]
68
+ }
69
+
70
+ try:
71
+ req_data = json.dumps(payload).encode("utf-8")
72
+ req = urllib.request.Request(
73
+ self.webhook_url,
74
+ data=req_data,
75
+ headers={"Content-Type": "application/json"},
76
+ method="POST"
77
+ )
78
+ with urllib.request.urlopen(req, timeout=5.0) as resp:
79
+ resp.read()
80
+ logger.info(f"[Slack Notification] Successfully sent update for task {task_id}")
81
+ except Exception as e:
82
+ logger.error(f"[Slack Notification] Failed to send update for task {task_id}: {e}")
83
+
84
+
85
+ class EmailSMTPBackend(BaseNotificationBackend):
86
+ """Sends task updates via email using SMTP."""
87
+
88
+ def __init__(self, smtp_host: str, smtp_port: int, smtp_user: Optional[str], smtp_pass: Optional[str], email_from: str, email_to: str):
89
+ self.smtp_host = smtp_host
90
+ self.smtp_port = smtp_port
91
+ self.smtp_user = smtp_user
92
+ self.smtp_pass = smtp_pass
93
+ self.email_from = email_from
94
+ self.email_to = email_to
95
+
96
+ def send_notification(self, task_id: str, status: str, progress: int, step: str, result: Optional[str] = None, error: Optional[str] = None) -> None:
97
+ if not self.smtp_host or not self.email_to:
98
+ return
99
+
100
+ subject = f"[AgyQueue] Task {status}: {task_id}"
101
+
102
+ body_parts = [
103
+ f"AgyQueue Task Update",
104
+ f"----------------------------------------",
105
+ f"Task ID: {task_id}",
106
+ f"Status: {status}",
107
+ f"Progress: {progress}%",
108
+ f"Current Step: {step}",
109
+ f"----------------------------------------",
110
+ ]
111
+
112
+ if error:
113
+ body_parts.append(f"Error details:\n{error}\n")
114
+ elif result and status == "COMPLETED":
115
+ body_parts.append(f"Execution Output:\n{result}\n")
116
+
117
+ msg = MIMEText("\n".join(body_parts))
118
+ msg["Subject"] = subject
119
+ msg["From"] = self.email_from
120
+ msg["To"] = self.email_to
121
+
122
+ try:
123
+ with smtplib.SMTP(self.smtp_host, self.smtp_port, timeout=5.0) as server:
124
+ if self.smtp_user and self.smtp_pass:
125
+ server.starttls()
126
+ server.login(self.smtp_user, self.smtp_pass)
127
+ server.sendmail(self.email_from, [self.email_to], msg.as_string())
128
+ logger.info(f"[Email Notification] Successfully sent update for task {task_id}")
129
+ except Exception as e:
130
+ logger.error(f"[Email Notification] Failed to send email for task {task_id}: {e}")
131
+
132
+
133
+ class NotificationManager:
134
+ """Manages routing of task notifications across multiple active channels."""
135
+
136
+ def __init__(self):
137
+ self.backends: List[BaseNotificationBackend] = []
138
+ self._load_backends()
139
+
140
+ def _load_backends(self) -> None:
141
+ # Load from config / environment
142
+ from agyqueue.config import settings
143
+
144
+ # Comma-separated active backends (e.g. "slack,email")
145
+ active_channels = os.environ.get("AGYQUEUE_NOTIFICATIONS", "").lower().split(",")
146
+ active_channels = [c.strip() for c in active_channels if c.strip()]
147
+
148
+ if "slack" in active_channels:
149
+ webhook = os.environ.get("SLACK_WEBHOOK_URL")
150
+ if webhook:
151
+ self.backends.append(SlackWebhookBackend(webhook))
152
+ logger.info("[Notification Manager] Slack channel enabled.")
153
+ else:
154
+ logger.warning("[Notification Manager] Slack enabled but SLACK_WEBHOOK_URL is missing.")
155
+
156
+ if "email" in active_channels:
157
+ smtp_host = os.environ.get("SMTP_HOST")
158
+ email_to = os.environ.get("SMTP_TO")
159
+ if smtp_host and email_to:
160
+ try:
161
+ port = int(os.environ.get("SMTP_PORT", "587"))
162
+ except ValueError:
163
+ port = 587
164
+ self.backends.append(
165
+ EmailSMTPBackend(
166
+ smtp_host=smtp_host,
167
+ smtp_port=port,
168
+ smtp_user=os.environ.get("SMTP_USER"),
169
+ smtp_pass=os.environ.get("SMTP_PASSWORD"),
170
+ email_from=os.environ.get("SMTP_FROM", "noreply@agyqueue.internal"),
171
+ email_to=email_to
172
+ )
173
+ )
174
+ logger.info("[Notification Manager] Email SMTP channel enabled.")
175
+ else:
176
+ logger.warning("[Notification Manager] Email enabled but SMTP_HOST or SMTP_TO is missing.")
177
+
178
+ def trigger_notifications(self, task_id: str, status: str, progress: int, step: str, result: Optional[str] = None, error: Optional[str] = None) -> None:
179
+ """Broadcasts updates to all active notification backends."""
180
+ for backend in self.backends:
181
+ try:
182
+ backend.send_notification(task_id, status, progress, step, result, error)
183
+ except Exception as e:
184
+ logger.error(f"Error triggering notification on backend {backend.__class__.__name__}: {e}")
185
+
186
+ # Global notification manager instance
187
+ notifications = NotificationManager()