agent-api-server 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. agent_api_server/__init__.py +0 -0
  2. agent_api_server/api/__init__.py +0 -0
  3. agent_api_server/api/v1/__init__.py +0 -0
  4. agent_api_server/api/v1/api.py +25 -0
  5. agent_api_server/api/v1/config.py +57 -0
  6. agent_api_server/api/v1/graph.py +59 -0
  7. agent_api_server/api/v1/schema.py +57 -0
  8. agent_api_server/api/v1/thread.py +563 -0
  9. agent_api_server/cache/__init__.py +0 -0
  10. agent_api_server/cache/redis_cache.py +385 -0
  11. agent_api_server/callback_handler.py +18 -0
  12. agent_api_server/client/css/styles.css +1202 -0
  13. agent_api_server/client/favicon.ico +0 -0
  14. agent_api_server/client/index.html +102 -0
  15. agent_api_server/client/js/app.js +1499 -0
  16. agent_api_server/client/js/index.umd.js +824 -0
  17. agent_api_server/config_center/config_center.py +239 -0
  18. agent_api_server/configs/__init__.py +3 -0
  19. agent_api_server/configs/config.py +163 -0
  20. agent_api_server/dynamic_llm/__init__.py +0 -0
  21. agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
  22. agent_api_server/listener.py +530 -0
  23. agent_api_server/log/__init__.py +0 -0
  24. agent_api_server/log/formatters.py +122 -0
  25. agent_api_server/log/logging.json +50 -0
  26. agent_api_server/mcp_convert/__init__.py +0 -0
  27. agent_api_server/mcp_convert/mcp_convert.py +375 -0
  28. agent_api_server/memeory/__init__.py +0 -0
  29. agent_api_server/memeory/postgres.py +233 -0
  30. agent_api_server/register/__init__.py +0 -0
  31. agent_api_server/register/register.py +65 -0
  32. agent_api_server/service.py +354 -0
  33. agent_api_server/service_hub/service_hub.py +233 -0
  34. agent_api_server/service_hub/service_hub_test.py +700 -0
  35. agent_api_server/shared/__init__.py +0 -0
  36. agent_api_server/shared/ase.py +54 -0
  37. agent_api_server/shared/base_model.py +103 -0
  38. agent_api_server/shared/common.py +110 -0
  39. agent_api_server/shared/decode_token.py +107 -0
  40. agent_api_server/shared/detect_message.py +410 -0
  41. agent_api_server/shared/get_model_info.py +491 -0
  42. agent_api_server/shared/message.py +419 -0
  43. agent_api_server/shared/util_func.py +372 -0
  44. agent_api_server/sso_service/__init__.py +1 -0
  45. agent_api_server/sso_service/sdk/__init__.py +1 -0
  46. agent_api_server/sso_service/sdk/client.py +224 -0
  47. agent_api_server/sso_service/sdk/credential.py +11 -0
  48. agent_api_server/sso_service/sdk/encoding.py +22 -0
  49. agent_api_server/sso_service/sso_service.py +177 -0
  50. agent_api_server-2.1.7.dist-info/METADATA +130 -0
  51. agent_api_server-2.1.7.dist-info/RECORD +52 -0
  52. agent_api_server-2.1.7.dist-info/WHEEL +4 -0
