aegra-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. aegra_api/__init__.py +3 -0
  2. aegra_api/api/__init__.py +1 -0
  3. aegra_api/api/assistants.py +235 -0
  4. aegra_api/api/runs.py +1110 -0
  5. aegra_api/api/store.py +200 -0
  6. aegra_api/api/threads.py +761 -0
  7. aegra_api/config.py +204 -0
  8. aegra_api/constants.py +5 -0
  9. aegra_api/core/__init__.py +0 -0
  10. aegra_api/core/app_loader.py +91 -0
  11. aegra_api/core/auth_ctx.py +65 -0
  12. aegra_api/core/auth_deps.py +186 -0
  13. aegra_api/core/auth_handlers.py +248 -0
  14. aegra_api/core/auth_middleware.py +331 -0
  15. aegra_api/core/database.py +123 -0
  16. aegra_api/core/health.py +131 -0
  17. aegra_api/core/orm.py +165 -0
  18. aegra_api/core/route_merger.py +69 -0
  19. aegra_api/core/serializers/__init__.py +7 -0
  20. aegra_api/core/serializers/base.py +22 -0
  21. aegra_api/core/serializers/general.py +54 -0
  22. aegra_api/core/serializers/langgraph.py +102 -0
  23. aegra_api/core/sse.py +178 -0
  24. aegra_api/main.py +303 -0
  25. aegra_api/middleware/__init__.py +4 -0
  26. aegra_api/middleware/double_encoded_json.py +74 -0
  27. aegra_api/middleware/logger_middleware.py +95 -0
  28. aegra_api/models/__init__.py +76 -0
  29. aegra_api/models/assistants.py +81 -0
  30. aegra_api/models/auth.py +62 -0
  31. aegra_api/models/enums.py +29 -0
  32. aegra_api/models/errors.py +29 -0
  33. aegra_api/models/runs.py +124 -0
  34. aegra_api/models/store.py +67 -0
  35. aegra_api/models/threads.py +152 -0
  36. aegra_api/observability/__init__.py +1 -0
  37. aegra_api/observability/base.py +88 -0
  38. aegra_api/observability/otel.py +133 -0
  39. aegra_api/observability/setup.py +27 -0
  40. aegra_api/observability/targets/__init__.py +11 -0
  41. aegra_api/observability/targets/base.py +18 -0
  42. aegra_api/observability/targets/langfuse.py +33 -0
  43. aegra_api/observability/targets/otlp.py +38 -0
  44. aegra_api/observability/targets/phoenix.py +24 -0
  45. aegra_api/services/__init__.py +0 -0
  46. aegra_api/services/assistant_service.py +569 -0
  47. aegra_api/services/base_broker.py +59 -0
  48. aegra_api/services/broker.py +141 -0
  49. aegra_api/services/event_converter.py +157 -0
  50. aegra_api/services/event_store.py +196 -0
  51. aegra_api/services/graph_streaming.py +433 -0
  52. aegra_api/services/langgraph_service.py +456 -0
  53. aegra_api/services/streaming_service.py +362 -0
  54. aegra_api/services/thread_state_service.py +128 -0
  55. aegra_api/settings.py +124 -0
  56. aegra_api/utils/__init__.py +3 -0
  57. aegra_api/utils/assistants.py +23 -0
  58. aegra_api/utils/run_utils.py +60 -0
  59. aegra_api/utils/setup_logging.py +122 -0
  60. aegra_api/utils/sse_utils.py +26 -0
  61. aegra_api/utils/status_compat.py +57 -0
  62. aegra_api-0.1.0.dist-info/METADATA +244 -0
  63. aegra_api-0.1.0.dist-info/RECORD +64 -0
  64. aegra_api-0.1.0.dist-info/WHEEL +4 -0
