flowyml 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. flowyml/core/execution_status.py +1 -0
  2. flowyml/core/executor.py +175 -3
  3. flowyml/core/observability.py +7 -7
  4. flowyml/core/resources.py +12 -12
  5. flowyml/core/retry_policy.py +2 -2
  6. flowyml/core/scheduler.py +9 -9
  7. flowyml/core/scheduler_config.py +2 -3
  8. flowyml/core/submission_result.py +4 -4
  9. flowyml/stacks/bridge.py +9 -9
  10. flowyml/stacks/plugins.py +2 -2
  11. flowyml/stacks/registry.py +21 -0
  12. flowyml/storage/materializers/base.py +33 -0
  13. flowyml/storage/metadata.py +3 -1042
  14. flowyml/storage/remote.py +590 -0
  15. flowyml/storage/sql.py +951 -0
  16. flowyml/ui/backend/dependencies.py +28 -0
  17. flowyml/ui/backend/main.py +4 -79
  18. flowyml/ui/backend/routers/assets.py +170 -9
  19. flowyml/ui/backend/routers/client.py +6 -6
  20. flowyml/ui/backend/routers/execution.py +2 -2
  21. flowyml/ui/backend/routers/experiments.py +53 -6
  22. flowyml/ui/backend/routers/metrics.py +23 -68
  23. flowyml/ui/backend/routers/pipelines.py +19 -10
  24. flowyml/ui/backend/routers/runs.py +287 -9
  25. flowyml/ui/backend/routers/schedules.py +5 -21
  26. flowyml/ui/backend/routers/stats.py +14 -0
  27. flowyml/ui/backend/routers/traces.py +37 -53
  28. flowyml/ui/backend/routers/websocket.py +121 -0
  29. flowyml/ui/frontend/dist/assets/index-CBUXOWze.css +1 -0
  30. flowyml/ui/frontend/dist/assets/index-DF8dJaFL.js +629 -0
  31. flowyml/ui/frontend/dist/index.html +2 -2
  32. flowyml/ui/frontend/package-lock.json +289 -0
  33. flowyml/ui/frontend/package.json +1 -0
  34. flowyml/ui/frontend/src/app/compare/page.jsx +213 -0
  35. flowyml/ui/frontend/src/app/experiments/compare/page.jsx +289 -0
  36. flowyml/ui/frontend/src/app/experiments/page.jsx +61 -1
  37. flowyml/ui/frontend/src/app/runs/[runId]/page.jsx +418 -203
  38. flowyml/ui/frontend/src/app/runs/page.jsx +64 -3
  39. flowyml/ui/frontend/src/app/settings/page.jsx +1 -1
  40. flowyml/ui/frontend/src/app/tokens/page.jsx +8 -6
  41. flowyml/ui/frontend/src/components/ArtifactViewer.jsx +159 -0
  42. flowyml/ui/frontend/src/components/NavigationTree.jsx +26 -9
  43. flowyml/ui/frontend/src/components/PipelineGraph.jsx +26 -24
  44. flowyml/ui/frontend/src/components/RunDetailsPanel.jsx +42 -14
  45. flowyml/ui/frontend/src/router/index.jsx +4 -0
  46. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/METADATA +3 -1
  47. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/RECORD +50 -42
  48. flowyml/ui/frontend/dist/assets/index-DcYwrn2j.css +0 -1
  49. flowyml/ui/frontend/dist/assets/index-Dlz_ygOL.js +0 -592
  50. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/WHEEL +0 -0
  51. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/entry_points.txt +0 -0
  52. {flowyml-1.3.0.dist-info → flowyml-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,19 +2,15 @@ from fastapi import APIRouter, HTTPException
2
2
  from pydantic import BaseModel
3
3
  from flowyml.storage.metadata import SQLiteMetadataStore
4
4
  from flowyml.core.project import ProjectManager
5
- from typing import Optional
6
5
  import json
6
+ from flowyml.ui.backend.dependencies import get_store
7
7
 
8
8
  router = APIRouter()
9
9
 
10
10
 
11
- def get_store():
12
- return SQLiteMetadataStore()
13
-
14
-
15
11
  def _iter_metadata_stores():
16
12
  """Yield tuples of (project_name, store) including global and project stores."""
17
- stores: list[tuple[Optional[str], SQLiteMetadataStore]] = [(None, SQLiteMetadataStore())]
13
+ stores: list[tuple[str | None, SQLiteMetadataStore]] = [(None, get_store())]
18
14
  try:
19
15
  manager = ProjectManager()
20
16
  for project_meta in manager.list_projects():
@@ -50,16 +46,43 @@ def _sort_runs(runs):
50
46
 
51
47
 
52
48
  @router.get("/")
53
- async def list_runs(limit: int = 20, project: str = None):
54
- """List all runs, optionally filtered by project."""
49
+ async def list_runs(
50
+ limit: int = 20,
51
+ project: str = None,
52
+ pipeline_name: str = None,
53
+ status: str = None,
54
+ ):
55
+ """List all runs, optionally filtered by project, pipeline_name, and status."""
55
56
  try:
56
57
  combined = []
57
58
  for project_name, store in _iter_metadata_stores():
58
59
  # Skip other projects if filtering by project name
59
60
  if project and project_name and project != project_name:
60
61
  continue
61
- store_runs = store.list_runs(limit=limit)
62
+
63
+ # Use store's query method if available for better performance, or list_runs
64
+ # SQLMetadataStore has query method.
65
+ if hasattr(store, "query"):
66
+ filters = {}
67
+ if pipeline_name:
68
+ filters["pipeline_name"] = pipeline_name
69
+ if status:
70
+ filters["status"] = status
71
+
72
+ # We can't pass limit to query easily if it doesn't support it,
73
+ # but SQLMetadataStore.query usually returns all matching.
74
+ # We'll slice later.
75
+ store_runs = store.query(**filters)
76
+ else:
77
+ store_runs = store.list_runs(limit=limit)
78
+
62
79
  for run in store_runs:
80
+ # Apply filters if store didn't (e.g. if we used list_runs or store doesn't support query)
81
+ if pipeline_name and run.get("pipeline_name") != pipeline_name:
82
+ continue
83
+ if status and run.get("status") != status:
84
+ continue
85
+
63
86
  combined.append((run, project_name))
64
87
 
65
88
  runs = _deduplicate_runs(combined)
@@ -73,12 +96,74 @@ async def list_runs(limit: int = 20, project: str = None):
73
96
  return {"runs": [], "error": str(e)}
74
97
 
75
98
 
99
+ class RunCreate(BaseModel):
100
+ run_id: str
101
+ pipeline_name: str
102
+ status: str = "pending"
103
+ start_time: str
104
+ end_time: str | None = None
105
+ duration: float | None = None
106
+ metadata: dict = {}
107
+ project: str | None = None
108
+ metrics: dict | None = None
109
+ parameters: dict | None = None
110
+
111
+
112
+ @router.post("/")
113
+ async def create_run(run: RunCreate):
114
+ """Create or update a run."""
115
+ try:
116
+ store = get_store()
117
+
118
+ # Prepare metadata dict
119
+ metadata = run.metadata.copy()
120
+ metadata.update(
121
+ {
122
+ "pipeline_name": run.pipeline_name,
123
+ "status": run.status,
124
+ "start_time": run.start_time,
125
+ "end_time": run.end_time,
126
+ "duration": run.duration,
127
+ "project": run.project,
128
+ },
129
+ )
130
+
131
+ if run.metrics:
132
+ metadata["metrics"] = run.metrics
133
+
134
+ if run.parameters:
135
+ metadata["parameters"] = run.parameters
136
+
137
+ store.save_run(run.run_id, metadata)
138
+ return {"status": "success", "run_id": run.run_id}
139
+ except Exception as e:
140
+ raise HTTPException(status_code=500, detail=str(e))
141
+
142
+
76
143
  @router.get("/{run_id}")
77
144
  async def get_run(run_id: str):
78
145
  """Get details for a specific run."""
79
146
  run, _ = _find_run(run_id)
80
147
  if not run:
81
148
  raise HTTPException(status_code=404, detail="Run not found")
149
+
150
+ # Mark dead steps
151
+ dead_steps = _get_dead_steps(run_id)
152
+ if dead_steps and "steps" in run:
153
+ for step_name in dead_steps:
154
+ if step_name in run["steps"]:
155
+ # Only mark as dead if it was running
156
+ if run["steps"][step_name].get("status") == "running":
157
+ run["steps"][step_name]["status"] = "dead"
158
+ run["steps"][step_name]["success"] = False
159
+
160
+ # Inject heartbeat timestamps
161
+ with _heartbeat_lock:
162
+ if run_id in _heartbeat_timestamps:
163
+ for step_name, ts in _heartbeat_timestamps[run_id].items():
164
+ if step_name in run.get("steps", {}):
165
+ run["steps"][step_name]["last_heartbeat"] = ts
166
+
82
167
  return run
83
168
 
84
169
 
@@ -206,3 +291,196 @@ async def get_cloud_status(run_id: str):
206
291
  "cloud_status": cloud_status,
207
292
  "cloud_error": cloud_error,
208
293
  }