@@ -0,0 +1,419 @@
1
+ import os
2
+ import logging
3
+ import traceback
4
+ from datetime import datetime
5
+ import json
6
+ from typing import Dict, Any, AsyncGenerator, Union, List, Tuple
7
+ from agent_api_server.shared.util_func import load_graph_config, load_graph, get_env
8
+ from langgraph.types import StateSnapshot
9
+ from agent_api_server.cache.redis_cache import ThreadState
10
+ from langchain_core.messages import (
11
+ AIMessage,
12
+ BaseMessage,
13
+ HumanMessage,
14
+ ToolMessage, AIMessageChunk)
15
+ from .detect_message import detect_content_type
16
+ from langgraph.types import Interrupt, Command
17
+ from agent_api_server.shared.base_model import ChatMessage
18
+ from fastapi import Body, status, HTTPException
19
+ from langfuse.langchain import CallbackHandler
20
+ from agent_api_server.shared.decode_token import decode_jwt
21
+ from agent_api_server.cache.redis_cache import AsyncRedisThreadStorage
22
+ from langfuse import propagate_attributes
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Type aliases for better readability
28
+ MessageContent = Union[str, List[Any]]
29
+ EventData = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
30
+ StreamEvent = Union[Tuple[str, Any], Tuple[str, str, Any]]
31
+ SerializableData = Union[Dict[str, Any], List[Any], str, int, float, bool, None, datetime]
32
+
33
+
34
+ def format_state_snapshot(snapshot: StateSnapshot) -> Dict[str, Any]:
35
+ """Format state snapshot into a serializable dictionary."""
36
+ interrupts = []
37
+ if hasattr(snapshot, 'tasks'):
38
+ for task in snapshot.tasks:
39
+ if hasattr(task, 'interrupts'):
40
+ for i in task.interrupts:
41
+ if hasattr(i, 'value'):
42
+ interrupts.append({"value": i.value})
43
+
44
+ return {
45
+ "values": getattr(snapshot, 'values', None),
46
+ "next": getattr(snapshot, 'next', None),
47
+ "config": getattr(snapshot, 'config', None),
48
+ "interrupts": interrupts,
49
+ "parent_config": getattr(snapshot, 'parent_config', None),
50
+ "metadata": getattr(snapshot, 'metadata', None)
51
+ }
52
+
53
+
54
+ def convert_message_content(content: MessageContent) -> str:
55
+ """Convert message content to string representation."""
56
+ if isinstance(content, str):
57
+ return content
58
+ if isinstance(content, list):
59
+ return "".join(
60
+ item.get("text") if isinstance(item, dict) and item.get("type") == "text"
61
+ else str(item)
62
+ for item in content
63
+ )
64
+ return str(content)
65
+
66
+
67
+ def langchain_to_chat_message(message: BaseMessage) -> ChatMessage:
68
+ """Convert LangChain message to ChatMessage with proper type handling."""
69
+ content = convert_message_content(message.content)
70
+ content_type = detect_content_type(content)
71
+
72
+ # Handle specific message types
73
+ if isinstance(message, HumanMessage):
74
+ message_type = "human"
75
+ tool_calls = []
76
+ response_metadata = {}
77
+ references = []
78
+ elif isinstance(message, ToolMessage):
79
+ message_type = "tool"
80
+ tool_calls = []
81
+ response_metadata = {}
82
+ references = []
83
+ else:
84
+ message_type = "ai"
85
+ tool_calls = getattr(message, "tool_calls", [])
86
+ response_metadata = getattr(message, "response_metadata", {})
87
+ additional_kwargs = getattr(message, "additional_kwargs", {})
88
+ references = additional_kwargs.get("references", [])
89
+ return ChatMessage(
90
+ type=message_type,
91
+ content_type=content_type,
92
+ content=content,
93
+ tool_calls=tool_calls, # 确保是 list
94
+ response_metadata=response_metadata, # 确保是 dict
95
+ references=references
96
+ )
97
+
98
+
99
+ def serialize_data(data: Any) -> SerializableData:
100
+ """Recursively serialize data to JSON-compatible format."""
101
+ if isinstance(data, (str, int, float, bool)) or data is None:
102
+ return data
103
+ if isinstance(data, dict):
104
+ return {k: serialize_data(v) for k, v in data.items()}
105
+ if isinstance(data, (list, tuple, set)):
106
+ return [serialize_data(item) for item in data]
107
+ if isinstance(data, datetime):
108
+ return data.isoformat()
109
+
110
+ # Handle objects with serialization methods
111
+ if hasattr(data, 'model_dump'):
112
+ return serialize_data(data.model_dump())
113
+ if hasattr(data, 'dict'):
114
+ return serialize_data(data.dict())
115
+ if hasattr(data, '__dict__'):
116
+ return serialize_data(data.__dict__)
117
+
118
+ return str(data)
119
+
120
+
121
+ def format_sse_event(event_type: str, data: Any) -> str:
122
+ """Format data as Server-Sent Event (SSE)."""
123
+ try:
124
+ # 确保 event_data 是字典,且 data 能被安全合并
125
+ event_data = {"event": event_type}
126
+
127
+ # 如果 data 是字典,直接合并;否则作为 raw_data 字段
128
+ if isinstance(data, dict):
129
+ event_data.update(data)
130
+ else:
131
+ event_data["data"] = str(data) # 非字典类型转为字符串存储
132
+
133
+ payload = json.dumps(event_data, ensure_ascii=False)
134
+ return f"data: {payload}\n\n"
135
+ except (TypeError, ValueError) as e:
136
+ # 错误处理:返回带错误信息的 SSE
137
+ error_payload = {
138
+ "event": "error",
139
+ "exception": str(e),
140
+ "original_data": str(data)[:200] # 截断避免过长
141
+ }
142
+ payload = json.dumps(error_payload, ensure_ascii=False)
143
+ return f"data: {payload}\n\n"
144
+
145
+
146
+ async def process_node_updates(node: str, updates: Dict[str, Any]) -> AsyncGenerator[
147
+ str, None]:
148
+ try:
149
+ # Handle agent nodes specially
150
+ if node == "agent":
151
+ messages = updates.get("messages", [])
152
+ if messages:
153
+ message = messages[-1]
154
+ chat_msg = langchain_to_chat_message(message)
155
+ serialized_content = serialize_data(chat_msg)
156
+ else:
157
+ node_message = AIMessage(content=str(updates))
158
+ chat_msg = langchain_to_chat_message(node_message)
159
+ serialized_content = serialize_data(chat_msg)
160
+
161
+ yield format_sse_event("agent_message", {
162
+ "node": node,
163
+ "update_content": serialized_content,
164
+ "update_type": "complete",
165
+ "timestamp": datetime.now().isoformat()
166
+ })
167
+ elif node == "__interrupt__":
168
+ interrupt: Interrupt
169
+ for interrupt in updates:
170
+ node_message = AIMessage(content=str(interrupt.value))
171
+ chat_msg = langchain_to_chat_message(node_message)
172
+ serialized_content = serialize_data(chat_msg)
173
+
174
+ yield format_sse_event("interrupt_message", {
175
+ "node": node,
176
+ "update_content": serialized_content,
177
+ "update_type": "complete",
178
+ "timestamp": datetime.now().isoformat()
179
+ })
180
+ elif node == "tools":
181
+ messages = updates.get("messages", [])
182
+ if messages:
183
+ message = messages[-1]
184
+ chat_msg = langchain_to_chat_message(message)
185
+ serialized_content = serialize_data(chat_msg)
186
+ else:
187
+ node_message = AIMessage(content=str(updates))
188
+ chat_msg = langchain_to_chat_message(node_message)
189
+ serialized_content = serialize_data(chat_msg)
190
+
191
+ yield format_sse_event("tools_message", {
192
+ "node": node,
193
+ "update_content": serialized_content,
194
+ "update_type": "complete",
195
+ "timestamp": datetime.now().isoformat()
196
+ })
197
+ elif node == "supervisor":
198
+ messages = updates.get("messages", [])
199
+ if messages:
200
+ message = messages[-1]
201
+ chat_msg = langchain_to_chat_message(message)
202
+ serialized_content = serialize_data(chat_msg)
203
+ else:
204
+ node_message = AIMessage(content=str(updates))
205
+ chat_msg = langchain_to_chat_message(node_message)
206
+ serialized_content = serialize_data(chat_msg)
207
+
208
+ yield format_sse_event("supervisor_message", {
209
+ "node": node,
210
+ "update_content": serialized_content,
211
+ "update_type": "complete",
212
+ "timestamp": datetime.now().isoformat()
213
+ })
214
+ else:
215
+ messages = updates.get("messages", [])
216
+ if messages:
217
+ message = messages[-1]
218
+ chat_msg = langchain_to_chat_message(message)
219
+ serialized_content = serialize_data(chat_msg)
220
+ else:
221
+ node_message = AIMessage(content=str(updates))
222
+ chat_msg = langchain_to_chat_message(node_message)
223
+ serialized_content = serialize_data(chat_msg)
224
+
225
+ yield format_sse_event("node_message", {
226
+ "node": node,
227
+ "update_content": serialized_content,
228
+ "update_type": "complete",
229
+ "timestamp": datetime.now().isoformat()
230
+ })
231
+ except Exception as e:
232
+ logger.error(f"Failed to process {node} updates: {str(e)}", exc_info=True)
233
+ yield format_sse_event("error", {
234
+ "node": node,
235
+ "error": f"Message processing error: {str(e)}",
236
+ "timestamp": datetime.now().isoformat()
237
+ })
238
+
239
+
240
+ async def handle_stream_event(stream_event: StreamEvent) -> AsyncGenerator[str, None]:
241
+ """Process stream events and yield formatted SSE events."""
242
+ try:
243
+ if isinstance(stream_event, tuple) and len(stream_event) == 2:
244
+ stream_mode, event = stream_event
245
+
246
+ if stream_mode == "updates":
247
+ if isinstance(event, dict):
248
+ for node, updates in event.items():
249
+ async for chunk in process_node_updates(node, updates):
250
+ yield chunk
251
+ elif stream_mode == "messages":
252
+ if isinstance(event, tuple) and len(event) == 2:
253
+ message_chunk, metadata = event
254
+
255
+ is_ai_chunk = False
256
+ if isinstance(message_chunk, AIMessageChunk):
257
+ is_ai_chunk = True
258
+
259
+ if is_ai_chunk:
260
+ if message_chunk.content and message_chunk.content.strip():
261
+ yield format_sse_event("token_stream", {
262
+ "node": metadata.get('langgraph_node'),
263
+ "update_content": message_chunk.content,
264
+ "update_type": "start",
265
+ "timestamp": datetime.now().isoformat()
266
+ })
267
+
268
+ finish_reason = getattr(message_chunk, 'response_metadata', {}).get('finish_reason')
269
+ if finish_reason:
270
+ complete_message = AIMessage(
271
+ content=getattr(message_chunk, 'content', ''),
272
+ tool_calls=getattr(message_chunk, 'tool_calls', []),
273
+ response_metadata=getattr(message_chunk, 'response_metadata', {})
274
+ )
275
+
276
+ chat_msg = langchain_to_chat_message(complete_message)
277
+ serialized_content = serialize_data(chat_msg)
278
+
279
+ yield format_sse_event("token_stream", {
280
+ "node": metadata.get('langgraph_node'),
281
+ "update_type": "complete",
282
+ "update_content": message_chunk.content if hasattr(message_chunk, 'content') else "",
283
+ "timestamp": datetime.now().isoformat()
284
+ })
285
+ else:
286
+ logger.warning(f"Unexpected messages event format: {type(event)}")
287
+ elif stream_mode == "custom":
288
+ logger.info(f"receive custom stream event, event message is {event}")
289
+ yield event
290
+
291
+ elif isinstance(stream_event, tuple) and len(stream_event) == 3:
292
+ _, stream_mode, event = stream_event
293
+ async for chunk in handle_stream_event((stream_mode, event)):
294
+ yield chunk
295
+ else:
296
+ logger.warning(f"Unexpected event format: {type(stream_event)}")
297
+ yield format_sse_event("error", {
298
+ "error": f"Unexpected event format: {type(stream_event)}",
299
+ "event_data": str(stream_event)[:200]
300
+ })
301
+ except Exception as e:
302
+ logger.error(f"Error processing event: {str(e)}", exc_info=True)
303
+ yield format_sse_event("error", {
304
+ "error": f"Event processing error: {str(e)}",
305
+ "event_type": type(stream_event).__name__ if 'stream_event' in locals() else "unknown"
306
+ })
307
+
308
+
309
+ async def message_generator(
310
+ state: ThreadState,
311
+ ts_tenant: str,
312
+ ei_token: str,
313
+ inputs: Dict[str, Any] = Body(..., embed=True),
314
+ files: List[Dict[str, Any]] = Body(default=[], embed=True)
315
+ ) -> AsyncGenerator[str, None]:
316
+ """Generate message stream for the given thread."""
317
+ storage = AsyncRedisThreadStorage.get_worker_instance()
318
+
319
+ user_id = None
320
+ if ei_token:
321
+ try:
322
+ header, payload, user_info = decode_jwt(ei_token)
323
+ if user_info and 'id' in user_info:
324
+ user_id = user_info['id']
325
+ logger.info(f"Extracted user_id from token: {user_id}")
326
+ else:
327
+ logger.warning("No user_id found in JWT token")
328
+ except Exception as e:
329
+ logger.warning(f"Failed to decode JWT token: {str(e)}")
330
+
331
+ try:
332
+ _, graph_instance, checkpointer = await load_graph(state.graph_name, await load_graph_config(), True)
333
+ data = await storage.update_thread(state.thread_id, status="running")
334
+ if not data:
335
+ raise HTTPException(
336
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
337
+ detail={
338
+ "error": "update_failed",
339
+ "message": "Failed to update thread status",
340
+ "thread_id": state.thread_id
341
+ }
342
+ )
343
+
344
+ subgraph_names = {name for name, _ in graph_instance.get_subgraphs()}
345
+ logger.info(f"Found graphs: {subgraph_names}, resume message is: {inputs.get('resume', '')}")
346
+
347
+ configurable_params = {
348
+ **dict(get_env(ts_tenant=ts_tenant)),
349
+ "graph_name": state.graph_name,
350
+ "thread_id": state.thread_id,
351
+ "TSTenant": ts_tenant,
352
+ "EIToken": ei_token,
353
+ "files": files or []
354
+ }
355
+
356
+ langfuse_keys = [
357
+ os.getenv("LANGFUSE_SECRET_KEY"),
358
+ os.getenv("LANGFUSE_PUBLIC_KEY"),
359
+ os.getenv("LANGFUSE_BASE_URL")
360
+ ]
361
+
362
+ callbacks_config = {}
363
+ if all(langfuse_keys):
364
+ langfuse_handler = CallbackHandler()
365
+ callbacks_config = {"callbacks": [langfuse_handler], "run_name": state.graph_name}
366
+
367
+ config = {"configurable": configurable_params, **callbacks_config}
368
+
369
+ state_get = await graph_instance.aget_state(
370
+ config={"configurable": {**dict(get_env(ts_tenant=ts_tenant)), "thread_id": state.thread_id}}
371
+ )
372
+ interrupted_tasks = [
373
+ task for task in state_get.tasks if hasattr(task, "interrupts") and task.interrupts
374
+ ]
375
+
376
+ async def process_stream():
377
+ if interrupted_tasks:
378
+ logger.info(f"find resume message, so run {state.thread_id} for graph {state.graph_name} with resume")
379
+ with propagate_attributes(session_id=state.thread_id,
380
+ user_id=user_id, trace_name=f"{state.graph_name}") if user_id else propagate_attributes(
381
+ session_id=state.thread_id, trace_name=f"{state.graph_name}"):
382
+ async for stream_event in graph_instance.astream(
383
+ Command(resume=inputs.get("resume", "")),
384
+ config=config,
385
+ stream_mode=["updates", "messages", "custom"],
386
+ subgraphs=True
387
+ ):
388
+ async for chunk in handle_stream_event(stream_event):
389
+ yield chunk
390
+
391
+ else:
392
+ logger.info(f"do not find resume message, so first run {state.thread_id} for graph {state.graph_name}")
393
+ with propagate_attributes(session_id=state.thread_id,
394
+ user_id=user_id, trace_name=f"{state.graph_name}") if user_id else propagate_attributes(
395
+ session_id=state.thread_id, trace_name=f"{state.graph_name}"):
396
+ async for stream_event in graph_instance.astream(
397
+ inputs or {},
398
+ config=config,
399
+ stream_mode=["updates", "messages", "custom"],
400
+ subgraphs=True
401
+ ):
402
+ async for chunk in handle_stream_event(stream_event):
403
+ yield chunk
404
+
405
+ async for chunk in process_stream():
406
+ yield chunk
407
+
408
+ except HTTPException as e:
409
+ yield format_sse_event("error", f"Stream processing failed: {e.detail}")
410
+ except (RuntimeError, ValueError) as e:
411
+ yield format_sse_event("error", f"Stream processing failed: {e}")
412
+ except Exception as e:
413
+ logging.error("Stream processing failed:\n%s", traceback.format_exc())
414
+ yield format_sse_event("error", f"Stream processing failed: {str(e)}")
415
+ finally:
416
+ data = await storage.update_thread(state.thread_id, status="complete")
417
+ if not data:
418
+ yield format_sse_event("error", f"Failed to update thread status")
419
+ yield "data: [DONE]\n\n"