aegra_api/api/runs.py ADDED
@@ -0,0 +1,1110 @@
1
+ """Run endpoints for Agent Protocol"""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ from collections.abc import AsyncIterator
6
+ from datetime import UTC, datetime
7
+ from typing import Any
8
+ from uuid import uuid4
9
+
10
+ import structlog
11
+ from fastapi import APIRouter, Depends, Header, HTTPException, Query
12
+ from fastapi.responses import StreamingResponse
13
+ from langgraph.types import Command, Send
14
+ from sqlalchemy import delete, select, update
15
+ from sqlalchemy.ext.asyncio import AsyncSession
16
+
17
+ from aegra_api.core.auth_ctx import with_auth_ctx
18
+ from aegra_api.core.auth_deps import get_current_user
19
+ from aegra_api.core.auth_handlers import build_auth_context, handle_event
20
+ from aegra_api.core.orm import Assistant as AssistantORM
21
+ from aegra_api.core.orm import Run as RunORM
22
+ from aegra_api.core.orm import Thread as ThreadORM
23
+ from aegra_api.core.orm import _get_session_maker, get_session
24
+ from aegra_api.core.serializers import GeneralSerializer
25
+ from aegra_api.core.sse import create_end_event, get_sse_headers
26
+ from aegra_api.models import Run, RunCreate, RunStatus, User
27
+ from aegra_api.services.broker import broker_manager
28
+ from aegra_api.services.graph_streaming import stream_graph_events
29
+ from aegra_api.services.langgraph_service import create_run_config, get_langgraph_service
30
+ from aegra_api.services.streaming_service import streaming_service
31
+ from aegra_api.utils.assistants import resolve_assistant_id
32
+ from aegra_api.utils.run_utils import (
33
+ _merge_jsonb,
34
+ )
35
+ from aegra_api.utils.status_compat import validate_run_status
36
+
37
+ router = APIRouter(tags=["Runs"])
38
+
39
+ logger = structlog.getLogger(__name__)
40
+ serializer = GeneralSerializer()
41
+
42
+
43
+ # NOTE: We keep only an in-memory task registry for asyncio.Task handles.
44
+ # All run metadata/state is persisted via ORM.
45
+ active_runs: dict[str, asyncio.Task] = {}
46
+
47
+ # Default stream modes for background run execution
48
+ DEFAULT_STREAM_MODES = ["values"]
49
+
50
+
51
+ def map_command_to_langgraph(cmd: dict[str, Any]) -> Command:
52
+ """Convert API command to LangGraph Command"""
53
+ goto = cmd.get("goto")
54
+ if goto is not None and not isinstance(goto, list):
55
+ goto = [goto]
56
+
57
+ update = cmd.get("update")
58
+ if isinstance(update, (tuple, list)) and all(
59
+ isinstance(t, (tuple, list)) and len(t) == 2 and isinstance(t[0], str) for t in update
60
+ ):
61
+ update = [tuple(t) for t in update]
62
+
63
+ return Command(
64
+ update=update,
65
+ goto=([it if isinstance(it, str) else Send(it["node"], it["input"]) for it in goto] if goto else None),
66
+ resume=cmd.get("resume"),
67
+ )
68
+
69
+
70
+ async def set_thread_status(session: AsyncSession, thread_id: str, status: str) -> None:
71
+ """Update the status column of a thread.
72
+
73
+ Status is validated to ensure it conforms to API specification.
74
+ """
75
+ # Validate status conforms to API specification
76
+ from aegra_api.utils.status_compat import validate_thread_status
77
+
78
+ validated_status = validate_thread_status(status)
79
+ result = await session.execute(
80
+ update(ThreadORM)
81
+ .where(ThreadORM.thread_id == thread_id)
82
+ .values(status=validated_status, updated_at=datetime.now(UTC))
83
+ )
84
+ await session.commit()
85
+
86
+ # Verify thread was updated (matching row exists)
87
+ if result.rowcount == 0:
88
+ raise HTTPException(404, f"Thread '{thread_id}' not found")
89
+
90
+
91
+ async def update_thread_metadata(
92
+ session: AsyncSession,
93
+ thread_id: str,
94
+ assistant_id: str,
95
+ graph_id: str,
96
+ user_id: str | None = None,
97
+ ) -> None:
98
+ """Update thread metadata with assistant and graph information (dialect agnostic).
99
+
100
+ If thread doesn't exist, auto-creates it.
101
+ """
102
+ # Read-modify-write to avoid DB-specific JSON concat operators
103
+ thread = await session.scalar(select(ThreadORM).where(ThreadORM.thread_id == thread_id))
104
+
105
+ if not thread:
106
+ # Auto-create thread if it doesn't exist
107
+ if not user_id:
108
+ raise HTTPException(400, "Cannot auto-create thread: user_id is required")
109
+
110
+ metadata = {
111
+ "owner": user_id,
112
+ "assistant_id": str(assistant_id),
113
+ "graph_id": graph_id,
114
+ "thread_name": "",
115
+ }
116
+
117
+ thread_orm = ThreadORM(
118
+ thread_id=thread_id,
119
+ status="idle",
120
+ metadata_json=metadata,
121
+ user_id=user_id,
122
+ )
123
+ session.add(thread_orm)
124
+ await session.commit()
125
+ return
126
+
127
+ md = dict(getattr(thread, "metadata_json", {}) or {})
128
+ md.update(
129
+ {
130
+ "assistant_id": str(assistant_id),
131
+ "graph_id": graph_id,
132
+ }
133
+ )
134
+ await session.execute(
135
+ update(ThreadORM).where(ThreadORM.thread_id == thread_id).values(metadata_json=md, updated_at=datetime.now(UTC))
136
+ )
137
+ await session.commit()
138
+
139
+
140
+ async def _validate_resume_command(session: AsyncSession, thread_id: str, command: dict[str, Any] | None) -> None:
141
+ """Validate resume command requirements."""
142
+ if command and command.get("resume") is not None:
143
+ # Check if thread exists and is in interrupted state
144
+ thread_stmt = select(ThreadORM).where(ThreadORM.thread_id == thread_id)
145
+ thread = await session.scalar(thread_stmt)
146
+ if not thread:
147
+ raise HTTPException(404, f"Thread '{thread_id}' not found")
148
+ if thread.status != "interrupted":
149
+ raise HTTPException(400, "Cannot resume: thread is not in interrupted state")
150
+
151
+
152
+ @router.post("/threads/{thread_id}/runs", response_model=Run)
153
+ async def create_run(
154
+ thread_id: str,
155
+ request: RunCreate,
156
+ user: User = Depends(get_current_user),
157
+ session: AsyncSession = Depends(get_session),
158
+ ) -> Run:
159
+ """Create and execute a new run (persisted)."""
160
+ # Authorization check (create_run action on threads resource)
161
+ ctx = build_auth_context(user, "threads", "create_run")
162
+ value = {**request.model_dump(), "thread_id": thread_id}
163
+ filters = await handle_event(ctx, value)
164
+
165
+ # If handler modified config/context, update request
166
+ if filters:
167
+ if "config" in filters:
168
+ request.config = {**(request.config or {}), **filters["config"]}
169
+ if "context" in filters:
170
+ request.context = {**(request.context or {}), **filters["context"]}
171
+ elif value.get("config"):
172
+ request.config = {**(request.config or {}), **value["config"]}
173
+ elif value.get("context"):
174
+ request.context = {**(request.context or {}), **value["context"]}
175
+
176
+ # Validate resume command requirements early
177
+ await _validate_resume_command(session, thread_id, request.command)
178
+
179
+ run_id = str(uuid4())
180
+
181
+ # Get LangGraph service
182
+ langgraph_service = get_langgraph_service()
183
+ logger.info(f"[create_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}")
184
+
185
+ # Validate assistant exists and get its graph_id. If a graph_id was provided
186
+ # instead of an assistant UUID, map it deterministically and fall back to the
187
+ # default assistant created at startup.
188
+ requested_id = str(request.assistant_id)
189
+ available_graphs = langgraph_service.list_graphs()
190
+ resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
191
+
192
+ config = request.config
193
+ context = request.context
194
+ configurable = config.get("configurable", {})
195
+
196
+ if config.get("configurable") and context:
197
+ raise HTTPException(
198
+ status_code=400,
199
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
200
+ )
201
+
202
+ if context:
203
+ configurable = context.copy()
204
+ config["configurable"] = configurable
205
+ else:
206
+ context = configurable.copy()
207
+
208
+ assistant_stmt = select(AssistantORM).where(
209
+ AssistantORM.assistant_id == resolved_assistant_id,
210
+ )
211
+ assistant = await session.scalar(assistant_stmt)
212
+ if not assistant:
213
+ raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
214
+
215
+ config = _merge_jsonb(assistant.config, config)
216
+ context = _merge_jsonb(assistant.context, context)
217
+
218
+ # Validate the assistant's graph exists
219
+ available_graphs = langgraph_service.list_graphs()
220
+ if assistant.graph_id not in available_graphs:
221
+ raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
222
+
223
+ # Mark thread as busy and update metadata with assistant/graph info
224
+ # update_thread_metadata will auto-create thread if it doesn't exist
225
+ await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
226
+ await set_thread_status(session, thread_id, "busy")
227
+
228
+ # Persist run record via ORM model in core.orm (Run table)
229
+ now = datetime.now(UTC)
230
+ run_orm = RunORM(
231
+ run_id=run_id, # explicitly set (DB can also default-generate if omitted)
232
+ thread_id=thread_id,
233
+ assistant_id=resolved_assistant_id,
234
+ status="pending",
235
+ input=request.input or {},
236
+ config=config,
237
+ context=context,
238
+ user_id=user.identity,
239
+ created_at=now,
240
+ updated_at=now,
241
+ output=None,
242
+ error_message=None,
243
+ )
244
+ session.add(run_orm)
245
+ await session.commit()
246
+
247
+ # Build response from ORM -> Pydantic
248
+ run = Run.model_validate(run_orm)
249
+
250
+ # Start execution asynchronously
251
+ # Don't pass the session to avoid transaction conflicts
252
+ task = asyncio.create_task(
253
+ execute_run_async(
254
+ run_id,
255
+ thread_id,
256
+ assistant.graph_id,
257
+ request.input or {},
258
+ user,
259
+ config,
260
+ context,
261
+ request.stream_mode,
262
+ None, # Don't pass session to avoid conflicts
263
+ request.checkpoint,
264
+ request.command,
265
+ request.interrupt_before,
266
+ request.interrupt_after,
267
+ request.multitask_strategy,
268
+ request.stream_subgraphs,
269
+ )
270
+ )
271
+ logger.info(f"[create_run] background task created task_id={id(task)} for run_id={run_id}")
272
+ active_runs[run_id] = task
273
+
274
+ return run
275
+
276
+
277
+ @router.post("/threads/{thread_id}/runs/stream")
278
+ async def create_and_stream_run(
279
+ thread_id: str,
280
+ request: RunCreate,
281
+ user: User = Depends(get_current_user),
282
+ session: AsyncSession = Depends(get_session),
283
+ ) -> StreamingResponse:
284
+ """Create a new run and stream its execution - persisted + SSE."""
285
+
286
+ # Validate resume command requirements early
287
+ await _validate_resume_command(session, thread_id, request.command)
288
+
289
+ run_id = str(uuid4())
290
+
291
+ # Get LangGraph service
292
+ langgraph_service = get_langgraph_service()
293
+ logger.info(
294
+ f"[create_and_stream_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}"
295
+ )
296
+
297
+ # Validate assistant exists and get its graph_id. Allow passing a graph_id
298
+ # by mapping it to a deterministic assistant ID.
299
+ requested_id = str(request.assistant_id)
300
+ available_graphs = langgraph_service.list_graphs()
301
+
302
+ resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
303
+
304
+ config = request.config
305
+ context = request.context
306
+ configurable = config.get("configurable", {})
307
+
308
+ if config.get("configurable") and context:
309
+ raise HTTPException(
310
+ status_code=400,
311
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
312
+ )
313
+
314
+ if context:
315
+ configurable = context.copy()
316
+ config["configurable"] = configurable
317
+ else:
318
+ context = configurable.copy()
319
+
320
+ assistant_stmt = select(AssistantORM).where(
321
+ AssistantORM.assistant_id == resolved_assistant_id,
322
+ )
323
+ assistant = await session.scalar(assistant_stmt)
324
+ if not assistant:
325
+ raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
326
+
327
+ config = _merge_jsonb(assistant.config, config)
328
+ context = _merge_jsonb(assistant.context, context)
329
+
330
+ # Validate the assistant's graph exists
331
+ available_graphs = langgraph_service.list_graphs()
332
+ if assistant.graph_id not in available_graphs:
333
+ raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
334
+
335
+ # Mark thread as busy and update metadata with assistant/graph info
336
+ # update_thread_metadata will auto-create thread if it doesn't exist
337
+ await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
338
+ await set_thread_status(session, thread_id, "busy")
339
+
340
+ # Persist run record
341
+ now = datetime.now(UTC)
342
+ run_orm = RunORM(
343
+ run_id=run_id,
344
+ thread_id=thread_id,
345
+ assistant_id=resolved_assistant_id,
346
+ status="running",
347
+ input=request.input or {},
348
+ config=config,
349
+ context=context,
350
+ user_id=user.identity,
351
+ created_at=now,
352
+ updated_at=now,
353
+ output=None,
354
+ error_message=None,
355
+ )
356
+ session.add(run_orm)
357
+ await session.commit()
358
+
359
+ # Build response model for stream context
360
+ run = Run.model_validate(run_orm)
361
+
362
+ # Start background execution that will populate the broker
363
+ # Don't pass the session to avoid transaction conflicts
364
+ task = asyncio.create_task(
365
+ execute_run_async(
366
+ run_id,
367
+ thread_id,
368
+ assistant.graph_id,
369
+ request.input or {},
370
+ user,
371
+ config,
372
+ context,
373
+ request.stream_mode,
374
+ None, # Don't pass session to avoid conflicts
375
+ request.checkpoint,
376
+ request.command,
377
+ request.interrupt_before,
378
+ request.interrupt_after,
379
+ request.multitask_strategy,
380
+ request.stream_subgraphs,
381
+ )
382
+ )
383
+ logger.info(f"[create_and_stream_run] background task created task_id={id(task)} for run_id={run_id}")
384
+ active_runs[run_id] = task
385
+
386
+ # Extract requested stream mode(s)
387
+ stream_mode = request.stream_mode
388
+ if not stream_mode and config and "stream_mode" in config:
389
+ stream_mode = config["stream_mode"]
390
+
391
+ # Stream immediately from broker (which will also include replay of any early events)
392
+ # Default to cancel on disconnect - this matches user expectation that clicking
393
+ # "Cancel" in the frontend will stop the backend task. Users can explicitly
394
+ # set on_disconnect="continue" if they want the task to continue.
395
+ cancel_on_disconnect = (request.on_disconnect or "cancel").lower() == "cancel"
396
+
397
+ return StreamingResponse(
398
+ streaming_service.stream_run_execution(
399
+ run,
400
+ None,
401
+ cancel_on_disconnect=cancel_on_disconnect,
402
+ ),
403
+ media_type="text/event-stream",
404
+ headers={
405
+ **get_sse_headers(),
406
+ "Location": f"/threads/{thread_id}/runs/{run_id}/stream",
407
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
408
+ },
409
+ )
410
+
411
+
412
+ @router.get("/threads/{thread_id}/runs/{run_id}", response_model=Run)
413
+ async def get_run(
414
+ thread_id: str,
415
+ run_id: str,
416
+ user: User = Depends(get_current_user),
417
+ session: AsyncSession = Depends(get_session),
418
+ ) -> Run:
419
+ """Get run by ID (persisted)."""
420
+ # Authorization check (read action on runs resource)
421
+ ctx = build_auth_context(user, "runs", "read")
422
+ value = {"run_id": run_id, "thread_id": thread_id}
423
+ await handle_event(ctx, value)
424
+
425
+ stmt = select(RunORM).where(
426
+ RunORM.run_id == str(run_id),
427
+ RunORM.thread_id == thread_id,
428
+ RunORM.user_id == user.identity,
429
+ )
430
+ logger.info(f"[get_run] querying DB run_id={run_id} thread_id={thread_id} user={user.identity}")
431
+ run_orm = await session.scalar(stmt)
432
+ if not run_orm:
433
+ raise HTTPException(404, f"Run '{run_id}' not found")
434
+
435
+ # Refresh to ensure we have the latest data (in case background task updated it)
436
+ await session.refresh(run_orm)
437
+
438
+ logger.info(
439
+ f"[get_run] found run status={run_orm.status} user={user.identity} thread_id={thread_id} run_id={run_id}"
440
+ )
441
+ # Convert to Pydantic
442
+ return Run.model_validate(run_orm)
443
+
444
+
445
+ @router.get("/threads/{thread_id}/runs", response_model=list[Run])
446
+ async def list_runs(
447
+ thread_id: str,
448
+ limit: int = Query(10, ge=1, description="Maximum number of runs to return"),
449
+ offset: int = Query(0, ge=0, description="Number of runs to skip for pagination"),
450
+ status: str | None = Query(None, description="Filter by run status"),
451
+ user: User = Depends(get_current_user),
452
+ session: AsyncSession = Depends(get_session),
453
+ ) -> list[Run]:
454
+ """List runs for a specific thread (persisted)."""
455
+ stmt = (
456
+ select(RunORM)
457
+ .where(
458
+ RunORM.thread_id == thread_id,
459
+ RunORM.user_id == user.identity,
460
+ *([RunORM.status == status] if status else []),
461
+ )
462
+ .limit(limit)
463
+ .offset(offset)
464
+ .order_by(RunORM.created_at.desc())
465
+ )
466
+ logger.info(f"[list_runs] querying DB thread_id={thread_id} user={user.identity}")
467
+ result = await session.scalars(stmt)
468
+ rows = result.all()
469
+ runs = [Run.model_validate(r) for r in rows]
470
+ logger.info(f"[list_runs] total={len(runs)} user={user.identity} thread_id={thread_id}")
471
+ return runs
472
+
473
+
474
+ @router.patch("/threads/{thread_id}/runs/{run_id}")
475
+ async def update_run(
476
+ thread_id: str,
477
+ run_id: str,
478
+ request: RunStatus,
479
+ user: User = Depends(get_current_user),
480
+ session: AsyncSession = Depends(get_session),
481
+ ) -> Run:
482
+ """Update run status (for cancellation/interruption, persisted)."""
483
+ logger.info(f"[update_run] fetch for update run_id={run_id} thread_id={thread_id} user={user.identity}")
484
+ run_orm = await session.scalar(
485
+ select(RunORM).where(
486
+ RunORM.run_id == str(run_id),
487
+ RunORM.thread_id == thread_id,
488
+ RunORM.user_id == user.identity,
489
+ )
490
+ )
491
+ if not run_orm:
492
+ raise HTTPException(404, f"Run '{run_id}' not found")
493
+
494
+ # Handle interruption/cancellation
495
+ # Validate status conforms to API specification
496
+ validated_status = validate_run_status(request.status)
497
+
498
+ if validated_status == "interrupted":
499
+ logger.info(f"[update_run] cancelling/interrupting run_id={run_id} user={user.identity} thread_id={thread_id}")
500
+ # Handle interruption - use interrupt_run for cooperative interruption
501
+ await streaming_service.interrupt_run(run_id)
502
+ logger.info(f"[update_run] set DB status=interrupted run_id={run_id}")
503
+ await session.execute(
504
+ update(RunORM)
505
+ .where(RunORM.run_id == str(run_id))
506
+ .values(status="interrupted", updated_at=datetime.now(UTC))
507
+ )
508
+ await session.commit()
509
+ logger.info(f"[update_run] commit done (interrupted) run_id={run_id}")
510
+
511
+ # Return final run state
512
+ run_orm = await session.scalar(select(RunORM).where(RunORM.run_id == run_id))
513
+ if run_orm:
514
+ # Refresh to ensure we have the latest data after our own update
515
+ await session.refresh(run_orm)
516
+ return Run.model_validate(run_orm)
517
+
518
+
519
+ @router.get("/threads/{thread_id}/runs/{run_id}/join")
520
+ async def join_run(
521
+ thread_id: str,
522
+ run_id: str,
523
+ user: User = Depends(get_current_user),
524
+ session: AsyncSession = Depends(get_session),
525
+ ) -> dict[str, Any]:
526
+ """Join a run (wait for completion and return final output) - persisted."""
527
+ # Get run and validate it exists
528
+ run_orm = await session.scalar(
529
+ select(RunORM).where(
530
+ RunORM.run_id == str(run_id),
531
+ RunORM.thread_id == thread_id,
532
+ RunORM.user_id == user.identity,
533
+ )
534
+ )
535
+ if not run_orm:
536
+ raise HTTPException(404, f"Run '{run_id}' not found")
537
+
538
+ # If already completed, return output immediately
539
+ # Check if run is in a terminal state
540
+ terminal_states = ["success", "error", "interrupted"]
541
+ if run_orm.status in terminal_states:
542
+ # Refresh to ensure we have the latest data
543
+ await session.refresh(run_orm)
544
+ output = getattr(run_orm, "output", None) or {}
545
+ return output
546
+
547
+ # Wait for background task to complete
548
+ task = active_runs.get(run_id)
549
+ if task:
550
+ try:
551
+ await asyncio.wait_for(task, timeout=30.0)
552
+ except TimeoutError:
553
+ # Task is taking too long, but that's okay - we'll check DB status
554
+ pass
555
+ except asyncio.CancelledError:
556
+ # Task was cancelled, that's also okay
557
+ pass
558
+
559
+ # Return final output from database
560
+ run_orm = await session.scalar(select(RunORM).where(RunORM.run_id == run_id))
561
+ if run_orm:
562
+ await session.refresh(run_orm) # Refresh to get latest data from DB
563
+ output = getattr(run_orm, "output", None) or {}
564
+ return output
565
+
566
+
567
+ @router.post("/threads/{thread_id}/runs/wait")
568
+ async def wait_for_run(
569
+ thread_id: str,
570
+ request: RunCreate,
571
+ user: User = Depends(get_current_user),
572
+ session: AsyncSession = Depends(get_session),
573
+ ) -> dict[str, Any]:
574
+ """Create a run, execute it, and wait for completion (Agent Protocol).
575
+
576
+ This endpoint combines run creation and execution with synchronous waiting.
577
+ Returns the final output directly (not the Run object).
578
+
579
+ Compatible with LangGraph SDK's runs.wait() method and Agent Protocol spec.
580
+ """
581
+ # Validate resume command requirements early
582
+ await _validate_resume_command(session, thread_id, request.command)
583
+
584
+ run_id = str(uuid4())
585
+
586
+ # Get LangGraph service
587
+ langgraph_service = get_langgraph_service()
588
+ logger.info(f"[wait_for_run] creating run run_id={run_id} thread_id={thread_id} user={user.identity}")
589
+
590
+ # Validate assistant exists and get its graph_id
591
+ requested_id = str(request.assistant_id)
592
+ available_graphs = langgraph_service.list_graphs()
593
+ resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
594
+
595
+ config = request.config
596
+ context = request.context
597
+ configurable = config.get("configurable", {})
598
+
599
+ if config.get("configurable") and context:
600
+ raise HTTPException(
601
+ status_code=400,
602
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
603
+ )
604
+
605
+ if context:
606
+ configurable = context.copy()
607
+ config["configurable"] = configurable
608
+ else:
609
+ context = configurable.copy()
610
+
611
+ assistant_stmt = select(AssistantORM).where(
612
+ AssistantORM.assistant_id == resolved_assistant_id,
613
+ )
614
+ assistant = await session.scalar(assistant_stmt)
615
+ if not assistant:
616
+ raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
617
+
618
+ config = _merge_jsonb(assistant.config, config)
619
+ context = _merge_jsonb(assistant.context, context)
620
+
621
+ # Validate the assistant's graph exists
622
+ available_graphs = langgraph_service.list_graphs()
623
+ if assistant.graph_id not in available_graphs:
624
+ raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
625
+
626
+ # Mark thread as busy and update metadata with assistant/graph info
627
+ # update_thread_metadata will auto-create thread if it doesn't exist
628
+ await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
629
+ await set_thread_status(session, thread_id, "busy")
630
+
631
+ # Persist run record
632
+ now = datetime.now(UTC)
633
+ run_orm = RunORM(
634
+ run_id=run_id,
635
+ thread_id=thread_id,
636
+ assistant_id=resolved_assistant_id,
637
+ status="pending",
638
+ input=request.input or {},
639
+ config=config,
640
+ context=context,
641
+ user_id=user.identity,
642
+ created_at=now,
643
+ updated_at=now,
644
+ output=None,
645
+ error_message=None,
646
+ )
647
+ session.add(run_orm)
648
+ await session.commit()
649
+
650
+ # Start execution asynchronously
651
+ task = asyncio.create_task(
652
+ execute_run_async(
653
+ run_id,
654
+ thread_id,
655
+ assistant.graph_id,
656
+ request.input or {},
657
+ user,
658
+ config,
659
+ context,
660
+ request.stream_mode,
661
+ None, # Don't pass session to avoid conflicts
662
+ request.checkpoint,
663
+ request.command,
664
+ request.interrupt_before,
665
+ request.interrupt_after,
666
+ request.multitask_strategy,
667
+ request.stream_subgraphs,
668
+ )
669
+ )
670
+ logger.info(f"[wait_for_run] background task created task_id={id(task)} for run_id={run_id}")
671
+ active_runs[run_id] = task
672
+
673
+ # Wait for task to complete with timeout
674
+ try:
675
+ await asyncio.wait_for(task, timeout=300.0) # 5 minute timeout
676
+ except TimeoutError:
677
+ logger.warning(f"[wait_for_run] timeout waiting for run_id={run_id}")
678
+ # Don't raise, just return current state
679
+ except asyncio.CancelledError:
680
+ logger.info(f"[wait_for_run] cancelled run_id={run_id}")
681
+ # Task was cancelled, continue to return final state
682
+ except Exception as e:
683
+ logger.error(f"[wait_for_run] exception in run_id={run_id}: {e}")
684
+ # Exception already handled by execute_run_async
685
+
686
+ # Get final output from database
687
+ run_orm = await session.scalar(
688
+ select(RunORM).where(
689
+ RunORM.run_id == run_id,
690
+ RunORM.thread_id == thread_id,
691
+ RunORM.user_id == user.identity,
692
+ )
693
+ )
694
+ if not run_orm:
695
+ raise HTTPException(500, f"Run '{run_id}' disappeared during execution")
696
+
697
+ await session.refresh(run_orm)
698
+
699
+ # Return output based on final status
700
+ if run_orm.status == "success":
701
+ return run_orm.output or {}
702
+ elif run_orm.status == "error":
703
+ # For error runs, still return output if available, but log the error
704
+ logger.error(f"[wait_for_run] run failed run_id={run_id} error={run_orm.error_message}")
705
+ return run_orm.output or {}
706
+ elif run_orm.status == "interrupted":
707
+ # Return partial output for interrupted runs
708
+ return run_orm.output or {}
709
+ else:
710
+ # Still pending/running after timeout
711
+ return run_orm.output or {}
712
+
713
+
714
+ # TODO: check if this method is actually required because the implementation doesn't seem correct.
715
+ @router.get("/threads/{thread_id}/runs/{run_id}/stream")
716
+ async def stream_run(
717
+ thread_id: str,
718
+ run_id: str,
719
+ last_event_id: str | None = Header(None, alias="Last-Event-ID"),
720
+ _stream_mode: str | None = Query(None),
721
+ user: User = Depends(get_current_user),
722
+ session: AsyncSession = Depends(get_session),
723
+ ) -> StreamingResponse:
724
+ """Stream run execution with SSE and reconnection support - persisted metadata."""
725
+ logger.info(f"[stream_run] fetch for stream run_id={run_id} thread_id={thread_id} user={user.identity}")
726
+ run_orm = await session.scalar(
727
+ select(RunORM).where(
728
+ RunORM.run_id == str(run_id),
729
+ RunORM.thread_id == thread_id,
730
+ RunORM.user_id == user.identity,
731
+ )
732
+ )
733
+ if not run_orm:
734
+ raise HTTPException(404, f"Run '{run_id}' not found")
735
+
736
+ logger.info(f"[stream_run] status={run_orm.status} user={user.identity} thread_id={thread_id} run_id={run_id}")
737
+ # If already terminal, emit a final end event
738
+ terminal_states = ["success", "error", "interrupted"]
739
+ if run_orm.status in terminal_states:
740
+
741
+ async def generate_final() -> AsyncIterator[str]:
742
+ yield create_end_event()
743
+
744
+ logger.info(f"[stream_run] starting terminal stream run_id={run_id} status={run_orm.status}")
745
+ return StreamingResponse(
746
+ generate_final(),
747
+ media_type="text/event-stream",
748
+ headers={
749
+ **get_sse_headers(),
750
+ "Location": f"/threads/{thread_id}/runs/{run_id}/stream",
751
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
752
+ },
753
+ )
754
+
755
+ # Stream active or pending runs via broker
756
+
757
+ # Build a lightweight Pydantic Run from ORM for streaming context (IDs already strings)
758
+ run_model = Run.model_validate(run_orm)
759
+
760
+ return StreamingResponse(
761
+ streaming_service.stream_run_execution(run_model, last_event_id, cancel_on_disconnect=False),
762
+ media_type="text/event-stream",
763
+ headers={
764
+ **get_sse_headers(),
765
+ "Location": f"/threads/{thread_id}/runs/{run_id}/stream",
766
+ "Content-Location": f"/threads/{thread_id}/runs/{run_id}",
767
+ },
768
+ )
769
+
770
+
771
+ @router.post("/threads/{thread_id}/runs/{run_id}/cancel")
772
+ async def cancel_run_endpoint(
773
+ thread_id: str,
774
+ run_id: str,
775
+ wait: int = Query(0, ge=0, le=1, description="Whether to wait for the run task to settle"),
776
+ action: str = Query("cancel", pattern="^(cancel|interrupt)$", description="Cancellation action"),
777
+ user: User = Depends(get_current_user),
778
+ session: AsyncSession = Depends(get_session),
779
+ ) -> Run:
780
+ """
781
+ Cancel or interrupt a run (client-compatible endpoint).
782
+
783
+ Matches client usage:
784
+ POST /v1/threads/{thread_id}/runs/{run_id}/cancel?wait=0&action=interrupt
785
+
786
+ - action=cancel => hard cancel
787
+ - action=interrupt => cooperative interrupt if supported
788
+ - wait=1 => await background task to finish settling
789
+ """
790
+ logger.info(f"[cancel_run] fetch run run_id={run_id} thread_id={thread_id} user={user.identity}")
791
+ run_orm = await session.scalar(
792
+ select(RunORM).where(
793
+ RunORM.run_id == run_id,
794
+ RunORM.thread_id == thread_id,
795
+ RunORM.user_id == user.identity,
796
+ )
797
+ )
798
+ if not run_orm:
799
+ raise HTTPException(404, f"Run '{run_id}' not found")
800
+
801
+ if action == "interrupt":
802
+ logger.info(f"[cancel_run] interrupt run_id={run_id} user={user.identity} thread_id={thread_id}")
803
+ await streaming_service.interrupt_run(run_id)
804
+ # Persist status as interrupted
805
+ await session.execute(
806
+ update(RunORM)
807
+ .where(RunORM.run_id == str(run_id))
808
+ .values(status="interrupted", updated_at=datetime.now(UTC))
809
+ )
810
+ await session.commit()
811
+ else:
812
+ logger.info(f"[cancel_run] cancel run_id={run_id} user={user.identity} thread_id={thread_id}")
813
+ await streaming_service.cancel_run(run_id)
814
+ # Persist status as interrupted
815
+ await session.execute(
816
+ update(RunORM)
817
+ .where(RunORM.run_id == str(run_id))
818
+ .values(status="interrupted", updated_at=datetime.now(UTC))
819
+ )
820
+ await session.commit()
821
+
822
+ # Optionally wait for background task
823
+ if wait:
824
+ task = active_runs.get(run_id)
825
+ if task:
826
+ with contextlib.suppress(asyncio.CancelledError, Exception):
827
+ await task
828
+
829
+ # Reload and return updated Run (do NOT delete here; deletion is a separate endpoint)
830
+ run_orm = await session.scalar(
831
+ select(RunORM).where(
832
+ RunORM.run_id == run_id,
833
+ RunORM.thread_id == thread_id,
834
+ RunORM.user_id == user.identity,
835
+ )
836
+ )
837
+ if not run_orm:
838
+ raise HTTPException(404, f"Run '{run_id}' not found after cancellation")
839
+ return Run.model_validate(run_orm)
840
+
841
+
842
+ async def execute_run_async(
843
+ run_id: str,
844
+ thread_id: str,
845
+ graph_id: str,
846
+ input_data: dict,
847
+ user: User,
848
+ config: dict | None = None,
849
+ context: dict | None = None,
850
+ stream_mode: list[str] | None = None,
851
+ session: AsyncSession | None = None,
852
+ checkpoint: dict | None = None,
853
+ command: dict[str, Any] | None = None,
854
+ interrupt_before: str | list[str] | None = None,
855
+ interrupt_after: str | list[str] | None = None,
856
+ _multitask_strategy: str | None = None,
857
+ subgraphs: bool | None = False,
858
+ ) -> None:
859
+ """Execute run asynchronously in background using streaming to capture all events""" # Use provided session or get a new one
860
+ if session is None:
861
+ maker = _get_session_maker()
862
+ session = maker()
863
+
864
+ try:
865
+ # Update status
866
+ await update_run_status(run_id, "running", session=session)
867
+
868
+ # Get graph and execute
869
+ langgraph_service = get_langgraph_service()
870
+
871
+ run_config = create_run_config(run_id, thread_id, user, config or {}, checkpoint)
872
+
873
+ # Handle human-in-the-loop fields
874
+ if interrupt_before is not None:
875
+ run_config["interrupt_before"] = (
876
+ interrupt_before if isinstance(interrupt_before, list) else [interrupt_before]
877
+ )
878
+ if interrupt_after is not None:
879
+ run_config["interrupt_after"] = interrupt_after if isinstance(interrupt_after, list) else [interrupt_after]
880
+
881
+ # Note: multitask_strategy is handled at the run creation level, not execution level
882
+ # It controls concurrent run behavior, not graph execution behavior
883
+
884
+ # Determine input for execution (either input_data or command)
885
+ if command is not None:
886
+ # When command is provided, it replaces input entirely
887
+ execution_input = map_command_to_langgraph(command)
888
+ else:
889
+ # No command, use regular input
890
+ execution_input = input_data
891
+
892
+ # Execute using streaming to capture events for later replay
893
+ event_counter = 0
894
+ final_output = None
895
+ has_interrupt = False
896
+
897
+ # Prepare stream modes for execution
898
+ if stream_mode is None:
899
+ stream_mode_list = DEFAULT_STREAM_MODES.copy()
900
+ elif isinstance(stream_mode, str):
901
+ stream_mode_list = [stream_mode]
902
+ else:
903
+ stream_mode_list = stream_mode.copy()
904
+
905
+ async with (
906
+ langgraph_service.get_graph(graph_id) as graph,
907
+ with_auth_ctx(user, []),
908
+ ):
909
+ # Stream events using the graph_streaming service
910
+ try:
911
+ async for event_type, event_data in stream_graph_events(
912
+ graph=graph,
913
+ input_data=execution_input,
914
+ config=run_config,
915
+ stream_mode=stream_mode_list,
916
+ context=context,
917
+ subgraphs=subgraphs,
918
+ on_checkpoint=lambda _: None, # Can add checkpoint handling if needed
919
+ on_task_result=lambda _: None, # Can add task result handling if needed
920
+ ):
921
+ try:
922
+ # Increment event counter
923
+ event_counter += 1
924
+ event_id = f"{run_id}_event_{event_counter}"
925
+
926
+ # Create event tuple for broker/storage
927
+ event_tuple = (event_type, event_data)
928
+
929
+ # Forward to broker for live consumers (already filtered by graph_streaming)
930
+ await streaming_service.put_to_broker(run_id, event_id, event_tuple)
931
+
932
+ # Store for replay (already filtered by graph_streaming)
933
+ await streaming_service.store_event_from_raw(run_id, event_id, event_tuple)
934
+
935
+ # Check for interrupt
936
+ if isinstance(event_data, dict) and "__interrupt__" in event_data:
937
+ has_interrupt = True
938
+
939
+ # Track final output from values events (handles both "values" and "values|namespace")
940
+ if event_type.startswith("values"):
941
+ final_output = event_data
942
+
943
+ except Exception as event_error:
944
+ # Error processing individual event - send error to frontend immediately
945
+ logger.error(f"[execute_run_async] error processing event for run_id={run_id}: {event_error}")
946
+ error_type = type(event_error).__name__
947
+ await streaming_service.signal_run_error(run_id, str(event_error), error_type)
948
+ raise
949
+
950
+ except Exception as stream_error:
951
+ # Error during streaming (e.g., graph execution error)
952
+ # Send error to frontend before re-raising
953
+ logger.error(f"[execute_run_async] streaming error for run_id={run_id}: {stream_error}")
954
+ error_type = type(stream_error).__name__
955
+ await streaming_service.signal_run_error(run_id, str(stream_error), error_type)
956
+ raise
957
+
958
+ if has_interrupt:
959
+ await update_run_status(run_id, "interrupted", output=final_output or {}, session=session)
960
+ if not session:
961
+ raise RuntimeError(f"No database session available to update thread {thread_id} status")
962
+ await set_thread_status(session, thread_id, "interrupted")
963
+
964
+ else:
965
+ # Update with results - use standard status
966
+ await update_run_status(run_id, "success", output=final_output or {}, session=session)
967
+ # Mark thread back to idle
968
+ if not session:
969
+ raise RuntimeError(f"No database session available to update thread {thread_id} status")
970
+ await set_thread_status(session, thread_id, "idle")
971
+
972
+ except asyncio.CancelledError:
973
+ # Store empty output to avoid JSON serialization issues - use standard status
974
+ await update_run_status(run_id, "interrupted", output={}, session=session)
975
+ if not session:
976
+ raise RuntimeError(f"No database session available to update thread {thread_id} status") from None
977
+ await set_thread_status(session, thread_id, "idle")
978
+ # Signal cancellation to broker
979
+ await streaming_service.signal_run_cancelled(run_id)
980
+ raise
981
+ except Exception as e:
982
+ # Store empty output to avoid JSON serialization issues - use standard status
983
+ await update_run_status(run_id, "error", output={}, error=str(e), session=session)
984
+ if not session:
985
+ raise RuntimeError(f"No database session available to update thread {thread_id} status") from None
986
+ # Set thread status to "error" when run fails (matches API specification)
987
+ await set_thread_status(session, thread_id, "error")
988
+ # Note: Error event already sent to broker in inner exception handler
989
+ # Only signal if broker still exists (cleanup not yet called)
990
+ broker = broker_manager.get_broker(run_id)
991
+ if broker and not broker.is_finished():
992
+ error_type = type(e).__name__
993
+ await streaming_service.signal_run_error(run_id, str(e), error_type)
994
+ raise
995
+ finally:
996
+ # Clean up broker
997
+ await streaming_service.cleanup_run(run_id)
998
+ active_runs.pop(run_id, None)
999
+
1000
+
1001
+ async def update_run_status(
1002
+ run_id: str,
1003
+ status: str,
1004
+ output: Any = None,
1005
+ error: str | None = None,
1006
+ session: AsyncSession | None = None,
1007
+ ) -> None:
1008
+ """Update run status in database (persisted). If session not provided, opens a short-lived session.
1009
+
1010
+ Status is validated to ensure it conforms to API specification.
1011
+ """
1012
+ # Validate status conforms to API specification
1013
+ validated_status = validate_run_status(status)
1014
+
1015
+ owns_session = False
1016
+ if session is None:
1017
+ maker = _get_session_maker()
1018
+ session = maker() # type: ignore[assignment]
1019
+ owns_session = True
1020
+ try:
1021
+ values = {"status": validated_status, "updated_at": datetime.now(UTC)}
1022
+ if output is not None:
1023
+ # Serialize output to ensure JSON compatibility
1024
+ try:
1025
+ serialized_output = serializer.serialize(output)
1026
+ values["output"] = serialized_output
1027
+ except Exception as e:
1028
+ logger.warning(f"Failed to serialize output for run {run_id}: {e}")
1029
+ values["output"] = {
1030
+ "error": "Output serialization failed",
1031
+ "original_type": str(type(output)),
1032
+ }
1033
+ if error is not None:
1034
+ values["error_message"] = error
1035
+ logger.info(f"[update_run_status] updating DB run_id={run_id} status={validated_status}")
1036
+ await session.execute(update(RunORM).where(RunORM.run_id == str(run_id)).values(**values)) # type: ignore[arg-type]
1037
+ await session.commit()
1038
+ logger.info(f"[update_run_status] commit done run_id={run_id}")
1039
+ finally:
1040
+ # Close only if we created it here
1041
+ if owns_session:
1042
+ await session.close() # type: ignore[func-returns-value]
1043
+
1044
+
1045
+ @router.delete("/threads/{thread_id}/runs/{run_id}", status_code=204)
1046
+ async def delete_run(
1047
+ thread_id: str,
1048
+ run_id: str,
1049
+ force: int = Query(0, ge=0, le=1, description="Force cancel active run before delete (1=yes)"),
1050
+ user: User = Depends(get_current_user),
1051
+ session: AsyncSession = Depends(get_session),
1052
+ ) -> None:
1053
+ """Delete run by ID"""
1054
+ # Authorization check (delete action on runs resource)
1055
+ ctx = build_auth_context(user, "runs", "delete")
1056
+ value = {"run_id": run_id, "thread_id": thread_id}
1057
+ await handle_event(ctx, value)
1058
+ """
1059
+ Delete a run record.
1060
+
1061
+ Behavior:
1062
+ - If the run is active (pending/running/streaming) and force=0, returns 409 Conflict.
1063
+ - If force=1 and the run is active, cancels it first (best-effort) and then deletes.
1064
+ - Always returns 204 No Content on successful deletion.
1065
+ """
1066
+ logger.info(f"[delete_run] fetch run run_id={run_id} thread_id={thread_id} user={user.identity}")
1067
+ run_orm = await session.scalar(
1068
+ select(RunORM).where(
1069
+ RunORM.run_id == str(run_id),
1070
+ RunORM.thread_id == thread_id,
1071
+ RunORM.user_id == user.identity,
1072
+ )
1073
+ )
1074
+ if not run_orm:
1075
+ raise HTTPException(404, f"Run '{run_id}' not found")
1076
+
1077
+ # If active and not forcing, reject deletion
1078
+ if run_orm.status in ["pending", "running"] and not force:
1079
+ raise HTTPException(
1080
+ status_code=409,
1081
+ detail="Run is active. Retry with force=1 to cancel and delete.",
1082
+ )
1083
+
1084
+ # If forcing and active, cancel first
1085
+ if force and run_orm.status in ["pending", "running"]:
1086
+ logger.info(f"[delete_run] force-cancelling active run run_id={run_id}")
1087
+ await streaming_service.cancel_run(run_id)
1088
+ # Best-effort: wait for bg task to settle
1089
+ task = active_runs.get(run_id)
1090
+ if task:
1091
+ with contextlib.suppress(asyncio.CancelledError, Exception):
1092
+ await task
1093
+
1094
+ # Delete the record
1095
+ await session.execute(
1096
+ delete(RunORM).where(
1097
+ RunORM.run_id == str(run_id),
1098
+ RunORM.thread_id == thread_id,
1099
+ RunORM.user_id == user.identity,
1100
+ )
1101
+ )
1102
+ await session.commit()
1103
+
1104
+ # Clean up active task if exists
1105
+ task = active_runs.pop(run_id, None)
1106
+ if task and not task.done():
1107
+ task.cancel()
1108
+
1109
+ # 204 No Content
1110
+ return