294
+
295
+
296
+ class HeartbeatRequest(BaseModel):
297
+ step_name: str
298
+ status: str = "running"
299
+
300
+
301
+ # In-memory storage for heartbeat timestamps
302
+ # Format: {run_id: {step_name: last_heartbeat_timestamp}}
303
+ _heartbeat_timestamps: dict[str, dict[str, float]] = {}
304
+ _heartbeat_lock = __import__("threading").Lock()
305
+
306
+ # Heartbeat interval in seconds (should match executor's interval)
307
+ HEARTBEAT_INTERVAL = 5
308
+ # Number of missed heartbeats before marking step as dead
309
+ DEAD_THRESHOLD = 3
310
+
311
+
312
+ def _record_heartbeat(run_id: str, step_name: str) -> None:
313
+ """Record heartbeat timestamp for a step."""
314
+ import time
315
+
316
+ with _heartbeat_lock:
317
+ if run_id not in _heartbeat_timestamps:
318
+ _heartbeat_timestamps[run_id] = {}
319
+ _heartbeat_timestamps[run_id][step_name] = time.time()
320
+
321
+
322
+ def _get_dead_steps(run_id: str) -> list[str]:
323
+ """Get list of steps that have missed too many heartbeats."""
324
+ import time
325
+
326
+ dead_steps = []
327
+ timeout = HEARTBEAT_INTERVAL * DEAD_THRESHOLD
328
+
329
+ with _heartbeat_lock:
330
+ if run_id not in _heartbeat_timestamps:
331
+ return []
332
+
333
+ current_time = time.time()
334
+ for step_name, last_heartbeat in _heartbeat_timestamps[run_id].items():
335
+ if current_time - last_heartbeat > timeout:
336
+ dead_steps.append(step_name)
337
+
338
+ return dead_steps
339
+
340
+
341
+ def _cleanup_heartbeats(run_id: str) -> None:
342
+ """Remove heartbeat tracking for a completed run."""
343
+ with _heartbeat_lock:
344
+ _heartbeat_timestamps.pop(run_id, None)
345
+
346
+
347
+ @router.post("/{run_id}/steps/{step_name}/heartbeat")
348
+ async def step_heartbeat(run_id: str, step_name: str, heartbeat: HeartbeatRequest):
349
+ """Receive heartbeat from a running step.
350
+
351
+ Returns:
352
+ dict: Instructions for the step (e.g., {"action": "continue"} or {"action": "stop"})
353
+ """
354
+ store = _find_store_for_run(run_id)
355
+
356
+ # Record heartbeat timestamp
357
+ _record_heartbeat(run_id, step_name)
358
+
359
+ # Check if run is marked for stopping
360
+ run = store.load_run(run_id)
361
+ if not run:
362
+ raise HTTPException(status_code=404, detail="Run not found")
363
+
364
+ run_status = run.get("status")
365
+ if run_status in ["stopping", "stopped", "cancelled", "cancelling"]:
366
+ return {"action": "stop"}
367
+
368
+ return {"action": "continue"}
369
+
370
+
371
+ @router.get("/{run_id}/dead-steps")
372
+ async def get_dead_steps(run_id: str):
373
+ """Get list of steps that appear to be dead (missed heartbeats)."""
374
+ dead_steps = _get_dead_steps(run_id)
375
+ return {"dead_steps": dead_steps}
376
+
377
+
378
+ @router.post("/{run_id}/stop")
379
+ async def stop_run(run_id: str):
380
+ """Signal a run to stop."""
381
+ store = _find_store_for_run(run_id)
382
+
383
+ try:
384
+ # Update run status to STOPPING
385
+ # This will be picked up by the next heartbeat
386
+ store.update_run_status(run_id, "stopping")
387
+ return {"status": "success", "message": "Stop signal sent"}
388
+ except Exception as e:
389
+ raise HTTPException(status_code=500, detail=str(e))
390
+
391
+
392
+ class LogChunk(BaseModel):
393
+ content: str
394
+ level: str = "INFO"
395
+ timestamp: str | None = None
396
+
397
+
398
+ @router.post("/{run_id}/steps/{step_name}/logs")
399
+ async def post_step_logs(run_id: str, step_name: str, log_chunk: LogChunk):
400
+ """Receive log chunk from a running step."""
401
+ import anyio
402
+
403
+ from flowyml.utils.config import get_config
404
+
405
+ # Store logs in the runs directory
406
+ runs_dir = get_config().runs_dir
407
+ log_dir = runs_dir / run_id / "logs"
408
+ log_dir.mkdir(parents=True, exist_ok=True)
409
+
410
+ log_file = log_dir / f"{step_name}.log"
411
+
412
+ # Append log content
413
+ timestamp = log_chunk.timestamp or ""
414
+ line = f"[{timestamp}] [{log_chunk.level}] {log_chunk.content}\n"
415
+
416
+ def write_log():
417
+ with open(log_file, "a") as f:
418
+ f.write(line)
419
+
420
+ await anyio.to_thread.run_sync(write_log)
421
+
422
+ # Broadcast to WebSocket clients
423
+ try:
424
+ from flowyml.ui.backend.routers.websocket import manager
425
+
426
+ await manager.broadcast_log(run_id, step_name, log_chunk.content)
427
+ except Exception:
428
+ pass # Ignore WebSocket broadcast failures
429
+
430
+ return {"status": "success"}
431
+
432
+
433
+ @router.get("/{run_id}/steps/{step_name}/logs")
434
+ async def get_step_logs(run_id: str, step_name: str, offset: int = 0):
435
+ """Get logs for a specific step."""
436
+ import anyio
437
+
438
+ from flowyml.utils.config import get_config
439
+
440
+ runs_dir = get_config().runs_dir
441
+ log_file = runs_dir / run_id / "logs" / f"{step_name}.log"
442
+
443
+ if not log_file.exists():
444
+ return {"logs": "", "offset": 0, "has_more": False}
445
+
446
+ def read_log():
447
+ with open(log_file) as f:
448
+ return f.read()
449
+
450
+ content = await anyio.to_thread.run_sync(read_log)
451
+
452
+ # Return content from offset
453
+ if offset > 0 and offset < len(content):
454
+ content = content[offset:]
455
+
456
+ return {
457
+ "logs": content,
458
+ "offset": offset + len(content),
459
+ "has_more": False, # For now, always return all available
460
+ }
461
+
462
+
463
+ @router.get("/{run_id}/logs")
464
+ async def get_run_logs(run_id: str):
465
+ """Get all logs for a run."""
466
+ import anyio
467
+
468
+ from flowyml.utils.config import get_config
469
+
470
+ runs_dir = get_config().runs_dir
471
+ log_dir = runs_dir / run_id / "logs"
472
+
473
+ if not log_dir.exists():
474
+ return {"logs": {}}
475
+
476
+ def read_all_logs():
477
+ logs = {}
478
+ for log_file in log_dir.glob("*.log"):
479
+ step_name = log_file.stem
480
+ with open(log_file) as f:
481
+ logs[step_name] = f.read()
482
+ return logs
483
+
484
+ logs = await anyio.to_thread.run_sync(read_all_logs)
485
+
486
+ return {"logs": logs}
@@ -71,16 +71,13 @@ async def create_schedule(schedule: ScheduleRequest):
71
71
  else:
