agent-api-server 2.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_api_server/__init__.py +0 -0
- agent_api_server/api/__init__.py +0 -0
- agent_api_server/api/v1/__init__.py +0 -0
- agent_api_server/api/v1/api.py +25 -0
- agent_api_server/api/v1/config.py +57 -0
- agent_api_server/api/v1/graph.py +59 -0
- agent_api_server/api/v1/schema.py +57 -0
- agent_api_server/api/v1/thread.py +563 -0
- agent_api_server/cache/__init__.py +0 -0
- agent_api_server/cache/redis_cache.py +385 -0
- agent_api_server/callback_handler.py +18 -0
- agent_api_server/client/css/styles.css +1202 -0
- agent_api_server/client/favicon.ico +0 -0
- agent_api_server/client/index.html +102 -0
- agent_api_server/client/js/app.js +1499 -0
- agent_api_server/client/js/index.umd.js +824 -0
- agent_api_server/config_center/config_center.py +239 -0
- agent_api_server/configs/__init__.py +3 -0
- agent_api_server/configs/config.py +163 -0
- agent_api_server/dynamic_llm/__init__.py +0 -0
- agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
- agent_api_server/listener.py +530 -0
- agent_api_server/log/__init__.py +0 -0
- agent_api_server/log/formatters.py +122 -0
- agent_api_server/log/logging.json +50 -0
- agent_api_server/mcp_convert/__init__.py +0 -0
- agent_api_server/mcp_convert/mcp_convert.py +375 -0
- agent_api_server/memeory/__init__.py +0 -0
- agent_api_server/memeory/postgres.py +233 -0
- agent_api_server/register/__init__.py +0 -0
- agent_api_server/register/register.py +65 -0
- agent_api_server/service.py +354 -0
- agent_api_server/service_hub/service_hub.py +233 -0
- agent_api_server/service_hub/service_hub_test.py +700 -0
- agent_api_server/shared/__init__.py +0 -0
- agent_api_server/shared/ase.py +54 -0
- agent_api_server/shared/base_model.py +103 -0
- agent_api_server/shared/common.py +110 -0
- agent_api_server/shared/decode_token.py +107 -0
- agent_api_server/shared/detect_message.py +410 -0
- agent_api_server/shared/get_model_info.py +491 -0
- agent_api_server/shared/message.py +419 -0
- agent_api_server/shared/util_func.py +372 -0
- agent_api_server/sso_service/__init__.py +1 -0
- agent_api_server/sso_service/sdk/__init__.py +1 -0
- agent_api_server/sso_service/sdk/client.py +224 -0
- agent_api_server/sso_service/sdk/credential.py +11 -0
- agent_api_server/sso_service/sdk/encoding.py +22 -0
- agent_api_server/sso_service/sso_service.py +177 -0
- agent_api_server-2.1.7.dist-info/METADATA +130 -0
- agent_api_server-2.1.7.dist-info/RECORD +52 -0
- agent_api_server-2.1.7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import traceback
|
|
3
|
+
import asyncio
|
|
4
|
+
from fastapi import Request
|
|
5
|
+
from http.cookies import SimpleCookie
|
|
6
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
7
|
+
from typing import Dict, Any, Optional, AsyncGenerator, Tuple, List
|
|
8
|
+
from fastapi.responses import StreamingResponse
|
|
9
|
+
from fastapi import APIRouter, Body, status, HTTPException
|
|
10
|
+
from agent_api_server.cache.redis_cache import ThreadState, AsyncRedisThreadStorage
|
|
11
|
+
from langgraph.types import Command
|
|
12
|
+
from agent_api_server.shared.base_model import (
|
|
13
|
+
ThreadInfo,
|
|
14
|
+
RunResponse,
|
|
15
|
+
error_response,
|
|
16
|
+
sse_response_example
|
|
17
|
+
)
|
|
18
|
+
from langchain_core.messages import AIMessage
|
|
19
|
+
from agent_api_server.shared.util_func import load_graph_config, load_graph, get_env
|
|
20
|
+
from agent_api_server.shared.message import (
|
|
21
|
+
message_generator,
|
|
22
|
+
langchain_to_chat_message
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
api_router = APIRouter()
|
|
27
|
+
|
|
28
|
+
THREAD_NOT_FOUND = "Specified thread does not exist"
|
|
29
|
+
CHECKPOINT_NOT_FOUND = "Checkpoint not found for this graph"
|
|
30
|
+
|
|
31
|
+
class TaskCancelledError(Exception):
|
|
32
|
+
"""Custom exception for cancelled tasks"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
class ThreadCancellationManager:
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self._events: Dict[str, asyncio.Event] = {}
|
|
38
|
+
self._lock = asyncio.Lock()
|
|
39
|
+
|
|
40
|
+
async def get_event(self, thread_id: str) -> asyncio.Event:
|
|
41
|
+
async with self._lock:
|
|
42
|
+
if thread_id not in self._events:
|
|
43
|
+
self._events[thread_id] = asyncio.Event()
|
|
44
|
+
return self._events[thread_id]
|
|
45
|
+
|
|
46
|
+
async def cancel(self, thread_id: str):
|
|
47
|
+
async with self._lock:
|
|
48
|
+
if thread_id in self._events:
|
|
49
|
+
self._events[thread_id].set()
|
|
50
|
+
del self._events[thread_id]
|
|
51
|
+
|
|
52
|
+
async def is_cancelled(self, thread_id: str) -> bool:
|
|
53
|
+
async with self._lock:
|
|
54
|
+
if thread_id in self._events:
|
|
55
|
+
return self._events[thread_id].is_set()
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
cancellation_manager = ThreadCancellationManager()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def _get_thread_or_404(thread_id: str, graph_name: Optional[str] = '') -> ThreadState:
|
|
62
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
63
|
+
thread_data = await storage.get_thread(thread_id)
|
|
64
|
+
if thread_data:
|
|
65
|
+
return thread_data
|
|
66
|
+
|
|
67
|
+
if not graph_name:
|
|
68
|
+
raise HTTPException(
|
|
69
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
|
70
|
+
detail={
|
|
71
|
+
"error": "not_found",
|
|
72
|
+
"message": THREAD_NOT_FOUND,
|
|
73
|
+
"thread_id": thread_id
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
logger.warning(f"Thread {thread_id} not found, creating with graph: {graph_name}")
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
new_thread = await storage.create_thread_with_id(thread_id=thread_id, graph_name=graph_name)
|
|
81
|
+
logger.info(f"Created thread {thread_id} for graph {graph_name}")
|
|
82
|
+
return new_thread
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error(f"Failed to create thread {thread_id}: {str(e)}")
|
|
85
|
+
raise HTTPException(
|
|
86
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
87
|
+
detail={
|
|
88
|
+
"error": "creation_failed",
|
|
89
|
+
"message": f"Failed to create thread: {str(e)}",
|
|
90
|
+
"graph_name": graph_name,
|
|
91
|
+
"thread_id": thread_id
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _get_run_config(thread_id: str, ts_tenant: str, ei_token: str, graph_name: str, files: List[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
96
|
+
config = {"configurable": dict(get_env(ts_tenant=ts_tenant))}
|
|
97
|
+
config["configurable"]["graph_name"] = graph_name
|
|
98
|
+
config["configurable"]["thread_id"] = thread_id
|
|
99
|
+
config["configurable"]["TSTenant"] = ts_tenant
|
|
100
|
+
config["configurable"]["EIToken"] = ei_token
|
|
101
|
+
config["configurable"]["files"] = files or []
|
|
102
|
+
return config
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def _check_stop_flag(thread_id: str, storage: AsyncRedisThreadStorage):
|
|
106
|
+
"""Check stop flag and raise if cancelled"""
|
|
107
|
+
if await storage.should_stop(thread_id):
|
|
108
|
+
await storage.update_thread(thread_id, status="cancelled")
|
|
109
|
+
raise TaskCancelledError(f"Thread {thread_id} was cancelled")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
async def _monitored_stream(
|
|
113
|
+
thread_id: str,
|
|
114
|
+
generator: AsyncGenerator,
|
|
115
|
+
storage: AsyncRedisThreadStorage
|
|
116
|
+
) -> AsyncGenerator:
|
|
117
|
+
"""Wrap generator with stop flag checking"""
|
|
118
|
+
try:
|
|
119
|
+
async for item in generator:
|
|
120
|
+
# 添加轻微延迟避免tight loop
|
|
121
|
+
await asyncio.sleep(0.05)
|
|
122
|
+
await _check_stop_flag(thread_id, storage)
|
|
123
|
+
yield item
|
|
124
|
+
except TaskCancelledError:
|
|
125
|
+
logger.info(f"Stream for thread {thread_id} was cancelled")
|
|
126
|
+
await storage.clear_stop_flag(thread_id)
|
|
127
|
+
yield "data: [CANCELED]\n\n"
|
|
128
|
+
raise
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"Stream error for thread {thread_id}: {str(e)}")
|
|
131
|
+
await storage.clear_stop_flag(thread_id)
|
|
132
|
+
yield "data: [CANCELED]\n\n"
|
|
133
|
+
raise
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@api_router.get(
|
|
137
|
+
"/{thread_id}/status",
|
|
138
|
+
response_model=ThreadInfo,
|
|
139
|
+
summary="Get thread status",
|
|
140
|
+
responses={
|
|
141
|
+
404: {"model": error_response, "description": "Thread not found"},
|
|
142
|
+
500: {"model": error_response, "description": "Internal server error"}
|
|
143
|
+
}
|
|
144
|
+
)
|
|
145
|
+
async def get_thread_status(thread_id: str) -> ThreadInfo:
|
|
146
|
+
"""Get the current status of a processing thread."""
|
|
147
|
+
try:
|
|
148
|
+
state = await _get_thread_or_404(thread_id)
|
|
149
|
+
return ThreadInfo(
|
|
150
|
+
thread_id=state.thread_id,
|
|
151
|
+
graph_name=state.graph_name,
|
|
152
|
+
status=state.status,
|
|
153
|
+
)
|
|
154
|
+
except HTTPException:
|
|
155
|
+
raise
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error(f"Failed to get thread status: {str(e)}", exc_info=True)
|
|
158
|
+
raise HTTPException(
|
|
159
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
160
|
+
detail={
|
|
161
|
+
"error": "status_check_failed",
|
|
162
|
+
"message": "Failed to retrieve thread status",
|
|
163
|
+
"thread_id": thread_id
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@api_router.post(
|
|
169
|
+
"/",
|
|
170
|
+
response_model=ThreadInfo,
|
|
171
|
+
status_code=status.HTTP_201_CREATED,
|
|
172
|
+
summary="Create new processing thread",
|
|
173
|
+
responses={
|
|
174
|
+
422: {"model": error_response, "description": "Invalid configuration"},
|
|
175
|
+
500: {"model": error_response, "description": "Internal server error"}
|
|
176
|
+
}
|
|
177
|
+
)
|
|
178
|
+
async def create_thread(graph_name: str, thread_id: Optional[str] = None) -> ThreadInfo:
|
|
179
|
+
"""Create a new processing thread with the specified graph."""
|
|
180
|
+
if thread_id:
|
|
181
|
+
try:
|
|
182
|
+
state = await _get_thread_or_404(thread_id)
|
|
183
|
+
logger.info(f"{thread_id} already exists, returning existing thread")
|
|
184
|
+
return ThreadInfo(
|
|
185
|
+
thread_id=state.thread_id,
|
|
186
|
+
graph_name=state.graph_name,
|
|
187
|
+
status=state.status,
|
|
188
|
+
)
|
|
189
|
+
except HTTPException:
|
|
190
|
+
pass # Continue to create new thread if not exists
|
|
191
|
+
|
|
192
|
+
validated_name, _, _ = await load_graph(graph_name, await load_graph_config(), False)
|
|
193
|
+
try:
|
|
194
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
195
|
+
thread_state = await storage.create_thread(graph_name=validated_name)
|
|
196
|
+
|
|
197
|
+
return ThreadInfo(
|
|
198
|
+
thread_id=thread_state.thread_id,
|
|
199
|
+
graph_name=thread_state.graph_name,
|
|
200
|
+
status=thread_state.status,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
except ValueError as e:
|
|
204
|
+
raise HTTPException(
|
|
205
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
206
|
+
detail={
|
|
207
|
+
"error": "invalid_config",
|
|
208
|
+
"message": str(e),
|
|
209
|
+
"graph_name": graph_name
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.critical(f"Thread creation failed: {traceback.format_exc()}")
|
|
214
|
+
raise HTTPException(
|
|
215
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
216
|
+
detail={
|
|
217
|
+
"error": "creation_failed",
|
|
218
|
+
"message": "Internal server error",
|
|
219
|
+
"graph_name": graph_name
|
|
220
|
+
}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
@api_router.post(
|
|
224
|
+
"/{thread_id}/stream",
|
|
225
|
+
response_class=StreamingResponse,
|
|
226
|
+
responses=sse_response_example(),
|
|
227
|
+
summary="Stream messages for a thread"
|
|
228
|
+
)
|
|
229
|
+
async def stream(
|
|
230
|
+
thread_id: str,
|
|
231
|
+
request: Request,
|
|
232
|
+
inputs: Dict[str, Any] = Body(..., embed=True),
|
|
233
|
+
files: List[Dict[str, Any]] = Body(default=[], embed=True)
|
|
234
|
+
) -> StreamingResponse:
|
|
235
|
+
"""Stream messages for a processing thread."""
|
|
236
|
+
get_graph_name = request.headers.get('X-Agent-Name', '')
|
|
237
|
+
cookie_header = request.headers.get("cookie", "")
|
|
238
|
+
cookie = SimpleCookie()
|
|
239
|
+
cookie.load(cookie_header)
|
|
240
|
+
ts_tenant = cookie.get("TSTenant").value if "TSTenant" in cookie else None
|
|
241
|
+
ei_token = cookie.get("EIToken").value if "EIToken" in cookie else None
|
|
242
|
+
|
|
243
|
+
state = await _get_thread_or_404(thread_id, get_graph_name)
|
|
244
|
+
if state.status == "running":
|
|
245
|
+
raise HTTPException(
|
|
246
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
247
|
+
detail={
|
|
248
|
+
"error": "already_running",
|
|
249
|
+
"message": "Thread is already running",
|
|
250
|
+
"thread_id": thread_id
|
|
251
|
+
}
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if get_graph_name != '':
|
|
255
|
+
logger.info(f"get graph name {get_graph_name} from header is not None, so use this graph name to load and chat")
|
|
256
|
+
state.graph_name = get_graph_name
|
|
257
|
+
|
|
258
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
259
|
+
monitored_gen = _monitored_stream(
|
|
260
|
+
thread_id,
|
|
261
|
+
message_generator(state=state, ts_tenant=ts_tenant, ei_token=ei_token, inputs=inputs, files=files),
|
|
262
|
+
storage
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return StreamingResponse(
|
|
266
|
+
monitored_gen,
|
|
267
|
+
media_type="text/event-stream",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@api_router.post(
|
|
272
|
+
"/{thread_id}/run",
|
|
273
|
+
response_model=RunResponse,
|
|
274
|
+
summary="Execute processing thread",
|
|
275
|
+
responses={
|
|
276
|
+
404: {"model": error_response, "description": "Thread not found"},
|
|
277
|
+
408: {"model": error_response, "description": "Request timeout"},
|
|
278
|
+
500: {"model": error_response, "description": "Execution failed"}
|
|
279
|
+
},
|
|
280
|
+
deprecated=True
|
|
281
|
+
)
|
|
282
|
+
async def run_thread(
|
|
283
|
+
thread_id: str,
|
|
284
|
+
request: Request,
|
|
285
|
+
inputs: Dict[str, Any] = Body(..., embed=True),
|
|
286
|
+
files: List[Dict[str, Any]] = Body(default=[], embed=True)
|
|
287
|
+
) -> RunResponse:
|
|
288
|
+
"""
|
|
289
|
+
Execute processing thread with enhanced error handling and timeout control.
|
|
290
|
+
"""
|
|
291
|
+
get_graph_name = request.headers.get('X-Agent-Name', '')
|
|
292
|
+
cookie_header = request.headers.get("cookie", "")
|
|
293
|
+
cookie = SimpleCookie()
|
|
294
|
+
cookie.load(cookie_header)
|
|
295
|
+
ts_tenant = cookie.get("TSTenant").value if "TSTenant" in cookie else None
|
|
296
|
+
ei_token = cookie.get("EIToken").value if "EIToken" in cookie else None
|
|
297
|
+
|
|
298
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
299
|
+
try:
|
|
300
|
+
state = await _get_thread_or_404(thread_id, get_graph_name)
|
|
301
|
+
if state.status == "running":
|
|
302
|
+
raise HTTPException(
|
|
303
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
304
|
+
detail={
|
|
305
|
+
"error": "already_running",
|
|
306
|
+
"message": "Thread is already running",
|
|
307
|
+
"thread_id": thread_id
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if get_graph_name != '':
|
|
312
|
+
logger.info(
|
|
313
|
+
f"get graph name {get_graph_name} from header is not None, so use this graph name to load and chat")
|
|
314
|
+
state.graph_name = get_graph_name
|
|
315
|
+
|
|
316
|
+
async with AsyncExitStack() as stack:
|
|
317
|
+
try:
|
|
318
|
+
_, graph_instance, _ = await stack.enter_async_context(
|
|
319
|
+
timeout_context(10.0, load_graph(state.graph_name, await load_graph_config(), True))
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
await storage.clear_stop_flag(thread_id)
|
|
324
|
+
if not await storage.update_thread(thread_id, status="running"):
|
|
325
|
+
raise HTTPException(status_code=500, detail="Failed to update thread status")
|
|
326
|
+
except Exception as e:
|
|
327
|
+
logger.error(f"Redis operation failed: {str(e)}")
|
|
328
|
+
raise HTTPException(status_code=500, detail="Storage operation failed")
|
|
329
|
+
|
|
330
|
+
_, response_events = await stack.enter_async_context(
|
|
331
|
+
timeout_context(300.0, _execute_graph_with_redis_check(
|
|
332
|
+
graph_instance=graph_instance,
|
|
333
|
+
run_cfg=_get_run_config(thread_id=thread_id, ei_token=ei_token, ts_tenant=ts_tenant, graph_name=state.graph_name, files=files),
|
|
334
|
+
inputs=inputs,
|
|
335
|
+
thread_id=thread_id,
|
|
336
|
+
storage=storage
|
|
337
|
+
))
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
response_type, response = response_events[-1]
|
|
341
|
+
if "__interrupt__" in response:
|
|
342
|
+
output = langchain_to_chat_message(
|
|
343
|
+
AIMessage(content=response["__interrupt__"][0].value)
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
await storage.update_thread(thread_id, status="completed")
|
|
347
|
+
return RunResponse(
|
|
348
|
+
status="completed",
|
|
349
|
+
thread_id=thread_id,
|
|
350
|
+
result=output
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
output = langchain_to_chat_message(response["messages"][-1])
|
|
354
|
+
await storage.update_thread(thread_id, status="completed")
|
|
355
|
+
return RunResponse(
|
|
356
|
+
status="completed",
|
|
357
|
+
thread_id=thread_id,
|
|
358
|
+
result=output
|
|
359
|
+
)
|
|
360
|
+
except TaskCancelledError:
|
|
361
|
+
logger.info(f"Thread {thread_id} was cancelled")
|
|
362
|
+
await storage.clear_stop_flag(thread_id)
|
|
363
|
+
await storage.update_thread(thread_id, status="completed")
|
|
364
|
+
raise HTTPException(
|
|
365
|
+
status_code=status.HTTP_200_OK,
|
|
366
|
+
detail={
|
|
367
|
+
"error": "cancelled",
|
|
368
|
+
"thread_id": thread_id,
|
|
369
|
+
"message": "Execution was cancelled by user"
|
|
370
|
+
}
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
except asyncio.TimeoutError as te:
|
|
374
|
+
await storage.update_thread(thread_id, status="completed")
|
|
375
|
+
raise HTTPException(
|
|
376
|
+
status_code=status.HTTP_408_REQUEST_TIMEOUT,
|
|
377
|
+
detail={
|
|
378
|
+
"error": "timeout",
|
|
379
|
+
"thread_id": thread_id,
|
|
380
|
+
"message": f"Operation timed out after {300 if 'execute' in str(te) else 10} seconds"
|
|
381
|
+
}
|
|
382
|
+
)
|
|
383
|
+
except HTTPException:
|
|
384
|
+
raise
|
|
385
|
+
except Exception as e:
|
|
386
|
+
logger.error(f"Execution failed: {thread_id}", exc_info=True)
|
|
387
|
+
await storage.update_thread(thread_id, status="completed")
|
|
388
|
+
raise HTTPException(
|
|
389
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
390
|
+
detail={
|
|
391
|
+
"error": "execution_failed",
|
|
392
|
+
"thread_id": thread_id,
|
|
393
|
+
"message": str(e)[:200]
|
|
394
|
+
}
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
async def _execute_graph_with_redis_check(
|
|
398
|
+
graph_instance,
|
|
399
|
+
run_cfg: Dict[str, Any],
|
|
400
|
+
inputs: Dict[str, Any],
|
|
401
|
+
thread_id: str,
|
|
402
|
+
storage: AsyncRedisThreadStorage
|
|
403
|
+
) -> Tuple[str, Any]:
|
|
404
|
+
execute_task = None
|
|
405
|
+
try:
|
|
406
|
+
current_state = await graph_instance.aget_state(config=run_cfg)
|
|
407
|
+
invoke_input = (
|
|
408
|
+
Command(resume=inputs.get("resume", ""))
|
|
409
|
+
if any(getattr(task, "interrupts", None) for task in current_state.tasks)
|
|
410
|
+
else inputs or {}
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
execute_task = asyncio.create_task(
|
|
414
|
+
graph_instance.ainvoke(
|
|
415
|
+
invoke_input,
|
|
416
|
+
config=run_cfg,
|
|
417
|
+
stream_mode=["values"]
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
while not execute_task.done():
|
|
422
|
+
if await storage.should_stop(thread_id):
|
|
423
|
+
execute_task.cancel()
|
|
424
|
+
try:
|
|
425
|
+
await execute_task
|
|
426
|
+
except asyncio.CancelledError:
|
|
427
|
+
pass
|
|
428
|
+
raise TaskCancelledError(f"Thread {thread_id} was cancelled")
|
|
429
|
+
await asyncio.sleep(0.1)
|
|
430
|
+
|
|
431
|
+
return "completed", await execute_task
|
|
432
|
+
|
|
433
|
+
except Exception as e:
|
|
434
|
+
if execute_task and not execute_task.done():
|
|
435
|
+
execute_task.cancel()
|
|
436
|
+
try:
|
|
437
|
+
await execute_task
|
|
438
|
+
except:
|
|
439
|
+
pass
|
|
440
|
+
if isinstance(e, TaskCancelledError):
|
|
441
|
+
raise
|
|
442
|
+
logger.error(f"Execution failed: {str(e)}")
|
|
443
|
+
raise
|
|
444
|
+
|
|
445
|
+
@asynccontextmanager
|
|
446
|
+
async def timeout_context(timeout: float, coro):
|
|
447
|
+
"""Safe timeout wrapper with proper cleanup."""
|
|
448
|
+
try:
|
|
449
|
+
yield await asyncio.wait_for(coro, timeout=timeout)
|
|
450
|
+
except asyncio.TimeoutError as te:
|
|
451
|
+
# Enhance timeout error with context
|
|
452
|
+
raise asyncio.TimeoutError(f"{coro.__name__} timeout after {timeout}s") from te
|
|
453
|
+
|
|
454
|
+
@api_router.post(
|
|
455
|
+
"/{thread_id}/stop",
|
|
456
|
+
response_model=Dict[str, Any],
|
|
457
|
+
summary="Stop a processing thread",
|
|
458
|
+
responses={
|
|
459
|
+
404: {"model": error_response, "description": "Thread not found"},
|
|
460
|
+
500: {"model": error_response, "description": "Execution failed"}
|
|
461
|
+
}
|
|
462
|
+
)
|
|
463
|
+
async def stop_thread(thread_id: str) -> Dict[str, Any]:
|
|
464
|
+
"""Stop a processing thread."""
|
|
465
|
+
state = await _get_thread_or_404(thread_id)
|
|
466
|
+
if state.status != "running":
|
|
467
|
+
raise HTTPException(
|
|
468
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
469
|
+
detail={
|
|
470
|
+
"error": "invalid_status",
|
|
471
|
+
"message": f"Thread must be in 'running' state to stop. Current status: {state.status}",
|
|
472
|
+
"thread_id": thread_id
|
|
473
|
+
}
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
477
|
+
# Atomic stop operation
|
|
478
|
+
success = await storage.set_stop_flag(thread_id)
|
|
479
|
+
if not success:
|
|
480
|
+
raise HTTPException(
|
|
481
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
482
|
+
detail={
|
|
483
|
+
"error": "stop_failed",
|
|
484
|
+
"message": "Failed to set stop flag",
|
|
485
|
+
"thread_id": thread_id
|
|
486
|
+
}
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
return {
|
|
490
|
+
"status": "success",
|
|
491
|
+
"detail": {
|
|
492
|
+
"thread_id": thread_id,
|
|
493
|
+
"graph_name": state.graph_name,
|
|
494
|
+
"status": "stop"
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
@api_router.get(
|
|
500
|
+
"/",
|
|
501
|
+
response_model=Dict[str, ThreadInfo],
|
|
502
|
+
summary="List active threads",
|
|
503
|
+
responses={
|
|
504
|
+
500: {"model": error_response, "description": "Internal server error"}
|
|
505
|
+
}
|
|
506
|
+
)
|
|
507
|
+
async def list_threads() -> Dict[str, ThreadInfo]:
|
|
508
|
+
"""List all active processing threads."""
|
|
509
|
+
try:
|
|
510
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
511
|
+
threads = await storage.list_threads()
|
|
512
|
+
return {
|
|
513
|
+
tid: ThreadInfo(
|
|
514
|
+
thread_id=state.thread_id,
|
|
515
|
+
graph_name=state.graph_name,
|
|
516
|
+
status=state.status,
|
|
517
|
+
)
|
|
518
|
+
for tid, state in threads.items()
|
|
519
|
+
}
|
|
520
|
+
except Exception as e:
|
|
521
|
+
logger.error(f"Failed to list threads: {str(e)}", exc_info=True)
|
|
522
|
+
raise HTTPException(
|
|
523
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
524
|
+
detail={
|
|
525
|
+
"error": "list_failed",
|
|
526
|
+
"message": "Failed to retrieve thread list"
|
|
527
|
+
}
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@api_router.delete(
|
|
532
|
+
"/{thread_id}",
|
|
533
|
+
response_model=Dict[str, Any],
|
|
534
|
+
summary="Terminate processing thread",
|
|
535
|
+
responses={
|
|
536
|
+
404: {"model": error_response, "description": "Thread not found"},
|
|
537
|
+
500: {"model": error_response, "description": "Deletion failed"}
|
|
538
|
+
}
|
|
539
|
+
)
|
|
540
|
+
async def delete_thread(thread_id: str) -> Dict[str, Any]:
|
|
541
|
+
"""Delete a processing thread."""
|
|
542
|
+
state = await _get_thread_or_404(thread_id)
|
|
543
|
+
storage = AsyncRedisThreadStorage.get_worker_instance()
|
|
544
|
+
success = await storage.delete_thread(thread_id)
|
|
545
|
+
if not success:
|
|
546
|
+
raise HTTPException(
|
|
547
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
548
|
+
detail={
|
|
549
|
+
"error": "deletion_failed",
|
|
550
|
+
"message": "Thread deletion failed",
|
|
551
|
+
"thread_id": thread_id
|
|
552
|
+
}
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
logger.info(f"Thread terminated: {thread_id}")
|
|
556
|
+
return {
|
|
557
|
+
"status": "success",
|
|
558
|
+
"detail": {
|
|
559
|
+
"thread_id": thread_id,
|
|
560
|
+
"graph_name": state.graph_name,
|
|
561
|
+
"status": "terminated"
|
|
562
|
+
}
|
|
563
|
+
}
|
|
File without changes
|