aegra-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aegra_api/__init__.py +3 -0
- aegra_api/api/__init__.py +1 -0
- aegra_api/api/assistants.py +235 -0
- aegra_api/api/runs.py +1110 -0
- aegra_api/api/store.py +200 -0
- aegra_api/api/threads.py +761 -0
- aegra_api/config.py +204 -0
- aegra_api/constants.py +5 -0
- aegra_api/core/__init__.py +0 -0
- aegra_api/core/app_loader.py +91 -0
- aegra_api/core/auth_ctx.py +65 -0
- aegra_api/core/auth_deps.py +186 -0
- aegra_api/core/auth_handlers.py +248 -0
- aegra_api/core/auth_middleware.py +331 -0
- aegra_api/core/database.py +123 -0
- aegra_api/core/health.py +131 -0
- aegra_api/core/orm.py +165 -0
- aegra_api/core/route_merger.py +69 -0
- aegra_api/core/serializers/__init__.py +7 -0
- aegra_api/core/serializers/base.py +22 -0
- aegra_api/core/serializers/general.py +54 -0
- aegra_api/core/serializers/langgraph.py +102 -0
- aegra_api/core/sse.py +178 -0
- aegra_api/main.py +303 -0
- aegra_api/middleware/__init__.py +4 -0
- aegra_api/middleware/double_encoded_json.py +74 -0
- aegra_api/middleware/logger_middleware.py +95 -0
- aegra_api/models/__init__.py +76 -0
- aegra_api/models/assistants.py +81 -0
- aegra_api/models/auth.py +62 -0
- aegra_api/models/enums.py +29 -0
- aegra_api/models/errors.py +29 -0
- aegra_api/models/runs.py +124 -0
- aegra_api/models/store.py +67 -0
- aegra_api/models/threads.py +152 -0
- aegra_api/observability/__init__.py +1 -0
- aegra_api/observability/base.py +88 -0
- aegra_api/observability/otel.py +133 -0
- aegra_api/observability/setup.py +27 -0
- aegra_api/observability/targets/__init__.py +11 -0
- aegra_api/observability/targets/base.py +18 -0
- aegra_api/observability/targets/langfuse.py +33 -0
- aegra_api/observability/targets/otlp.py +38 -0
- aegra_api/observability/targets/phoenix.py +24 -0
- aegra_api/services/__init__.py +0 -0
- aegra_api/services/assistant_service.py +569 -0
- aegra_api/services/base_broker.py +59 -0
- aegra_api/services/broker.py +141 -0
- aegra_api/services/event_converter.py +157 -0
- aegra_api/services/event_store.py +196 -0
- aegra_api/services/graph_streaming.py +433 -0
- aegra_api/services/langgraph_service.py +456 -0
- aegra_api/services/streaming_service.py +362 -0
- aegra_api/services/thread_state_service.py +128 -0
- aegra_api/settings.py +124 -0
- aegra_api/utils/__init__.py +3 -0
- aegra_api/utils/assistants.py +23 -0
- aegra_api/utils/run_utils.py +60 -0
- aegra_api/utils/setup_logging.py +122 -0
- aegra_api/utils/sse_utils.py +26 -0
- aegra_api/utils/status_compat.py +57 -0
- aegra_api-0.1.0.dist-info/METADATA +244 -0
- aegra_api-0.1.0.dist-info/RECORD +64 -0
- aegra_api-0.1.0.dist-info/WHEEL +4 -0
aegra_api/api/runs.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
1
|
+
"""Run endpoints for Agent Protocol"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
from typing import Any
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
import structlog
|
|
11
|
+
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
|
12
|
+
from fastapi.responses import StreamingResponse
|
|
13
|
+
from langgraph.types import Command, Send
|
|
14
|
+
from sqlalchemy import delete, select, update
|
|
15
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
16
|
+
|
|
17
|
+
from aegra_api.core.auth_ctx import with_auth_ctx
|
|
18
|
+
from aegra_api.core.auth_deps import get_current_user
|
|
19
|
+
from aegra_api.core.auth_handlers import build_auth_context, handle_event
|
|
20
|
+
from aegra_api.core.orm import Assistant as AssistantORM
|
|
21
|
+
from aegra_api.core.orm import Run as RunORM
|
|
22
|
+
from aegra_api.core.orm import Thread as ThreadORM
|
|
23
|
+
from aegra_api.core.orm import _get_session_maker, get_session
|
|
24
|
+
from aegra_api.core.serializers import GeneralSerializer
|
|
25
|
+
from aegra_api.core.sse import create_end_event, get_sse_headers
|
|
26
|
+
from aegra_api.models import Run, RunCreate, RunStatus, User
|
|
27
|
+
from aegra_api.services.broker import broker_manager
|
|
28
|
+
from aegra_api.services.graph_streaming import stream_graph_events
|
|
29
|
+
from aegra_api.services.langgraph_service import create_run_config, get_langgraph_service
|
|
30
|
+
from aegra_api.services.streaming_service import streaming_service
|
|
31
|
+
from aegra_api.utils.assistants import resolve_assistant_id
|
|
32
|
+
from aegra_api.utils.run_utils import (
|
|
33
|
+
_merge_jsonb,
|
|
34
|
+
)
|
|
35
|
+
from aegra_api.utils.status_compat import validate_run_status
|
|
36
|
+
|
|
37
|
+
router = APIRouter(tags=["Runs"])
|
|
38
|
+
|
|
39
|
+
logger = structlog.getLogger(__name__)
|
|
40
|
+
serializer = GeneralSerializer()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# NOTE: We keep only an in-memory task registry for asyncio.Task handles.
|
|
44
|
+
# All run metadata/state is persisted via ORM.
|
|
45
|
+
active_runs: dict[str, asyncio.Task] = {}
|
|
46
|
+
|
|
47
|
+
# Default stream modes for background run execution
|
|
48
|
+
DEFAULT_STREAM_MODES = ["values"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def map_command_to_langgraph(cmd: dict[str, Any]) -> Command:
|
|
52
|
+
"""Convert API command to LangGraph Command"""
|
|
53
|
+
goto = cmd.get("goto")
|
|
54
|
+
if goto is not None and not isinstance(goto, list):
|
|
55
|
+
goto = [goto]
|
|
56
|
+
|
|
57
|
+
update = cmd.get("update")
|
|
58
|
+
if isinstance(update, (tuple, list)) and all(
|
|
59
|
+
isinstance(t, (tuple, list)) and len(t) == 2 and isinstance(t[0], str) for t in update
|
|
60
|
+
):
|
|
61
|
+
update = [tuple(t) for t in update]
|
|
62
|
+
|
|
63
|
+
return Command(
|
|
64
|
+
update=update,
|
|
65
|
+
goto=([it if isinstance(it, str) else Send(it["node"], it["input"]) for it in goto] if goto else None),
|
|
66
|
+
resume=cmd.get("resume"),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def set_thread_status(session: AsyncSession, thread_id: str, status: str) -> None:
|
|
71
|
+
"""Update the status column of a thread.
|
|
72
|
+
|
|
73
|
+
Status is validated to ensure it conforms to API specification.
|
|
74
|
+
"""
|
|
75
|
+
# Validate status conforms to API specification
|
|
76
|
+
from aegra_api.utils.status_compat import validate_thread_status
|
|
77
|
+
|
|
78
|
+
validated_status = validate_thread_status(status)
|
|
79
|
+
result = await session.execute(
|
|
80
|
+
update(ThreadORM)
|
|
81
|
+
.where(ThreadORM.thread_id == thread_id)
|
|
82
|
+
.values(status=validated_status, updated_at=datetime.now(UTC))
|
|
83
|
+
)
|
|
84
|
+
await session.commit()
|
|
85
|
+
|
|
86
|
+
# Verify thread was updated (matching row exists)
|
|
87
|
+
if result.rowcount == 0:
|
|
88
|
+
raise HTTPException(404, f"Thread '{thread_id}' not found")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def update_thread_metadata(
|
|
92
|
+
session: AsyncSession,
|
|
93
|
+
thread_id: str,
|
|
94
|
+
assistant_id: str,
|
|
95
|
+
graph_id: str,
|
|
96
|
+
user_id: str | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""Update thread metadata with assistant and graph information (dialect agnostic).
|
|
99
|
+
|
|
100
|
+
If thread doesn't exist, auto-creates it.
|
|
101
|
+
"""
|
|
102
|
+
# Read-modify-write to avoid DB-specific JSON concat operators
|
|
103
|
+
thread = await session.scalar(select(ThreadORM).where(ThreadORM.thread_id == thread_id))
|
|
104
|
+
|
|
105
|
+
if not thread:
|
|
106
|
+
# Auto-create thread if it doesn't exist
|
|
107
|
+
if not user_id:
|
|
108
|
+
raise HTTPException(400, "Cannot auto-create thread: user_id is required")
|
|
109
|
+
|
|
110
|
+
metadata = {
|
|
111
|
+
"owner": user_id,
|
|
112
|
+
"assistant_id": str(assistant_id),
|
|
113
|
+
"graph_id": graph_id,
|
|
114
|
+
"thread_name": "",
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
thread_orm = ThreadORM(
|
|
118
|
+
thread_id=thread_id,
|
|
119
|
+
status="idle",
|
|
120
|
+
metadata_json=metadata,
|
|
121
|
+
user_id=user_id,
|
|
122
|
+
)
|
|
123
|
+
session.add(thread_orm)
|
|
124
|
+
await session.commit()
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
md = dict(getattr(thread, "metadata_json", {}) or {})
|
|
128
|
+
md.update(
|
|
129
|
+
{
|
|
130
|
+
"assistant_id": str(assistant_id),
|
|
131
|
+
"graph_id": graph_id,
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
await session.execute(
|
|
135
|
+
update(ThreadORM).where(ThreadORM.thread_id == thread_id).values(metadata_json=md, updated_at=datetime.now(UTC))
|
|
136
|
+
)
|
|
137
|
+
await session.commit()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def _validate_resume_command(session: AsyncSession, thread_id: str, command: dict[str, Any] | None) -> None:
|
|
141
|
+
"""Validate resume command requirements."""
|
|
142
|
+
if command and command.get("resume") is not None:
|
|
143
|
+
# Check if thread exists and is in interrupted state
|
|
144
|
+
thread_stmt = select(ThreadORM).where(ThreadORM.thread_id == thread_id)
|
|
145
|
+
thread = await session.scalar(thread_stmt)
|
|
146
|
+
if not thread:
|
|
147
|
+
raise HTTPException(404, f"Thread '{thread_id}' not found")
|
|
148
|
+
if thread.status != "interrupted":
|
|
149
|
+
raise HTTPException(400, "Cannot resume: thread is not in interrupted state")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@router.post("/threads/{thread_id}/runs", response_model=Run)
|
|
153
|
+
async def create_run(
|
|
154
|
+
thread_id: str,
|
|
155
|
+
request: RunCreate,
|
|
156
|
+
user: User = Depends(get_current_user),
|
|
157
|
+
session: AsyncSession = Depends(get_session),
|
|
158
|
+
) -> Run:
|
|
159
|
+
"""Create and execute a new run (persisted)."""
|
|
160
|
+
# Authorization check (create_run action on threads resource)
|
|
161
|
+
ctx = build_auth_context(user, "threads", "create_run")
|
|
162
|
+
value = {**request.model_dump(), "thread_id": thread_id}
|
|
163
|
+
filters = await handle_event(ctx, value)
|
|
164
|
+
|
|
165
|
+
# If handler modified config/context, update request
|
|
166
|
+
if filters:
|
|
167
|
+
if "config" in filters:
|
|
168
|
+
request.config = {**(request.config or {}), **filters["config"]}
|
|
169
|
+
if "context" in filters:
|
|
170
|
+
request.context = {**(request.context or {}), **filters["context"]}
|
|
171
|
+
elif value.get("config"):
|
|
172
|
+
request.config = {**(request.config or {}), **value["config"]}
|
|
173
|
+
elif value.get("context"):
|
|
174
|
+
request.context = {**(request.context or {}), **value["context"]}
|
|
175
|
+
|
|
176
|
+
# Validate resume command requirements early
|
|
177
|
+
await _validate_resume_command(session, thread_id, request.command)
|
|
178
|
+
|
|
179
|
+
run_id = str(uuid4())
|
|
180
|
+
|
|
181
|
+
# Get LangGraph service
|
|
182
|
+
langgraph_service = get_langgraph_service()
|
|
183
|
+
logger.info(f"[create_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
184
|
+
|
|
185
|
+
# Validate assistant exists and get its graph_id. If a graph_id was provided
|
|
186
|
+
# instead of an assistant UUID, map it deterministically and fall back to the
|
|
187
|
+
# default assistant created at startup.
|
|
188
|
+
requested_id = str(request.assistant_id)
|
|
189
|
+
available_graphs = langgraph_service.list_graphs()
|
|
190
|
+
resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
|
|
191
|
+
|
|
192
|
+
config = request.config
|
|
193
|
+
context = request.context
|
|
194
|
+
configurable = config.get("configurable", {})
|
|
195
|
+
|
|
196
|
+
if config.get("configurable") and context:
|
|
197
|
+
raise HTTPException(
|
|
198
|
+
status_code=400,
|
|
199
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if context:
|
|
203
|
+
configurable = context.copy()
|
|
204
|
+
config["configurable"] = configurable
|
|
205
|
+
else:
|
|
206
|
+
context = configurable.copy()
|
|
207
|
+
|
|
208
|
+
assistant_stmt = select(AssistantORM).where(
|
|
209
|
+
AssistantORM.assistant_id == resolved_assistant_id,
|
|
210
|
+
)
|
|
211
|
+
assistant = await session.scalar(assistant_stmt)
|
|
212
|
+
if not assistant:
|
|
213
|
+
raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
|
|
214
|
+
|
|
215
|
+
config = _merge_jsonb(assistant.config, config)
|
|
216
|
+
context = _merge_jsonb(assistant.context, context)
|
|
217
|
+
|
|
218
|
+
# Validate the assistant's graph exists
|
|
219
|
+
available_graphs = langgraph_service.list_graphs()
|
|
220
|
+
if assistant.graph_id not in available_graphs:
|
|
221
|
+
raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
|
|
222
|
+
|
|
223
|
+
# Mark thread as busy and update metadata with assistant/graph info
|
|
224
|
+
# update_thread_metadata will auto-create thread if it doesn't exist
|
|
225
|
+
await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
|
|
226
|
+
await set_thread_status(session, thread_id, "busy")
|
|
227
|
+
|
|
228
|
+
# Persist run record via ORM model in core.orm (Run table)
|
|
229
|
+
now = datetime.now(UTC)
|
|
230
|
+
run_orm = RunORM(
|
|
231
|
+
run_id=run_id, # explicitly set (DB can also default-generate if omitted)
|
|
232
|
+
thread_id=thread_id,
|
|
233
|
+
assistant_id=resolved_assistant_id,
|
|
234
|
+
status="pending",
|
|
235
|
+
input=request.input or {},
|
|
236
|
+
config=config,
|
|
237
|
+
context=context,
|
|
238
|
+
user_id=user.identity,
|
|
239
|
+
created_at=now,
|
|
240
|
+
updated_at=now,
|
|
241
|
+
output=None,
|
|
242
|
+
error_message=None,
|
|
243
|
+
)
|
|
244
|
+
session.add(run_orm)
|
|
245
|
+
await session.commit()
|
|
246
|
+
|
|
247
|
+
# Build response from ORM -> Pydantic
|
|
248
|
+
run = Run.model_validate(run_orm)
|
|
249
|
+
|
|
250
|
+
# Start execution asynchronously
|
|
251
|
+
# Don't pass the session to avoid transaction conflicts
|
|
252
|
+
task = asyncio.create_task(
|
|
253
|
+
execute_run_async(
|
|
254
|
+
run_id,
|
|
255
|
+
thread_id,
|
|
256
|
+
assistant.graph_id,
|
|
257
|
+
request.input or {},
|
|
258
|
+
user,
|
|
259
|
+
config,
|
|
260
|
+
context,
|
|
261
|
+
request.stream_mode,
|
|
262
|
+
None, # Don't pass session to avoid conflicts
|
|
263
|
+
request.checkpoint,
|
|
264
|
+
request.command,
|
|
265
|
+
request.interrupt_before,
|
|
266
|
+
request.interrupt_after,
|
|
267
|
+
request.multitask_strategy,
|
|
268
|
+
request.stream_subgraphs,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
logger.info(f"[create_run] background task created task_id={id(task)} for run_id={run_id}")
|
|
272
|
+
active_runs[run_id] = task
|
|
273
|
+
|
|
274
|
+
return run
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@router.post("/threads/{thread_id}/runs/stream")
|
|
278
|
+
async def create_and_stream_run(
|
|
279
|
+
thread_id: str,
|
|
280
|
+
request: RunCreate,
|
|
281
|
+
user: User = Depends(get_current_user),
|
|
282
|
+
session: AsyncSession = Depends(get_session),
|
|
283
|
+
) -> StreamingResponse:
|
|
284
|
+
"""Create a new run and stream its execution - persisted + SSE."""
|
|
285
|
+
|
|
286
|
+
# Validate resume command requirements early
|
|
287
|
+
await _validate_resume_command(session, thread_id, request.command)
|
|
288
|
+
|
|
289
|
+
run_id = str(uuid4())
|
|
290
|
+
|
|
291
|
+
# Get LangGraph service
|
|
292
|
+
langgraph_service = get_langgraph_service()
|
|
293
|
+
logger.info(
|
|
294
|
+
f"[create_and_stream_run] scheduling background task run_id={run_id} thread_id={thread_id} user={user.identity}"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Validate assistant exists and get its graph_id. Allow passing a graph_id
|
|
298
|
+
# by mapping it to a deterministic assistant ID.
|
|
299
|
+
requested_id = str(request.assistant_id)
|
|
300
|
+
available_graphs = langgraph_service.list_graphs()
|
|
301
|
+
|
|
302
|
+
resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
|
|
303
|
+
|
|
304
|
+
config = request.config
|
|
305
|
+
context = request.context
|
|
306
|
+
configurable = config.get("configurable", {})
|
|
307
|
+
|
|
308
|
+
if config.get("configurable") and context:
|
|
309
|
+
raise HTTPException(
|
|
310
|
+
status_code=400,
|
|
311
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if context:
|
|
315
|
+
configurable = context.copy()
|
|
316
|
+
config["configurable"] = configurable
|
|
317
|
+
else:
|
|
318
|
+
context = configurable.copy()
|
|
319
|
+
|
|
320
|
+
assistant_stmt = select(AssistantORM).where(
|
|
321
|
+
AssistantORM.assistant_id == resolved_assistant_id,
|
|
322
|
+
)
|
|
323
|
+
assistant = await session.scalar(assistant_stmt)
|
|
324
|
+
if not assistant:
|
|
325
|
+
raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
|
|
326
|
+
|
|
327
|
+
config = _merge_jsonb(assistant.config, config)
|
|
328
|
+
context = _merge_jsonb(assistant.context, context)
|
|
329
|
+
|
|
330
|
+
# Validate the assistant's graph exists
|
|
331
|
+
available_graphs = langgraph_service.list_graphs()
|
|
332
|
+
if assistant.graph_id not in available_graphs:
|
|
333
|
+
raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
|
|
334
|
+
|
|
335
|
+
# Mark thread as busy and update metadata with assistant/graph info
|
|
336
|
+
# update_thread_metadata will auto-create thread if it doesn't exist
|
|
337
|
+
await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
|
|
338
|
+
await set_thread_status(session, thread_id, "busy")
|
|
339
|
+
|
|
340
|
+
# Persist run record
|
|
341
|
+
now = datetime.now(UTC)
|
|
342
|
+
run_orm = RunORM(
|
|
343
|
+
run_id=run_id,
|
|
344
|
+
thread_id=thread_id,
|
|
345
|
+
assistant_id=resolved_assistant_id,
|
|
346
|
+
status="running",
|
|
347
|
+
input=request.input or {},
|
|
348
|
+
config=config,
|
|
349
|
+
context=context,
|
|
350
|
+
user_id=user.identity,
|
|
351
|
+
created_at=now,
|
|
352
|
+
updated_at=now,
|
|
353
|
+
output=None,
|
|
354
|
+
error_message=None,
|
|
355
|
+
)
|
|
356
|
+
session.add(run_orm)
|
|
357
|
+
await session.commit()
|
|
358
|
+
|
|
359
|
+
# Build response model for stream context
|
|
360
|
+
run = Run.model_validate(run_orm)
|
|
361
|
+
|
|
362
|
+
# Start background execution that will populate the broker
|
|
363
|
+
# Don't pass the session to avoid transaction conflicts
|
|
364
|
+
task = asyncio.create_task(
|
|
365
|
+
execute_run_async(
|
|
366
|
+
run_id,
|
|
367
|
+
thread_id,
|
|
368
|
+
assistant.graph_id,
|
|
369
|
+
request.input or {},
|
|
370
|
+
user,
|
|
371
|
+
config,
|
|
372
|
+
context,
|
|
373
|
+
request.stream_mode,
|
|
374
|
+
None, # Don't pass session to avoid conflicts
|
|
375
|
+
request.checkpoint,
|
|
376
|
+
request.command,
|
|
377
|
+
request.interrupt_before,
|
|
378
|
+
request.interrupt_after,
|
|
379
|
+
request.multitask_strategy,
|
|
380
|
+
request.stream_subgraphs,
|
|
381
|
+
)
|
|
382
|
+
)
|
|
383
|
+
logger.info(f"[create_and_stream_run] background task created task_id={id(task)} for run_id={run_id}")
|
|
384
|
+
active_runs[run_id] = task
|
|
385
|
+
|
|
386
|
+
# Extract requested stream mode(s)
|
|
387
|
+
stream_mode = request.stream_mode
|
|
388
|
+
if not stream_mode and config and "stream_mode" in config:
|
|
389
|
+
stream_mode = config["stream_mode"]
|
|
390
|
+
|
|
391
|
+
# Stream immediately from broker (which will also include replay of any early events)
|
|
392
|
+
# Default to cancel on disconnect - this matches user expectation that clicking
|
|
393
|
+
# "Cancel" in the frontend will stop the backend task. Users can explicitly
|
|
394
|
+
# set on_disconnect="continue" if they want the task to continue.
|
|
395
|
+
cancel_on_disconnect = (request.on_disconnect or "cancel").lower() == "cancel"
|
|
396
|
+
|
|
397
|
+
return StreamingResponse(
|
|
398
|
+
streaming_service.stream_run_execution(
|
|
399
|
+
run,
|
|
400
|
+
None,
|
|
401
|
+
cancel_on_disconnect=cancel_on_disconnect,
|
|
402
|
+
),
|
|
403
|
+
media_type="text/event-stream",
|
|
404
|
+
headers={
|
|
405
|
+
**get_sse_headers(),
|
|
406
|
+
"Location": f"/threads/{thread_id}/runs/{run_id}/stream",
|
|
407
|
+
"Content-Location": f"/threads/{thread_id}/runs/{run_id}",
|
|
408
|
+
},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@router.get("/threads/{thread_id}/runs/{run_id}", response_model=Run)
|
|
413
|
+
async def get_run(
|
|
414
|
+
thread_id: str,
|
|
415
|
+
run_id: str,
|
|
416
|
+
user: User = Depends(get_current_user),
|
|
417
|
+
session: AsyncSession = Depends(get_session),
|
|
418
|
+
) -> Run:
|
|
419
|
+
"""Get run by ID (persisted)."""
|
|
420
|
+
# Authorization check (read action on runs resource)
|
|
421
|
+
ctx = build_auth_context(user, "runs", "read")
|
|
422
|
+
value = {"run_id": run_id, "thread_id": thread_id}
|
|
423
|
+
await handle_event(ctx, value)
|
|
424
|
+
|
|
425
|
+
stmt = select(RunORM).where(
|
|
426
|
+
RunORM.run_id == str(run_id),
|
|
427
|
+
RunORM.thread_id == thread_id,
|
|
428
|
+
RunORM.user_id == user.identity,
|
|
429
|
+
)
|
|
430
|
+
logger.info(f"[get_run] querying DB run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
431
|
+
run_orm = await session.scalar(stmt)
|
|
432
|
+
if not run_orm:
|
|
433
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
434
|
+
|
|
435
|
+
# Refresh to ensure we have the latest data (in case background task updated it)
|
|
436
|
+
await session.refresh(run_orm)
|
|
437
|
+
|
|
438
|
+
logger.info(
|
|
439
|
+
f"[get_run] found run status={run_orm.status} user={user.identity} thread_id={thread_id} run_id={run_id}"
|
|
440
|
+
)
|
|
441
|
+
# Convert to Pydantic
|
|
442
|
+
return Run.model_validate(run_orm)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
@router.get("/threads/{thread_id}/runs", response_model=list[Run])
|
|
446
|
+
async def list_runs(
|
|
447
|
+
thread_id: str,
|
|
448
|
+
limit: int = Query(10, ge=1, description="Maximum number of runs to return"),
|
|
449
|
+
offset: int = Query(0, ge=0, description="Number of runs to skip for pagination"),
|
|
450
|
+
status: str | None = Query(None, description="Filter by run status"),
|
|
451
|
+
user: User = Depends(get_current_user),
|
|
452
|
+
session: AsyncSession = Depends(get_session),
|
|
453
|
+
) -> list[Run]:
|
|
454
|
+
"""List runs for a specific thread (persisted)."""
|
|
455
|
+
stmt = (
|
|
456
|
+
select(RunORM)
|
|
457
|
+
.where(
|
|
458
|
+
RunORM.thread_id == thread_id,
|
|
459
|
+
RunORM.user_id == user.identity,
|
|
460
|
+
*([RunORM.status == status] if status else []),
|
|
461
|
+
)
|
|
462
|
+
.limit(limit)
|
|
463
|
+
.offset(offset)
|
|
464
|
+
.order_by(RunORM.created_at.desc())
|
|
465
|
+
)
|
|
466
|
+
logger.info(f"[list_runs] querying DB thread_id={thread_id} user={user.identity}")
|
|
467
|
+
result = await session.scalars(stmt)
|
|
468
|
+
rows = result.all()
|
|
469
|
+
runs = [Run.model_validate(r) for r in rows]
|
|
470
|
+
logger.info(f"[list_runs] total={len(runs)} user={user.identity} thread_id={thread_id}")
|
|
471
|
+
return runs
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
@router.patch("/threads/{thread_id}/runs/{run_id}")
|
|
475
|
+
async def update_run(
|
|
476
|
+
thread_id: str,
|
|
477
|
+
run_id: str,
|
|
478
|
+
request: RunStatus,
|
|
479
|
+
user: User = Depends(get_current_user),
|
|
480
|
+
session: AsyncSession = Depends(get_session),
|
|
481
|
+
) -> Run:
|
|
482
|
+
"""Update run status (for cancellation/interruption, persisted)."""
|
|
483
|
+
logger.info(f"[update_run] fetch for update run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
484
|
+
run_orm = await session.scalar(
|
|
485
|
+
select(RunORM).where(
|
|
486
|
+
RunORM.run_id == str(run_id),
|
|
487
|
+
RunORM.thread_id == thread_id,
|
|
488
|
+
RunORM.user_id == user.identity,
|
|
489
|
+
)
|
|
490
|
+
)
|
|
491
|
+
if not run_orm:
|
|
492
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
493
|
+
|
|
494
|
+
# Handle interruption/cancellation
|
|
495
|
+
# Validate status conforms to API specification
|
|
496
|
+
validated_status = validate_run_status(request.status)
|
|
497
|
+
|
|
498
|
+
if validated_status == "interrupted":
|
|
499
|
+
logger.info(f"[update_run] cancelling/interrupting run_id={run_id} user={user.identity} thread_id={thread_id}")
|
|
500
|
+
# Handle interruption - use interrupt_run for cooperative interruption
|
|
501
|
+
await streaming_service.interrupt_run(run_id)
|
|
502
|
+
logger.info(f"[update_run] set DB status=interrupted run_id={run_id}")
|
|
503
|
+
await session.execute(
|
|
504
|
+
update(RunORM)
|
|
505
|
+
.where(RunORM.run_id == str(run_id))
|
|
506
|
+
.values(status="interrupted", updated_at=datetime.now(UTC))
|
|
507
|
+
)
|
|
508
|
+
await session.commit()
|
|
509
|
+
logger.info(f"[update_run] commit done (interrupted) run_id={run_id}")
|
|
510
|
+
|
|
511
|
+
# Return final run state
|
|
512
|
+
run_orm = await session.scalar(select(RunORM).where(RunORM.run_id == run_id))
|
|
513
|
+
if run_orm:
|
|
514
|
+
# Refresh to ensure we have the latest data after our own update
|
|
515
|
+
await session.refresh(run_orm)
|
|
516
|
+
return Run.model_validate(run_orm)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@router.get("/threads/{thread_id}/runs/{run_id}/join")
|
|
520
|
+
async def join_run(
|
|
521
|
+
thread_id: str,
|
|
522
|
+
run_id: str,
|
|
523
|
+
user: User = Depends(get_current_user),
|
|
524
|
+
session: AsyncSession = Depends(get_session),
|
|
525
|
+
) -> dict[str, Any]:
|
|
526
|
+
"""Join a run (wait for completion and return final output) - persisted."""
|
|
527
|
+
# Get run and validate it exists
|
|
528
|
+
run_orm = await session.scalar(
|
|
529
|
+
select(RunORM).where(
|
|
530
|
+
RunORM.run_id == str(run_id),
|
|
531
|
+
RunORM.thread_id == thread_id,
|
|
532
|
+
RunORM.user_id == user.identity,
|
|
533
|
+
)
|
|
534
|
+
)
|
|
535
|
+
if not run_orm:
|
|
536
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
537
|
+
|
|
538
|
+
# If already completed, return output immediately
|
|
539
|
+
# Check if run is in a terminal state
|
|
540
|
+
terminal_states = ["success", "error", "interrupted"]
|
|
541
|
+
if run_orm.status in terminal_states:
|
|
542
|
+
# Refresh to ensure we have the latest data
|
|
543
|
+
await session.refresh(run_orm)
|
|
544
|
+
output = getattr(run_orm, "output", None) or {}
|
|
545
|
+
return output
|
|
546
|
+
|
|
547
|
+
# Wait for background task to complete
|
|
548
|
+
task = active_runs.get(run_id)
|
|
549
|
+
if task:
|
|
550
|
+
try:
|
|
551
|
+
await asyncio.wait_for(task, timeout=30.0)
|
|
552
|
+
except TimeoutError:
|
|
553
|
+
# Task is taking too long, but that's okay - we'll check DB status
|
|
554
|
+
pass
|
|
555
|
+
except asyncio.CancelledError:
|
|
556
|
+
# Task was cancelled, that's also okay
|
|
557
|
+
pass
|
|
558
|
+
|
|
559
|
+
# Return final output from database
|
|
560
|
+
run_orm = await session.scalar(select(RunORM).where(RunORM.run_id == run_id))
|
|
561
|
+
if run_orm:
|
|
562
|
+
await session.refresh(run_orm) # Refresh to get latest data from DB
|
|
563
|
+
output = getattr(run_orm, "output", None) or {}
|
|
564
|
+
return output
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
@router.post("/threads/{thread_id}/runs/wait")
|
|
568
|
+
async def wait_for_run(
|
|
569
|
+
thread_id: str,
|
|
570
|
+
request: RunCreate,
|
|
571
|
+
user: User = Depends(get_current_user),
|
|
572
|
+
session: AsyncSession = Depends(get_session),
|
|
573
|
+
) -> dict[str, Any]:
|
|
574
|
+
"""Create a run, execute it, and wait for completion (Agent Protocol).
|
|
575
|
+
|
|
576
|
+
This endpoint combines run creation and execution with synchronous waiting.
|
|
577
|
+
Returns the final output directly (not the Run object).
|
|
578
|
+
|
|
579
|
+
Compatible with LangGraph SDK's runs.wait() method and Agent Protocol spec.
|
|
580
|
+
"""
|
|
581
|
+
# Validate resume command requirements early
|
|
582
|
+
await _validate_resume_command(session, thread_id, request.command)
|
|
583
|
+
|
|
584
|
+
run_id = str(uuid4())
|
|
585
|
+
|
|
586
|
+
# Get LangGraph service
|
|
587
|
+
langgraph_service = get_langgraph_service()
|
|
588
|
+
logger.info(f"[wait_for_run] creating run run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
589
|
+
|
|
590
|
+
# Validate assistant exists and get its graph_id
|
|
591
|
+
requested_id = str(request.assistant_id)
|
|
592
|
+
available_graphs = langgraph_service.list_graphs()
|
|
593
|
+
resolved_assistant_id = resolve_assistant_id(requested_id, available_graphs)
|
|
594
|
+
|
|
595
|
+
config = request.config
|
|
596
|
+
context = request.context
|
|
597
|
+
configurable = config.get("configurable", {})
|
|
598
|
+
|
|
599
|
+
if config.get("configurable") and context:
|
|
600
|
+
raise HTTPException(
|
|
601
|
+
status_code=400,
|
|
602
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
if context:
|
|
606
|
+
configurable = context.copy()
|
|
607
|
+
config["configurable"] = configurable
|
|
608
|
+
else:
|
|
609
|
+
context = configurable.copy()
|
|
610
|
+
|
|
611
|
+
assistant_stmt = select(AssistantORM).where(
|
|
612
|
+
AssistantORM.assistant_id == resolved_assistant_id,
|
|
613
|
+
)
|
|
614
|
+
assistant = await session.scalar(assistant_stmt)
|
|
615
|
+
if not assistant:
|
|
616
|
+
raise HTTPException(404, f"Assistant '{request.assistant_id}' not found")
|
|
617
|
+
|
|
618
|
+
config = _merge_jsonb(assistant.config, config)
|
|
619
|
+
context = _merge_jsonb(assistant.context, context)
|
|
620
|
+
|
|
621
|
+
# Validate the assistant's graph exists
|
|
622
|
+
available_graphs = langgraph_service.list_graphs()
|
|
623
|
+
if assistant.graph_id not in available_graphs:
|
|
624
|
+
raise HTTPException(404, f"Graph '{assistant.graph_id}' not found for assistant")
|
|
625
|
+
|
|
626
|
+
# Mark thread as busy and update metadata with assistant/graph info
|
|
627
|
+
# update_thread_metadata will auto-create thread if it doesn't exist
|
|
628
|
+
await update_thread_metadata(session, thread_id, assistant.assistant_id, assistant.graph_id, user.identity)
|
|
629
|
+
await set_thread_status(session, thread_id, "busy")
|
|
630
|
+
|
|
631
|
+
# Persist run record
|
|
632
|
+
now = datetime.now(UTC)
|
|
633
|
+
run_orm = RunORM(
|
|
634
|
+
run_id=run_id,
|
|
635
|
+
thread_id=thread_id,
|
|
636
|
+
assistant_id=resolved_assistant_id,
|
|
637
|
+
status="pending",
|
|
638
|
+
input=request.input or {},
|
|
639
|
+
config=config,
|
|
640
|
+
context=context,
|
|
641
|
+
user_id=user.identity,
|
|
642
|
+
created_at=now,
|
|
643
|
+
updated_at=now,
|
|
644
|
+
output=None,
|
|
645
|
+
error_message=None,
|
|
646
|
+
)
|
|
647
|
+
session.add(run_orm)
|
|
648
|
+
await session.commit()
|
|
649
|
+
|
|
650
|
+
# Start execution asynchronously
|
|
651
|
+
task = asyncio.create_task(
|
|
652
|
+
execute_run_async(
|
|
653
|
+
run_id,
|
|
654
|
+
thread_id,
|
|
655
|
+
assistant.graph_id,
|
|
656
|
+
request.input or {},
|
|
657
|
+
user,
|
|
658
|
+
config,
|
|
659
|
+
context,
|
|
660
|
+
request.stream_mode,
|
|
661
|
+
None, # Don't pass session to avoid conflicts
|
|
662
|
+
request.checkpoint,
|
|
663
|
+
request.command,
|
|
664
|
+
request.interrupt_before,
|
|
665
|
+
request.interrupt_after,
|
|
666
|
+
request.multitask_strategy,
|
|
667
|
+
request.stream_subgraphs,
|
|
668
|
+
)
|
|
669
|
+
)
|
|
670
|
+
logger.info(f"[wait_for_run] background task created task_id={id(task)} for run_id={run_id}")
|
|
671
|
+
active_runs[run_id] = task
|
|
672
|
+
|
|
673
|
+
# Wait for task to complete with timeout
|
|
674
|
+
try:
|
|
675
|
+
await asyncio.wait_for(task, timeout=300.0) # 5 minute timeout
|
|
676
|
+
except TimeoutError:
|
|
677
|
+
logger.warning(f"[wait_for_run] timeout waiting for run_id={run_id}")
|
|
678
|
+
# Don't raise, just return current state
|
|
679
|
+
except asyncio.CancelledError:
|
|
680
|
+
logger.info(f"[wait_for_run] cancelled run_id={run_id}")
|
|
681
|
+
# Task was cancelled, continue to return final state
|
|
682
|
+
except Exception as e:
|
|
683
|
+
logger.error(f"[wait_for_run] exception in run_id={run_id}: {e}")
|
|
684
|
+
# Exception already handled by execute_run_async
|
|
685
|
+
|
|
686
|
+
# Get final output from database
|
|
687
|
+
run_orm = await session.scalar(
|
|
688
|
+
select(RunORM).where(
|
|
689
|
+
RunORM.run_id == run_id,
|
|
690
|
+
RunORM.thread_id == thread_id,
|
|
691
|
+
RunORM.user_id == user.identity,
|
|
692
|
+
)
|
|
693
|
+
)
|
|
694
|
+
if not run_orm:
|
|
695
|
+
raise HTTPException(500, f"Run '{run_id}' disappeared during execution")
|
|
696
|
+
|
|
697
|
+
await session.refresh(run_orm)
|
|
698
|
+
|
|
699
|
+
# Return output based on final status
|
|
700
|
+
if run_orm.status == "success":
|
|
701
|
+
return run_orm.output or {}
|
|
702
|
+
elif run_orm.status == "error":
|
|
703
|
+
# For error runs, still return output if available, but log the error
|
|
704
|
+
logger.error(f"[wait_for_run] run failed run_id={run_id} error={run_orm.error_message}")
|
|
705
|
+
return run_orm.output or {}
|
|
706
|
+
elif run_orm.status == "interrupted":
|
|
707
|
+
# Return partial output for interrupted runs
|
|
708
|
+
return run_orm.output or {}
|
|
709
|
+
else:
|
|
710
|
+
# Still pending/running after timeout
|
|
711
|
+
return run_orm.output or {}
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
# TODO: check if this method is actually required because the implementation doesn't seem correct.
|
|
715
|
+
@router.get("/threads/{thread_id}/runs/{run_id}/stream")
|
|
716
|
+
async def stream_run(
|
|
717
|
+
thread_id: str,
|
|
718
|
+
run_id: str,
|
|
719
|
+
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
|
720
|
+
_stream_mode: str | None = Query(None),
|
|
721
|
+
user: User = Depends(get_current_user),
|
|
722
|
+
session: AsyncSession = Depends(get_session),
|
|
723
|
+
) -> StreamingResponse:
|
|
724
|
+
"""Stream run execution with SSE and reconnection support - persisted metadata."""
|
|
725
|
+
logger.info(f"[stream_run] fetch for stream run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
726
|
+
run_orm = await session.scalar(
|
|
727
|
+
select(RunORM).where(
|
|
728
|
+
RunORM.run_id == str(run_id),
|
|
729
|
+
RunORM.thread_id == thread_id,
|
|
730
|
+
RunORM.user_id == user.identity,
|
|
731
|
+
)
|
|
732
|
+
)
|
|
733
|
+
if not run_orm:
|
|
734
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
735
|
+
|
|
736
|
+
logger.info(f"[stream_run] status={run_orm.status} user={user.identity} thread_id={thread_id} run_id={run_id}")
|
|
737
|
+
# If already terminal, emit a final end event
|
|
738
|
+
terminal_states = ["success", "error", "interrupted"]
|
|
739
|
+
if run_orm.status in terminal_states:
|
|
740
|
+
|
|
741
|
+
async def generate_final() -> AsyncIterator[str]:
|
|
742
|
+
yield create_end_event()
|
|
743
|
+
|
|
744
|
+
logger.info(f"[stream_run] starting terminal stream run_id={run_id} status={run_orm.status}")
|
|
745
|
+
return StreamingResponse(
|
|
746
|
+
generate_final(),
|
|
747
|
+
media_type="text/event-stream",
|
|
748
|
+
headers={
|
|
749
|
+
**get_sse_headers(),
|
|
750
|
+
"Location": f"/threads/{thread_id}/runs/{run_id}/stream",
|
|
751
|
+
"Content-Location": f"/threads/{thread_id}/runs/{run_id}",
|
|
752
|
+
},
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Stream active or pending runs via broker
|
|
756
|
+
|
|
757
|
+
# Build a lightweight Pydantic Run from ORM for streaming context (IDs already strings)
|
|
758
|
+
run_model = Run.model_validate(run_orm)
|
|
759
|
+
|
|
760
|
+
return StreamingResponse(
|
|
761
|
+
streaming_service.stream_run_execution(run_model, last_event_id, cancel_on_disconnect=False),
|
|
762
|
+
media_type="text/event-stream",
|
|
763
|
+
headers={
|
|
764
|
+
**get_sse_headers(),
|
|
765
|
+
"Location": f"/threads/{thread_id}/runs/{run_id}/stream",
|
|
766
|
+
"Content-Location": f"/threads/{thread_id}/runs/{run_id}",
|
|
767
|
+
},
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
|
772
|
+
async def cancel_run_endpoint(
|
|
773
|
+
thread_id: str,
|
|
774
|
+
run_id: str,
|
|
775
|
+
wait: int = Query(0, ge=0, le=1, description="Whether to wait for the run task to settle"),
|
|
776
|
+
action: str = Query("cancel", pattern="^(cancel|interrupt)$", description="Cancellation action"),
|
|
777
|
+
user: User = Depends(get_current_user),
|
|
778
|
+
session: AsyncSession = Depends(get_session),
|
|
779
|
+
) -> Run:
|
|
780
|
+
"""
|
|
781
|
+
Cancel or interrupt a run (client-compatible endpoint).
|
|
782
|
+
|
|
783
|
+
Matches client usage:
|
|
784
|
+
POST /v1/threads/{thread_id}/runs/{run_id}/cancel?wait=0&action=interrupt
|
|
785
|
+
|
|
786
|
+
- action=cancel => hard cancel
|
|
787
|
+
- action=interrupt => cooperative interrupt if supported
|
|
788
|
+
- wait=1 => await background task to finish settling
|
|
789
|
+
"""
|
|
790
|
+
logger.info(f"[cancel_run] fetch run run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
791
|
+
run_orm = await session.scalar(
|
|
792
|
+
select(RunORM).where(
|
|
793
|
+
RunORM.run_id == run_id,
|
|
794
|
+
RunORM.thread_id == thread_id,
|
|
795
|
+
RunORM.user_id == user.identity,
|
|
796
|
+
)
|
|
797
|
+
)
|
|
798
|
+
if not run_orm:
|
|
799
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
800
|
+
|
|
801
|
+
if action == "interrupt":
|
|
802
|
+
logger.info(f"[cancel_run] interrupt run_id={run_id} user={user.identity} thread_id={thread_id}")
|
|
803
|
+
await streaming_service.interrupt_run(run_id)
|
|
804
|
+
# Persist status as interrupted
|
|
805
|
+
await session.execute(
|
|
806
|
+
update(RunORM)
|
|
807
|
+
.where(RunORM.run_id == str(run_id))
|
|
808
|
+
.values(status="interrupted", updated_at=datetime.now(UTC))
|
|
809
|
+
)
|
|
810
|
+
await session.commit()
|
|
811
|
+
else:
|
|
812
|
+
logger.info(f"[cancel_run] cancel run_id={run_id} user={user.identity} thread_id={thread_id}")
|
|
813
|
+
await streaming_service.cancel_run(run_id)
|
|
814
|
+
# Persist status as interrupted
|
|
815
|
+
await session.execute(
|
|
816
|
+
update(RunORM)
|
|
817
|
+
.where(RunORM.run_id == str(run_id))
|
|
818
|
+
.values(status="interrupted", updated_at=datetime.now(UTC))
|
|
819
|
+
)
|
|
820
|
+
await session.commit()
|
|
821
|
+
|
|
822
|
+
# Optionally wait for background task
|
|
823
|
+
if wait:
|
|
824
|
+
task = active_runs.get(run_id)
|
|
825
|
+
if task:
|
|
826
|
+
with contextlib.suppress(asyncio.CancelledError, Exception):
|
|
827
|
+
await task
|
|
828
|
+
|
|
829
|
+
# Reload and return updated Run (do NOT delete here; deletion is a separate endpoint)
|
|
830
|
+
run_orm = await session.scalar(
|
|
831
|
+
select(RunORM).where(
|
|
832
|
+
RunORM.run_id == run_id,
|
|
833
|
+
RunORM.thread_id == thread_id,
|
|
834
|
+
RunORM.user_id == user.identity,
|
|
835
|
+
)
|
|
836
|
+
)
|
|
837
|
+
if not run_orm:
|
|
838
|
+
raise HTTPException(404, f"Run '{run_id}' not found after cancellation")
|
|
839
|
+
return Run.model_validate(run_orm)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
async def execute_run_async(
|
|
843
|
+
run_id: str,
|
|
844
|
+
thread_id: str,
|
|
845
|
+
graph_id: str,
|
|
846
|
+
input_data: dict,
|
|
847
|
+
user: User,
|
|
848
|
+
config: dict | None = None,
|
|
849
|
+
context: dict | None = None,
|
|
850
|
+
stream_mode: list[str] | None = None,
|
|
851
|
+
session: AsyncSession | None = None,
|
|
852
|
+
checkpoint: dict | None = None,
|
|
853
|
+
command: dict[str, Any] | None = None,
|
|
854
|
+
interrupt_before: str | list[str] | None = None,
|
|
855
|
+
interrupt_after: str | list[str] | None = None,
|
|
856
|
+
_multitask_strategy: str | None = None,
|
|
857
|
+
subgraphs: bool | None = False,
|
|
858
|
+
) -> None:
|
|
859
|
+
"""Execute run asynchronously in background using streaming to capture all events""" # Use provided session or get a new one
|
|
860
|
+
if session is None:
|
|
861
|
+
maker = _get_session_maker()
|
|
862
|
+
session = maker()
|
|
863
|
+
|
|
864
|
+
try:
|
|
865
|
+
# Update status
|
|
866
|
+
await update_run_status(run_id, "running", session=session)
|
|
867
|
+
|
|
868
|
+
# Get graph and execute
|
|
869
|
+
langgraph_service = get_langgraph_service()
|
|
870
|
+
|
|
871
|
+
run_config = create_run_config(run_id, thread_id, user, config or {}, checkpoint)
|
|
872
|
+
|
|
873
|
+
# Handle human-in-the-loop fields
|
|
874
|
+
if interrupt_before is not None:
|
|
875
|
+
run_config["interrupt_before"] = (
|
|
876
|
+
interrupt_before if isinstance(interrupt_before, list) else [interrupt_before]
|
|
877
|
+
)
|
|
878
|
+
if interrupt_after is not None:
|
|
879
|
+
run_config["interrupt_after"] = interrupt_after if isinstance(interrupt_after, list) else [interrupt_after]
|
|
880
|
+
|
|
881
|
+
# Note: multitask_strategy is handled at the run creation level, not execution level
|
|
882
|
+
# It controls concurrent run behavior, not graph execution behavior
|
|
883
|
+
|
|
884
|
+
# Determine input for execution (either input_data or command)
|
|
885
|
+
if command is not None:
|
|
886
|
+
# When command is provided, it replaces input entirely
|
|
887
|
+
execution_input = map_command_to_langgraph(command)
|
|
888
|
+
else:
|
|
889
|
+
# No command, use regular input
|
|
890
|
+
execution_input = input_data
|
|
891
|
+
|
|
892
|
+
# Execute using streaming to capture events for later replay
|
|
893
|
+
event_counter = 0
|
|
894
|
+
final_output = None
|
|
895
|
+
has_interrupt = False
|
|
896
|
+
|
|
897
|
+
# Prepare stream modes for execution
|
|
898
|
+
if stream_mode is None:
|
|
899
|
+
stream_mode_list = DEFAULT_STREAM_MODES.copy()
|
|
900
|
+
elif isinstance(stream_mode, str):
|
|
901
|
+
stream_mode_list = [stream_mode]
|
|
902
|
+
else:
|
|
903
|
+
stream_mode_list = stream_mode.copy()
|
|
904
|
+
|
|
905
|
+
async with (
|
|
906
|
+
langgraph_service.get_graph(graph_id) as graph,
|
|
907
|
+
with_auth_ctx(user, []),
|
|
908
|
+
):
|
|
909
|
+
# Stream events using the graph_streaming service
|
|
910
|
+
try:
|
|
911
|
+
async for event_type, event_data in stream_graph_events(
|
|
912
|
+
graph=graph,
|
|
913
|
+
input_data=execution_input,
|
|
914
|
+
config=run_config,
|
|
915
|
+
stream_mode=stream_mode_list,
|
|
916
|
+
context=context,
|
|
917
|
+
subgraphs=subgraphs,
|
|
918
|
+
on_checkpoint=lambda _: None, # Can add checkpoint handling if needed
|
|
919
|
+
on_task_result=lambda _: None, # Can add task result handling if needed
|
|
920
|
+
):
|
|
921
|
+
try:
|
|
922
|
+
# Increment event counter
|
|
923
|
+
event_counter += 1
|
|
924
|
+
event_id = f"{run_id}_event_{event_counter}"
|
|
925
|
+
|
|
926
|
+
# Create event tuple for broker/storage
|
|
927
|
+
event_tuple = (event_type, event_data)
|
|
928
|
+
|
|
929
|
+
# Forward to broker for live consumers (already filtered by graph_streaming)
|
|
930
|
+
await streaming_service.put_to_broker(run_id, event_id, event_tuple)
|
|
931
|
+
|
|
932
|
+
# Store for replay (already filtered by graph_streaming)
|
|
933
|
+
await streaming_service.store_event_from_raw(run_id, event_id, event_tuple)
|
|
934
|
+
|
|
935
|
+
# Check for interrupt
|
|
936
|
+
if isinstance(event_data, dict) and "__interrupt__" in event_data:
|
|
937
|
+
has_interrupt = True
|
|
938
|
+
|
|
939
|
+
# Track final output from values events (handles both "values" and "values|namespace")
|
|
940
|
+
if event_type.startswith("values"):
|
|
941
|
+
final_output = event_data
|
|
942
|
+
|
|
943
|
+
except Exception as event_error:
|
|
944
|
+
# Error processing individual event - send error to frontend immediately
|
|
945
|
+
logger.error(f"[execute_run_async] error processing event for run_id={run_id}: {event_error}")
|
|
946
|
+
error_type = type(event_error).__name__
|
|
947
|
+
await streaming_service.signal_run_error(run_id, str(event_error), error_type)
|
|
948
|
+
raise
|
|
949
|
+
|
|
950
|
+
except Exception as stream_error:
|
|
951
|
+
# Error during streaming (e.g., graph execution error)
|
|
952
|
+
# Send error to frontend before re-raising
|
|
953
|
+
logger.error(f"[execute_run_async] streaming error for run_id={run_id}: {stream_error}")
|
|
954
|
+
error_type = type(stream_error).__name__
|
|
955
|
+
await streaming_service.signal_run_error(run_id, str(stream_error), error_type)
|
|
956
|
+
raise
|
|
957
|
+
|
|
958
|
+
if has_interrupt:
|
|
959
|
+
await update_run_status(run_id, "interrupted", output=final_output or {}, session=session)
|
|
960
|
+
if not session:
|
|
961
|
+
raise RuntimeError(f"No database session available to update thread {thread_id} status")
|
|
962
|
+
await set_thread_status(session, thread_id, "interrupted")
|
|
963
|
+
|
|
964
|
+
else:
|
|
965
|
+
# Update with results - use standard status
|
|
966
|
+
await update_run_status(run_id, "success", output=final_output or {}, session=session)
|
|
967
|
+
# Mark thread back to idle
|
|
968
|
+
if not session:
|
|
969
|
+
raise RuntimeError(f"No database session available to update thread {thread_id} status")
|
|
970
|
+
await set_thread_status(session, thread_id, "idle")
|
|
971
|
+
|
|
972
|
+
except asyncio.CancelledError:
|
|
973
|
+
# Store empty output to avoid JSON serialization issues - use standard status
|
|
974
|
+
await update_run_status(run_id, "interrupted", output={}, session=session)
|
|
975
|
+
if not session:
|
|
976
|
+
raise RuntimeError(f"No database session available to update thread {thread_id} status") from None
|
|
977
|
+
await set_thread_status(session, thread_id, "idle")
|
|
978
|
+
# Signal cancellation to broker
|
|
979
|
+
await streaming_service.signal_run_cancelled(run_id)
|
|
980
|
+
raise
|
|
981
|
+
except Exception as e:
|
|
982
|
+
# Store empty output to avoid JSON serialization issues - use standard status
|
|
983
|
+
await update_run_status(run_id, "error", output={}, error=str(e), session=session)
|
|
984
|
+
if not session:
|
|
985
|
+
raise RuntimeError(f"No database session available to update thread {thread_id} status") from None
|
|
986
|
+
# Set thread status to "error" when run fails (matches API specification)
|
|
987
|
+
await set_thread_status(session, thread_id, "error")
|
|
988
|
+
# Note: Error event already sent to broker in inner exception handler
|
|
989
|
+
# Only signal if broker still exists (cleanup not yet called)
|
|
990
|
+
broker = broker_manager.get_broker(run_id)
|
|
991
|
+
if broker and not broker.is_finished():
|
|
992
|
+
error_type = type(e).__name__
|
|
993
|
+
await streaming_service.signal_run_error(run_id, str(e), error_type)
|
|
994
|
+
raise
|
|
995
|
+
finally:
|
|
996
|
+
# Clean up broker
|
|
997
|
+
await streaming_service.cleanup_run(run_id)
|
|
998
|
+
active_runs.pop(run_id, None)
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
async def update_run_status(
|
|
1002
|
+
run_id: str,
|
|
1003
|
+
status: str,
|
|
1004
|
+
output: Any = None,
|
|
1005
|
+
error: str | None = None,
|
|
1006
|
+
session: AsyncSession | None = None,
|
|
1007
|
+
) -> None:
|
|
1008
|
+
"""Update run status in database (persisted). If session not provided, opens a short-lived session.
|
|
1009
|
+
|
|
1010
|
+
Status is validated to ensure it conforms to API specification.
|
|
1011
|
+
"""
|
|
1012
|
+
# Validate status conforms to API specification
|
|
1013
|
+
validated_status = validate_run_status(status)
|
|
1014
|
+
|
|
1015
|
+
owns_session = False
|
|
1016
|
+
if session is None:
|
|
1017
|
+
maker = _get_session_maker()
|
|
1018
|
+
session = maker() # type: ignore[assignment]
|
|
1019
|
+
owns_session = True
|
|
1020
|
+
try:
|
|
1021
|
+
values = {"status": validated_status, "updated_at": datetime.now(UTC)}
|
|
1022
|
+
if output is not None:
|
|
1023
|
+
# Serialize output to ensure JSON compatibility
|
|
1024
|
+
try:
|
|
1025
|
+
serialized_output = serializer.serialize(output)
|
|
1026
|
+
values["output"] = serialized_output
|
|
1027
|
+
except Exception as e:
|
|
1028
|
+
logger.warning(f"Failed to serialize output for run {run_id}: {e}")
|
|
1029
|
+
values["output"] = {
|
|
1030
|
+
"error": "Output serialization failed",
|
|
1031
|
+
"original_type": str(type(output)),
|
|
1032
|
+
}
|
|
1033
|
+
if error is not None:
|
|
1034
|
+
values["error_message"] = error
|
|
1035
|
+
logger.info(f"[update_run_status] updating DB run_id={run_id} status={validated_status}")
|
|
1036
|
+
await session.execute(update(RunORM).where(RunORM.run_id == str(run_id)).values(**values)) # type: ignore[arg-type]
|
|
1037
|
+
await session.commit()
|
|
1038
|
+
logger.info(f"[update_run_status] commit done run_id={run_id}")
|
|
1039
|
+
finally:
|
|
1040
|
+
# Close only if we created it here
|
|
1041
|
+
if owns_session:
|
|
1042
|
+
await session.close() # type: ignore[func-returns-value]
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
@router.delete("/threads/{thread_id}/runs/{run_id}", status_code=204)
|
|
1046
|
+
async def delete_run(
|
|
1047
|
+
thread_id: str,
|
|
1048
|
+
run_id: str,
|
|
1049
|
+
force: int = Query(0, ge=0, le=1, description="Force cancel active run before delete (1=yes)"),
|
|
1050
|
+
user: User = Depends(get_current_user),
|
|
1051
|
+
session: AsyncSession = Depends(get_session),
|
|
1052
|
+
) -> None:
|
|
1053
|
+
"""Delete run by ID"""
|
|
1054
|
+
# Authorization check (delete action on runs resource)
|
|
1055
|
+
ctx = build_auth_context(user, "runs", "delete")
|
|
1056
|
+
value = {"run_id": run_id, "thread_id": thread_id}
|
|
1057
|
+
await handle_event(ctx, value)
|
|
1058
|
+
"""
|
|
1059
|
+
Delete a run record.
|
|
1060
|
+
|
|
1061
|
+
Behavior:
|
|
1062
|
+
- If the run is active (pending/running/streaming) and force=0, returns 409 Conflict.
|
|
1063
|
+
- If force=1 and the run is active, cancels it first (best-effort) and then deletes.
|
|
1064
|
+
- Always returns 204 No Content on successful deletion.
|
|
1065
|
+
"""
|
|
1066
|
+
logger.info(f"[delete_run] fetch run run_id={run_id} thread_id={thread_id} user={user.identity}")
|
|
1067
|
+
run_orm = await session.scalar(
|
|
1068
|
+
select(RunORM).where(
|
|
1069
|
+
RunORM.run_id == str(run_id),
|
|
1070
|
+
RunORM.thread_id == thread_id,
|
|
1071
|
+
RunORM.user_id == user.identity,
|
|
1072
|
+
)
|
|
1073
|
+
)
|
|
1074
|
+
if not run_orm:
|
|
1075
|
+
raise HTTPException(404, f"Run '{run_id}' not found")
|
|
1076
|
+
|
|
1077
|
+
# If active and not forcing, reject deletion
|
|
1078
|
+
if run_orm.status in ["pending", "running"] and not force:
|
|
1079
|
+
raise HTTPException(
|
|
1080
|
+
status_code=409,
|
|
1081
|
+
detail="Run is active. Retry with force=1 to cancel and delete.",
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
# If forcing and active, cancel first
|
|
1085
|
+
if force and run_orm.status in ["pending", "running"]:
|
|
1086
|
+
logger.info(f"[delete_run] force-cancelling active run run_id={run_id}")
|
|
1087
|
+
await streaming_service.cancel_run(run_id)
|
|
1088
|
+
# Best-effort: wait for bg task to settle
|
|
1089
|
+
task = active_runs.get(run_id)
|
|
1090
|
+
if task:
|
|
1091
|
+
with contextlib.suppress(asyncio.CancelledError, Exception):
|
|
1092
|
+
await task
|
|
1093
|
+
|
|
1094
|
+
# Delete the record
|
|
1095
|
+
await session.execute(
|
|
1096
|
+
delete(RunORM).where(
|
|
1097
|
+
RunORM.run_id == str(run_id),
|
|
1098
|
+
RunORM.thread_id == thread_id,
|
|
1099
|
+
RunORM.user_id == user.identity,
|
|
1100
|
+
)
|
|
1101
|
+
)
|
|
1102
|
+
await session.commit()
|
|
1103
|
+
|
|
1104
|
+
# Clean up active task if exists
|
|
1105
|
+
task = active_runs.pop(run_id, None)
|
|
1106
|
+
if task and not task.done():
|
|
1107
|
+
task.cancel()
|
|
1108
|
+
|
|
1109
|
+
# 204 No Content
|
|
1110
|
+
return
|