72
72
  # Check if it's a historical pipeline (in metadata but not registry)
73
73
  # This means we can't run it because we don't have the code loaded
74
- from flowyml.storage.metadata import SQLiteMetadataStore
74
+ from flowyml.ui.backend.dependencies import get_store
75
75
 
76
- store = SQLiteMetadataStore()
76
+ store = get_store()
77
77
  pipelines = store.list_pipelines()
78
78
 
79
79
  if schedule.pipeline_name in pipelines:
80
80
  # Try to load pipeline definition
81
- from flowyml.storage.metadata import SQLiteMetadataStore
82
-
83
- store = SQLiteMetadataStore()
84
81
  pipeline_def = store.get_pipeline_definition(schedule.pipeline_name)
85
82
 
86
83
  if pipeline_def:
@@ -188,7 +185,6 @@ async def get_schedule_history(schedule_name: str, limit: int = 50):
188
185
  async def list_registered_pipelines(project: str = None):
189
186
  """List all pipelines available for scheduling."""
190
187
  from flowyml.core.templates import list_templates
191
- from flowyml.storage.metadata import SQLiteMetadataStore
192
188
 
193
189
  registered = pipeline_registry.list_pipelines()
194
190
  templates = list_templates()
@@ -196,22 +192,10 @@ async def list_registered_pipelines(project: str = None):
196
192
  # Also get pipelines from metadata store (historical runs)
197
193
  metadata_pipelines = []
198
194
  try:
199
- store = SQLiteMetadataStore()
200
- import sqlite3
201
-
202
- conn = sqlite3.connect(store.db_path)
203
- cursor = conn.cursor()
204
-
205
- if project:
206
- cursor.execute(
207
- "SELECT DISTINCT pipeline_name FROM runs WHERE project = ? ORDER BY pipeline_name",
208
- (project,),
209
- )
210
- else:
211
- cursor.execute("SELECT DISTINCT pipeline_name FROM runs ORDER BY pipeline_name")
195
+ from flowyml.ui.backend.dependencies import get_store
212
196
 
213
- metadata_pipelines = [row[0] for row in cursor.fetchall()]
214
- conn.close()
197
+ store = get_store()
198
+ metadata_pipelines = store.list_pipelines(project=project)
215
199
  except Exception as e:
216
200
  print(f"Failed to fetch pipelines from metadata store: {e}")
217
201
 
@@ -0,0 +1,14 @@
1
+ from fastapi import APIRouter, HTTPException
2
+ from flowyml.ui.backend.dependencies import get_store
3
+
4
+ router = APIRouter()
5
+
6
+
7
+ @router.get("/")
8
+ async def get_global_stats(project: str | None = None):
9
+ """Get global statistics."""
10
+ try:
11
+ store = get_store()
12
+ return store.get_statistics(project=project)
13
+ except Exception as e:
14
+ raise HTTPException(status_code=500, detail=str(e))
@@ -1,7 +1,6 @@
1
1
  from fastapi import APIRouter, HTTPException
2
- from flowyml.storage.metadata import SQLiteMetadataStore
3
- import contextlib
4
- import builtins
2
+ from pydantic import BaseModel
3
+ from flowyml.ui.backend.dependencies import get_store
5
4
 
6
5
  router = APIRouter()
7
6
 
@@ -14,61 +13,14 @@ async def list_traces(
14
13
  project: str | None = None,
15
14
  ):
16
15
  """List traces, optionally filtered by project."""
17
- store = SQLiteMetadataStore()
18
-
19
- # We need to implement list_traces in metadata store or query manually
20
- # For now, let's query manually via sqlite
21
- import sqlite3
22
-
23
- conn = sqlite3.connect(store.db_path)
24
- cursor = conn.cursor()
25
-
26
- query = "SELECT * FROM traces"
27
- params = []
28
- conditions = []
29
-
30
- if trace_id:
31
- conditions.append("trace_id = ?")
32
- params.append(trace_id)
33
-
34
- if event_type:
35
- conditions.append("event_type = ?")
36
- params.append(event_type)
37
-
38
- if project:
39
- conditions.append("project = ?")
40
- params.append(project)
41
-
42
- if conditions:
43
- query += " WHERE " + " AND ".join(conditions)
44
-
45
- query += " ORDER BY start_time DESC LIMIT ?"
46
- params.append(limit)
47
-
48
- cursor.execute(query, params)
49
- columns = [description[0] for description in cursor.description]
50
- rows = cursor.fetchall()
51
-
52
- traces = []
53
- import json
54
-
55
- for row in rows:
56
- trace = dict(zip(columns, row, strict=False))
57
- # Parse JSON fields
58
- for field in ["inputs", "outputs", "metadata"]:
59
- if trace[field]:
60
- with contextlib.suppress(builtins.BaseException):
61
- trace[field] = json.loads(trace[field])
62
- traces.append(trace)
63
-
64
- conn.close()
65
- return traces
16
+ store = get_store()
17
+ return store.list_traces(limit=limit, trace_id=trace_id, event_type=event_type, project=project)
66
18
 
67
19
 
68
20
  @router.get("/{trace_id}")
69
21
  async def get_trace(trace_id: str):
70
22
  """Get a specific trace tree."""
71
- store = SQLiteMetadataStore()
23
+ store = get_store()
72
24
  events = store.get_trace(trace_id)
73
25
  if not events:
74
26
  raise HTTPException(status_code=404, detail="Trace not found")
@@ -82,3 +34,35 @@ async def get_trace(trace_id: str):
82
34
  return event
83
35
 
84
36
  return [build_tree(root) for root in root_events]
37
+
38
+
39
+ class TraceEventCreate(BaseModel):
40
+ event_id: str
41
+ trace_id: str
42
+ parent_id: str | None = None
43
+ event_type: str
44
+ name: str
45
+ inputs: dict | None = None
46
+ outputs: dict | None = None
47
+ start_time: float | None = None
48
+ end_time: float | None = None
49
+ duration: float | None = None
50
+ status: str | None = None
51
+ error: str | None = None
52
+ metadata: dict | None = None
53
+ prompt_tokens: int | None = None
54
+ completion_tokens: int | None = None
55
+ total_tokens: int | None = None
56
+ cost: float | None = None
57
+ model: str | None = None
58
+
59
+
60
+ @router.post("/")
61
+ async def create_trace_event(event: TraceEventCreate):
62
+ """Create or update a trace event."""
63
+ try:
64
+ store = get_store()
65
+ store.save_trace_event(event.dict())
66
+ return {"status": "success", "event_id": event.event_id}
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=str(e))
@@ -0,0 +1,121 @@
1
+ """WebSocket router for real-time log streaming."""
2
+
3
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
4
+ from datetime import datetime
5
+ import asyncio
6
+ import contextlib
7
+
8
+ router = APIRouter()
9
+
10
+
11
+ class ConnectionManager:
12
+ """Manage WebSocket connections for log streaming."""
13
+
14
+ def __init__(self):
15
+ # Format: {run_id: {step_name: [websocket, ...]}}
16
+ self.active_connections: dict[str, dict[str, list[WebSocket]]] = {}
17
+ self._lock = asyncio.Lock()
18
+
19
+ async def connect(self, websocket: WebSocket, run_id: str, step_name: str = "__all__"):
20
+ """Accept and track a WebSocket connection."""
21
+ await websocket.accept()
22
+ async with self._lock:
23
+ if run_id not in self.active_connections:
24
+ self.active_connections[run_id] = {}
25
+ if step_name not in self.active_connections[run_id]:
26
+ self.active_connections[run_id][step_name] = []
27
+ self.active_connections[run_id][step_name].append(websocket)
28
+
29
+ async def disconnect(self, websocket: WebSocket, run_id: str, step_name: str = "__all__"):
30
+ """Remove a WebSocket connection."""
31
+ async with self._lock:
32
+ if run_id in self.active_connections:
33
+ if step_name in self.active_connections[run_id]:
34
+ with contextlib.suppress(ValueError):
35
+ self.active_connections[run_id][step_name].remove(websocket)
36
+ if not self.active_connections[run_id][step_name]:
37
+ del self.active_connections[run_id][step_name]
38
+ if not self.active_connections[run_id]:
39
+ del self.active_connections[run_id]
40
+
41
+ async def broadcast_log(self, run_id: str, step_name: str, log_content: str):
42
+ """Broadcast log content to all connected clients for a run/step."""
43
+ async with self._lock:
44
+ connections_to_notify = []
45
+
46
+ if run_id in self.active_connections:
47
+ # Notify step-specific subscribers
48
+ if step_name in self.active_connections[run_id]:
49
+ connections_to_notify.extend(self.active_connections[run_id][step_name])
50
+ # Notify all-steps subscribers
51
+ if "__all__" in self.active_connections[run_id]:
52
+ connections_to_notify.extend(self.active_connections[run_id]["__all__"])
53
+
54
+ # Send to all relevant connections (outside lock to avoid blocking)
55
+ message = {
56
+ "type": "log",
57
+ "step": step_name,
58
+ "content": log_content,
59
+ "timestamp": datetime.now().isoformat(),
60
+ }
61
+ for ws in connections_to_notify:
62
+ with contextlib.suppress(Exception):
63
+ await ws.send_json(message)
64
+
65
+
66
+ # Global connection manager instance
67
+ manager = ConnectionManager()
68
+
69
+
70
+ @router.websocket("/ws/runs/{run_id}/logs")
71
+ async def websocket_logs(websocket: WebSocket, run_id: str, step_name: str = "__all__"):
72
+ """WebSocket endpoint for streaming logs.
73
+
74
+ Query params:
75
+ step_name: Optional. Subscribe to specific step logs only.
76
+
77
+ Messages sent:
78
+ {"type": "log", "step": "step_name", "content": "...", "timestamp": "..."}
79
+ {"type": "dead_steps", "steps": ["step1", "step2"]}
80
+ """
81
+ await manager.connect(websocket, run_id, step_name)
82
+ try:
83
+ # Keep connection alive and handle incoming messages
84
+ while True:
85
+ try:
86
+ # Wait for any message (ping/pong or close)
87
+ data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
88
+ # Could handle client commands here if needed
89
+ if data == "ping":
90
+ await websocket.send_text("pong")
91
+ except asyncio.TimeoutError:
92
+ # Send heartbeat to keep connection alive
93
+ try:
94
+ await websocket.send_json({"type": "heartbeat"})
95
+ except Exception:
96
+ break
97
+ except WebSocketDisconnect:
98
+ pass
99
+ finally:
100
+ await manager.disconnect(websocket, run_id, step_name)
101
+
102
+
103
+ @router.websocket("/ws/runs/{run_id}/steps/{step_name}/logs")
104
+ async def websocket_step_logs(websocket: WebSocket, run_id: str, step_name: str):
105
+ """WebSocket endpoint for streaming logs of a specific step."""
106
+ await manager.connect(websocket, run_id, step_name)
107
+ try:
108
+ while True:
109
+ try:
110
+ data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
111
+ if data == "ping":
112
+ await websocket.send_text("pong")
113
+ except asyncio.TimeoutError:
114
+ try:
115
+ await websocket.send_json({"type": "heartbeat"})
116
+ except Exception:
117
+ break
118
+ except WebSocketDisconnect:
119
+ pass
120
+ finally:
121
+ await manager.disconnect(websocket, run_id, step_name)