agent-api-server 2.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_api_server/__init__.py +0 -0
- agent_api_server/api/__init__.py +0 -0
- agent_api_server/api/v1/__init__.py +0 -0
- agent_api_server/api/v1/api.py +25 -0
- agent_api_server/api/v1/config.py +57 -0
- agent_api_server/api/v1/graph.py +59 -0
- agent_api_server/api/v1/schema.py +57 -0
- agent_api_server/api/v1/thread.py +563 -0
- agent_api_server/cache/__init__.py +0 -0
- agent_api_server/cache/redis_cache.py +385 -0
- agent_api_server/callback_handler.py +18 -0
- agent_api_server/client/css/styles.css +1202 -0
- agent_api_server/client/favicon.ico +0 -0
- agent_api_server/client/index.html +102 -0
- agent_api_server/client/js/app.js +1499 -0
- agent_api_server/client/js/index.umd.js +824 -0
- agent_api_server/config_center/config_center.py +239 -0
- agent_api_server/configs/__init__.py +3 -0
- agent_api_server/configs/config.py +163 -0
- agent_api_server/dynamic_llm/__init__.py +0 -0
- agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
- agent_api_server/listener.py +530 -0
- agent_api_server/log/__init__.py +0 -0
- agent_api_server/log/formatters.py +122 -0
- agent_api_server/log/logging.json +50 -0
- agent_api_server/mcp_convert/__init__.py +0 -0
- agent_api_server/mcp_convert/mcp_convert.py +375 -0
- agent_api_server/memeory/__init__.py +0 -0
- agent_api_server/memeory/postgres.py +233 -0
- agent_api_server/register/__init__.py +0 -0
- agent_api_server/register/register.py +65 -0
- agent_api_server/service.py +354 -0
- agent_api_server/service_hub/service_hub.py +233 -0
- agent_api_server/service_hub/service_hub_test.py +700 -0
- agent_api_server/shared/__init__.py +0 -0
- agent_api_server/shared/ase.py +54 -0
- agent_api_server/shared/base_model.py +103 -0
- agent_api_server/shared/common.py +110 -0
- agent_api_server/shared/decode_token.py +107 -0
- agent_api_server/shared/detect_message.py +410 -0
- agent_api_server/shared/get_model_info.py +491 -0
- agent_api_server/shared/message.py +419 -0
- agent_api_server/shared/util_func.py +372 -0
- agent_api_server/sso_service/__init__.py +1 -0
- agent_api_server/sso_service/sdk/__init__.py +1 -0
- agent_api_server/sso_service/sdk/client.py +224 -0
- agent_api_server/sso_service/sdk/credential.py +11 -0
- agent_api_server/sso_service/sdk/encoding.py +22 -0
- agent_api_server/sso_service/sso_service.py +177 -0
- agent_api_server-2.1.7.dist-info/METADATA +130 -0
- agent_api_server-2.1.7.dist-info/RECORD +52 -0
- agent_api_server-2.1.7.dist-info/WHEEL +4 -0
|
File without changes
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import json
|
|
5
|
+
from typing import Dict, Any, Optional, List, Callable, Awaitable
|
|
6
|
+
from functools import wraps, partial
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from langfuse import propagate_attributes
|
|
9
|
+
from langfuse.langchain import CallbackHandler
|
|
10
|
+
from agent_api_server.shared.decode_token import decode_jwt
|
|
11
|
+
from langgraph.config import get_stream_writer
|
|
12
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
13
|
+
from fastmcp import FastMCP, Context
|
|
14
|
+
from mcp import types
|
|
15
|
+
from mcp.types import RequestParams, CallToolResult, TextContent
|
|
16
|
+
from agent_api_server.shared.message import handle_stream_event
|
|
17
|
+
from agent_api_server.shared.util_func import load_graph_config, load_graph, get_env
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
async def handle_messages(message):
|
|
23
|
+
writer = get_stream_writer()
|
|
24
|
+
if isinstance(message, types.ServerNotification):
|
|
25
|
+
notification = message.root
|
|
26
|
+
if isinstance(notification, types.ProgressNotification):
|
|
27
|
+
params = notification.params
|
|
28
|
+
logger.info(f"receive progress message: ({params.message})")
|
|
29
|
+
writer(params.message)
|
|
30
|
+
|
|
31
|
+
async def create_mcp_tool_from_agent() -> FastMCP:
|
|
32
|
+
"""Create and configure MCP tools from LangGraph agent configurations.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
FastMCP: Configured MCP instance with all graph tools registered.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
RuntimeError: If tool creation fails due to configuration or loading errors.
|
|
39
|
+
"""
|
|
40
|
+
mcp = FastMCP(name="langgraph_mcp_tool")
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
graph_cfg = await load_graph_config()
|
|
44
|
+
|
|
45
|
+
# Validate configurations
|
|
46
|
+
graphs = graph_cfg.get("graphs", {})
|
|
47
|
+
if not isinstance(graphs, dict):
|
|
48
|
+
raise RuntimeError(f"Invalid graphs config. Expected dict, got {type(graphs)}")
|
|
49
|
+
|
|
50
|
+
agent_descriptions = graph_cfg.get("agent_description", {})
|
|
51
|
+
if not isinstance(agent_descriptions, dict):
|
|
52
|
+
logger.warning("agent_description should be dict, got %s", type(agent_descriptions))
|
|
53
|
+
agent_descriptions = {}
|
|
54
|
+
|
|
55
|
+
for graph_name, graph_path in graphs.items():
|
|
56
|
+
_, graph_instance, _ = await load_graph(graph_name, graph_cfg, False)
|
|
57
|
+
|
|
58
|
+
if not isinstance(graph_instance, CompiledStateGraph):
|
|
59
|
+
raise RuntimeError(f"Graph {graph_name} is not CompiledStateGraph")
|
|
60
|
+
|
|
61
|
+
tool_impl = partial(
|
|
62
|
+
_execute_graph_tool,
|
|
63
|
+
mcp=mcp,
|
|
64
|
+
graph_name=graph_name,
|
|
65
|
+
graph_instance=graph_instance
|
|
66
|
+
)
|
|
67
|
+
tool_impl.__name__ = graph_name
|
|
68
|
+
|
|
69
|
+
tool = create_tool_from_schema(
|
|
70
|
+
schema=graph_instance.get_input_jsonschema(),
|
|
71
|
+
func_name=graph_name,
|
|
72
|
+
func_doc=agent_descriptions.get(graph_name, ""),
|
|
73
|
+
implementation=tool_impl
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
mcp.tool()(_add_metrics_wrapper(tool, graph_name))
|
|
77
|
+
logger.info("Registered tool: %s", graph_name)
|
|
78
|
+
|
|
79
|
+
return mcp
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error("MCP tool creation failed", exc_info=True)
|
|
82
|
+
raise RuntimeError(f"MCP tool creation failed: {str(e)}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def _execute_graph_tool(
|
|
86
|
+
mcp: FastMCP,
|
|
87
|
+
graph_name: str,
|
|
88
|
+
graph_instance: CompiledStateGraph,
|
|
89
|
+
**kwargs: Any
|
|
90
|
+
) -> Dict[str, Any]:
|
|
91
|
+
"""Execute a LangGraph agent and stream results.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
mcp: FastMCP instance
|
|
95
|
+
graph_name: Name of the graph to execute
|
|
96
|
+
graph_instance: Compiled graph instance
|
|
97
|
+
**kwargs: Input parameters
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Execution result dict
|
|
101
|
+
"""
|
|
102
|
+
ctx = Context(fastmcp=mcp)
|
|
103
|
+
from fastmcp.server.dependencies import get_http_request
|
|
104
|
+
request = get_http_request()
|
|
105
|
+
|
|
106
|
+
use_sys_llm = request.headers.get("UseSysLLM", "")
|
|
107
|
+
ts_tenant = request.headers.get("TSTenant", "")
|
|
108
|
+
ei_token = request.headers.get("Authorization", "")
|
|
109
|
+
thread_id = request.headers.get('thread_id', '')
|
|
110
|
+
start_time = datetime.now()
|
|
111
|
+
|
|
112
|
+
logger.info(f"Executing graph '{graph_name}' for tenant '{ts_tenant}', ei_token is '{ei_token}', use_sys_llm is {use_sys_llm}")
|
|
113
|
+
|
|
114
|
+
# Validate input
|
|
115
|
+
schema = graph_instance.get_input_jsonschema()
|
|
116
|
+
input_dict = {}
|
|
117
|
+
validation_errors = []
|
|
118
|
+
|
|
119
|
+
for field, prop in schema.get("properties", {}).items():
|
|
120
|
+
if field in kwargs:
|
|
121
|
+
try:
|
|
122
|
+
input_dict[field] = _validate_field(field, kwargs[field], prop)
|
|
123
|
+
except ValueError as e:
|
|
124
|
+
validation_errors.append(str(e))
|
|
125
|
+
elif field in schema.get("required", []):
|
|
126
|
+
validation_errors.append(f"Missing required field: {field}")
|
|
127
|
+
|
|
128
|
+
if validation_errors:
|
|
129
|
+
error_msg = "Validation errors:\n" + "\n".join(validation_errors)
|
|
130
|
+
logger.error(error_msg)
|
|
131
|
+
return CallToolResult(
|
|
132
|
+
content=[TextContent(type="text", text=error_msg)],
|
|
133
|
+
isError=True
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if thread_id is not None and thread_id != "":
|
|
137
|
+
logger.info(f"execute graph {graph_name} with thread id {thread_id} which get from context headers")
|
|
138
|
+
else:
|
|
139
|
+
thread_id = ctx.session_id
|
|
140
|
+
logger.info(f"execute graph {graph_name} with session id {ctx.session_id}")
|
|
141
|
+
|
|
142
|
+
user_id = None
|
|
143
|
+
if ei_token:
|
|
144
|
+
try:
|
|
145
|
+
header, payload, user_info = decode_jwt(ei_token)
|
|
146
|
+
if user_info and 'id' in user_info:
|
|
147
|
+
user_id = user_info['id']
|
|
148
|
+
logger.info(f"Extracted user_id from token: {user_id}")
|
|
149
|
+
else:
|
|
150
|
+
logger.warning("No user_id found in JWT token")
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.warning(f"Failed to decode JWT token: {str(e)}")
|
|
153
|
+
|
|
154
|
+
# Build configurable parameters
|
|
155
|
+
configurable_params = {
|
|
156
|
+
"use_sys_llm": use_sys_llm,
|
|
157
|
+
"thread_id": thread_id,
|
|
158
|
+
"TSTenant": ts_tenant,
|
|
159
|
+
"EIToken": ei_token,
|
|
160
|
+
"graph_name": graph_name
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if ts_tenant is not None and ts_tenant != "":
|
|
164
|
+
configurable_params = {
|
|
165
|
+
**dict(get_env(ts_tenant=ts_tenant)),
|
|
166
|
+
**configurable_params
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
logger.info(f"finally configurable parameter is {configurable_params}")
|
|
170
|
+
|
|
171
|
+
# Configure Langfuse callbacks if available
|
|
172
|
+
langfuse_keys = [
|
|
173
|
+
os.getenv("LANGFUSE_SECRET_KEY"),
|
|
174
|
+
os.getenv("LANGFUSE_PUBLIC_KEY"),
|
|
175
|
+
os.getenv("LANGFUSE_BASE_URL")
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
callbacks_config = {}
|
|
179
|
+
if all(langfuse_keys):
|
|
180
|
+
langfuse_handler = CallbackHandler()
|
|
181
|
+
callbacks_config = {"callbacks": [langfuse_handler], "run_name": f"{graph_name}_MCP_Call"}
|
|
182
|
+
|
|
183
|
+
chunks = []
|
|
184
|
+
config = {"configurable": configurable_params, **callbacks_config}
|
|
185
|
+
try:
|
|
186
|
+
async def process_stream():
|
|
187
|
+
if user_id:
|
|
188
|
+
logger.info(f"execute graph {graph_name} with user_id: {user_id}")
|
|
189
|
+
with propagate_attributes(session_id=thread_id, user_id=user_id, trace_name=f"{graph_name}_MCP_Call"):
|
|
190
|
+
async for stream_event in graph_instance.astream(
|
|
191
|
+
input_dict,
|
|
192
|
+
config=config,
|
|
193
|
+
stream_mode=["updates"],
|
|
194
|
+
subgraphs=True
|
|
195
|
+
):
|
|
196
|
+
async for chunk in handle_stream_event(stream_event):
|
|
197
|
+
yield chunk
|
|
198
|
+
|
|
199
|
+
else:
|
|
200
|
+
logger.info(f"execute graph {graph_name} without user_id")
|
|
201
|
+
with propagate_attributes(session_id=thread_id, trace_name=f"{graph_name}_MCP_Call"):
|
|
202
|
+
async for stream_event in graph_instance.astream(
|
|
203
|
+
input_dict,
|
|
204
|
+
config=config,
|
|
205
|
+
stream_mode=["updates"],
|
|
206
|
+
subgraphs=True
|
|
207
|
+
):
|
|
208
|
+
async for chunk in handle_stream_event(stream_event):
|
|
209
|
+
yield chunk
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
async for chunk in process_stream():
|
|
213
|
+
ctx.request_context.meta = RequestParams.Meta(progressToken=ctx.request_id)
|
|
214
|
+
await ctx.report_progress(message=chunk, progress=len(chunks))
|
|
215
|
+
chunks.append(chunk)
|
|
216
|
+
|
|
217
|
+
if not chunks:
|
|
218
|
+
raise ValueError("No response from graph execution")
|
|
219
|
+
|
|
220
|
+
last_chunk = chunks[-1].strip()
|
|
221
|
+
if not last_chunk.startswith("data: "):
|
|
222
|
+
raise ValueError("Invalid response format")
|
|
223
|
+
|
|
224
|
+
data = json.loads(last_chunk[6:])
|
|
225
|
+
if not data.get("update_content", {}).get("content"):
|
|
226
|
+
raise ValueError("Missing content in response")
|
|
227
|
+
|
|
228
|
+
logger.info(
|
|
229
|
+
"Graph '%s' executed in %.2fs (%d chunks)",
|
|
230
|
+
graph_name,
|
|
231
|
+
(datetime.now() - start_time).total_seconds(),
|
|
232
|
+
len(chunks)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return CallToolResult(
|
|
236
|
+
content=[TextContent(type="text", text=data["update_content"]["content"])],
|
|
237
|
+
isError=False
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
except Exception as err:
|
|
241
|
+
logger.error("Graph execution failed: %s", str(err), exc_info=True)
|
|
242
|
+
return CallToolResult(
|
|
243
|
+
content=[TextContent(type="text", text=f"Execution failed: {str(err)}")],
|
|
244
|
+
isError=True
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def create_tool_from_schema(
|
|
249
|
+
schema: Dict[str, Any],
|
|
250
|
+
func_name: str,
|
|
251
|
+
func_doc: str = "",
|
|
252
|
+
implementation: Optional[Callable[..., Awaitable[Dict[str, Any]]]] = None
|
|
253
|
+
) -> Callable[..., Awaitable[Dict[str, Any]]]:
|
|
254
|
+
properties = schema.get("properties", {})
|
|
255
|
+
required = schema.get("required", [])
|
|
256
|
+
type_map = {
|
|
257
|
+
"string": str,
|
|
258
|
+
"number": float,
|
|
259
|
+
"integer": int,
|
|
260
|
+
"boolean": bool,
|
|
261
|
+
"array": List,
|
|
262
|
+
"object": Dict
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
params = []
|
|
266
|
+
type_hints = {}
|
|
267
|
+
for param_name, param_info in properties.items():
|
|
268
|
+
param_type = type_map.get(param_info.get("type"), Any)
|
|
269
|
+
|
|
270
|
+
if param_name not in required:
|
|
271
|
+
if 'anyOf' in param_info:
|
|
272
|
+
first_non_null_type = None
|
|
273
|
+
for item in param_info['anyOf']:
|
|
274
|
+
item_type = item.get('type')
|
|
275
|
+
if item_type and item_type != 'null':
|
|
276
|
+
first_non_null_type = item_type
|
|
277
|
+
break
|
|
278
|
+
|
|
279
|
+
if first_non_null_type:
|
|
280
|
+
non_null_param_type = type_map.get(first_non_null_type, Any)
|
|
281
|
+
param_type = non_null_param_type
|
|
282
|
+
|
|
283
|
+
type_hints[param_name] = param_type
|
|
284
|
+
params.append(
|
|
285
|
+
inspect.Parameter(
|
|
286
|
+
param_name,
|
|
287
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
288
|
+
default=None if param_name not in required else inspect.Parameter.empty,
|
|
289
|
+
annotation=param_type,
|
|
290
|
+
)
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
if implementation is None:
|
|
294
|
+
async def default_impl(**kwargs: Any) -> Dict[str, Any]:
|
|
295
|
+
missing = [f for f in required if f not in kwargs]
|
|
296
|
+
if missing:
|
|
297
|
+
raise ValueError(f"Missing fields: {missing}")
|
|
298
|
+
return {"content": [f"Default response for {func_name}"]}
|
|
299
|
+
|
|
300
|
+
implementation = default_impl
|
|
301
|
+
|
|
302
|
+
implementation.__name__ = func_name
|
|
303
|
+
implementation.__qualname__ = func_name
|
|
304
|
+
implementation.__doc__ = func_doc or f"Tool from schema: {schema.get('title', '')}"
|
|
305
|
+
implementation.__signature__ = inspect.Signature(
|
|
306
|
+
parameters=params,
|
|
307
|
+
return_annotation=Dict[str, Any],
|
|
308
|
+
)
|
|
309
|
+
implementation.__annotations__ = type_hints
|
|
310
|
+
|
|
311
|
+
return implementation
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _add_metrics_wrapper(
|
|
315
|
+
func: Callable[..., Awaitable[Dict[str, Any]]],
|
|
316
|
+
tool_name: str
|
|
317
|
+
) -> Callable[..., Awaitable[Dict[str, Any]]]:
|
|
318
|
+
|
|
319
|
+
@wraps(func)
|
|
320
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Dict[str, Any]:
|
|
321
|
+
start_time = datetime.now()
|
|
322
|
+
logger.info("Tool '%s' started", tool_name)
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
result = await func(*args, **kwargs)
|
|
326
|
+
logger.info(
|
|
327
|
+
"Tool '%s' completed in %.2fs",
|
|
328
|
+
tool_name,
|
|
329
|
+
(datetime.now() - start_time).total_seconds()
|
|
330
|
+
)
|
|
331
|
+
return result
|
|
332
|
+
except Exception as e:
|
|
333
|
+
logger.error(
|
|
334
|
+
"Tool '%s' failed after %.2fs: %s",
|
|
335
|
+
tool_name,
|
|
336
|
+
(datetime.now() - start_time).total_seconds(),
|
|
337
|
+
str(e),
|
|
338
|
+
exc_info=True
|
|
339
|
+
)
|
|
340
|
+
raise
|
|
341
|
+
|
|
342
|
+
return wrapper
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _validate_field(
|
|
346
|
+
field_name: str,
|
|
347
|
+
value: Any,
|
|
348
|
+
schema: Dict[str, Any]
|
|
349
|
+
) -> Any:
|
|
350
|
+
field_type = schema.get("type")
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
if field_type == "string":
|
|
354
|
+
if not isinstance(value, str):
|
|
355
|
+
value = str(value)
|
|
356
|
+
if "enum" in schema and value not in schema["enum"]:
|
|
357
|
+
raise ValueError(f"Value not in {schema['enum']}")
|
|
358
|
+
elif field_type == "number":
|
|
359
|
+
value = float(value)
|
|
360
|
+
elif field_type == "integer":
|
|
361
|
+
value = int(value)
|
|
362
|
+
elif field_type == "boolean":
|
|
363
|
+
if isinstance(value, str):
|
|
364
|
+
value = value.lower() in ("true", "1", "yes")
|
|
365
|
+
value = bool(value)
|
|
366
|
+
elif field_type == "array":
|
|
367
|
+
if not isinstance(value, list):
|
|
368
|
+
raise ValueError("Expected list")
|
|
369
|
+
elif field_type == "object":
|
|
370
|
+
if not isinstance(value, dict):
|
|
371
|
+
raise ValueError("Expected dict")
|
|
372
|
+
except (ValueError, TypeError) as e:
|
|
373
|
+
raise ValueError(f"Invalid value for '{field_name}': {str(e)}")
|
|
374
|
+
|
|
375
|
+
return value
|
|
File without changes
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import asyncio
|
|
3
|
+
import logging
|
|
4
|
+
import weakref
|
|
5
|
+
from typing import Optional, cast, Dict, Any
|
|
6
|
+
from psycopg_pool import AsyncConnectionPool
|
|
7
|
+
from psycopg.rows import dict_row
|
|
8
|
+
from psycopg import OperationalError, AsyncConnection
|
|
9
|
+
from agent_api_server.configs import global_config
|
|
10
|
+
from tenacity import (
|
|
11
|
+
retry,
|
|
12
|
+
stop_after_attempt,
|
|
13
|
+
wait_exponential,
|
|
14
|
+
retry_if_exception_type,
|
|
15
|
+
)
|
|
16
|
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
17
|
+
from langgraph.store.postgres import AsyncPostgresStore
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
class AsyncPostgresCheckpointer:
|
|
22
|
+
_instances = weakref.WeakSet()
|
|
23
|
+
_shutdown_lock = asyncio.Lock()
|
|
24
|
+
_worker_instances = weakref.WeakValueDictionary()
|
|
25
|
+
|
|
26
|
+
def __init__(self, max_retries: int = 3, retry_delay: float = 1.0):
|
|
27
|
+
self.worker_pid = os.getpid()
|
|
28
|
+
logger.info(
|
|
29
|
+
f"Initializing checkpointer for worker {self.worker_pid} "
|
|
30
|
+
f"(max_retries={max_retries}, retry_delay={retry_delay:.1f}s)"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.conn_str = self._get_conn_str()
|
|
34
|
+
self.pool: Optional[AsyncConnectionPool] = None
|
|
35
|
+
self.saver: Optional[AsyncPostgresSaver] = None
|
|
36
|
+
self.max_retries = max(max_retries, 1)
|
|
37
|
+
self.retry_delay = max(retry_delay, 0.1)
|
|
38
|
+
self._lock = asyncio.Lock()
|
|
39
|
+
self._is_initialized = False
|
|
40
|
+
|
|
41
|
+
# Connection pool configuration with validation
|
|
42
|
+
self.pool_min_size = max(1, int(global_config.POSTGRES_POOL_MIN_SIZE))
|
|
43
|
+
self.pool_max_size = max(self.pool_min_size, int(global_config.POSTGRES_POOL_MAX_SIZE))
|
|
44
|
+
self.pool_timeout = max(1.0, float(global_config.POSTGRES_POOL_TIMEOUT))
|
|
45
|
+
self.pool_recycle = max(60, int(global_config.POSTGRES_POOL_RECYCLE))
|
|
46
|
+
|
|
47
|
+
logger.debug(
|
|
48
|
+
f"Worker {self.worker_pid} pool config: "
|
|
49
|
+
f"min_size={self.pool_min_size}, max_size={self.pool_max_size}, "
|
|
50
|
+
f"timeout={self.pool_timeout:.1f}, recycle={self.pool_recycle}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
AsyncPostgresCheckpointer._instances.add(self)
|
|
54
|
+
AsyncPostgresCheckpointer._worker_instances[self.worker_pid] = self
|
|
55
|
+
logger.debug(f"New checkpointer instance registered for worker {self.worker_pid}")
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def get_worker_instance(cls) -> 'AsyncPostgresCheckpointer':
|
|
59
|
+
"""Get or create instance for current worker process"""
|
|
60
|
+
worker_pid = os.getpid()
|
|
61
|
+
if worker_pid not in cls._worker_instances:
|
|
62
|
+
instance = cls()
|
|
63
|
+
cls._worker_instances[worker_pid] = instance
|
|
64
|
+
logger.debug(f"Created new checkpointer instance for worker {worker_pid}")
|
|
65
|
+
return cls._worker_instances[worker_pid]
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
async def close_worker_instance(cls):
|
|
69
|
+
"""Close instance for current worker process"""
|
|
70
|
+
worker_pid = os.getpid()
|
|
71
|
+
if worker_pid in cls._worker_instances:
|
|
72
|
+
instance = cls._worker_instances.pop(worker_pid)
|
|
73
|
+
await instance.close()
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _get_conn_str() -> str:
|
|
77
|
+
"""获取安全的连接字符串(隐藏密码)"""
|
|
78
|
+
conn_str = global_config.POSTGRES_URL
|
|
79
|
+
safe_conn_str = conn_str.split('@')[0] + '@[REDACTED]' if '@' in conn_str else conn_str
|
|
80
|
+
logger.debug("Using connection string: %s", safe_conn_str)
|
|
81
|
+
return conn_str
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
async def close_all(cls):
|
|
85
|
+
logger.info("Initiating shutdown of all checkpointer instances")
|
|
86
|
+
|
|
87
|
+
async with cls._shutdown_lock:
|
|
88
|
+
instances = list(cls._instances)
|
|
89
|
+
if not instances:
|
|
90
|
+
logger.debug("No active instances to close")
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
logger.debug("Closing %d active instance(s)", len(instances))
|
|
94
|
+
for instance in instances:
|
|
95
|
+
try:
|
|
96
|
+
await instance.close()
|
|
97
|
+
logger.debug("Instance closed successfully")
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.warning(
|
|
100
|
+
"Error closing instance: %s",
|
|
101
|
+
str(e),
|
|
102
|
+
exc_info=logger.isEnabledFor(logging.DEBUG)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
async def initialize(self) -> None:
|
|
106
|
+
"""Initialize connection pool and database schema"""
|
|
107
|
+
async with self._lock:
|
|
108
|
+
if self._is_initialized:
|
|
109
|
+
logger.debug("Already initialized, skipping")
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
logger.info("Starting connection pool initialization")
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
self.pool = AsyncConnectionPool(
|
|
116
|
+
self.conn_str,
|
|
117
|
+
min_size=self.pool_min_size,
|
|
118
|
+
max_size=self.pool_max_size,
|
|
119
|
+
timeout=self.pool_timeout,
|
|
120
|
+
open=False,
|
|
121
|
+
kwargs={"row_factory": dict_row},
|
|
122
|
+
)
|
|
123
|
+
logger.debug("Connection pool created")
|
|
124
|
+
|
|
125
|
+
await self.pool.open()
|
|
126
|
+
logger.debug("Connection pool opened")
|
|
127
|
+
|
|
128
|
+
pool_with_dict_conn = cast(AsyncConnectionPool[AsyncConnection[Dict[str, Any]]], self.pool)
|
|
129
|
+
self.saver = AsyncPostgresSaver(pool_with_dict_conn)
|
|
130
|
+
self._is_initialized = True
|
|
131
|
+
logger.info("Checkpointer initialized successfully")
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error("Initialization failed", exc_info=True)
|
|
135
|
+
await self._safe_close()
|
|
136
|
+
raise RuntimeError("Checkpointer initialization failed") from e
|
|
137
|
+
|
|
138
|
+
@retry(
|
|
139
|
+
stop=stop_after_attempt(3),
|
|
140
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
141
|
+
retry=retry_if_exception_type(OperationalError),
|
|
142
|
+
before_sleep=lambda _: logger.warning("Retrying schema initialization..."),
|
|
143
|
+
)
|
|
144
|
+
async def initialize_schema_with_retry(self):
|
|
145
|
+
logger.debug("Starting schema initialization")
|
|
146
|
+
|
|
147
|
+
async with self.pool.connection() as conn:
|
|
148
|
+
try:
|
|
149
|
+
await conn.set_autocommit(True)
|
|
150
|
+
logger.debug("Autocommit enabled for schema setup")
|
|
151
|
+
|
|
152
|
+
await AsyncPostgresStore(conn).setup()
|
|
153
|
+
await AsyncPostgresSaver(conn).setup()
|
|
154
|
+
logger.debug("Schema setup completed")
|
|
155
|
+
|
|
156
|
+
finally:
|
|
157
|
+
await conn.set_autocommit(False)
|
|
158
|
+
logger.debug("Autocommit restored to default")
|
|
159
|
+
|
|
160
|
+
@retry(
|
|
161
|
+
stop=stop_after_attempt(3),
|
|
162
|
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
|
163
|
+
retry=retry_if_exception_type(OperationalError),
|
|
164
|
+
)
|
|
165
|
+
async def checkpointer(self) -> AsyncPostgresSaver | None:
|
|
166
|
+
if not self._is_initialized or self.saver is None:
|
|
167
|
+
logger.debug("Checkpointer not ready, initializing...")
|
|
168
|
+
await self.initialize()
|
|
169
|
+
return self.saver
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
async with self.pool.connection() as conn:
|
|
173
|
+
await conn.execute("SELECT 1") # 简单心跳检测
|
|
174
|
+
except OperationalError:
|
|
175
|
+
logger.warning("Connection pool invalid, reinitializing...")
|
|
176
|
+
await self._safe_close()
|
|
177
|
+
await self.initialize()
|
|
178
|
+
|
|
179
|
+
return self.saver
|
|
180
|
+
|
|
181
|
+
async def _safe_close(self) -> None:
|
|
182
|
+
logger.debug("Starting safe cleanup")
|
|
183
|
+
|
|
184
|
+
if self.pool:
|
|
185
|
+
try:
|
|
186
|
+
if not self.pool.closed:
|
|
187
|
+
await self.pool.close()
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.warning(f"Error closing pool with an error {str(e)}", exc_info=True)
|
|
190
|
+
|
|
191
|
+
self.pool = None
|
|
192
|
+
self.saver = None
|
|
193
|
+
self._is_initialized = False
|
|
194
|
+
logger.info("Resources cleaned up")
|
|
195
|
+
|
|
196
|
+
async def close(self) -> None:
|
|
197
|
+
"""关闭所有连接并清理资源"""
|
|
198
|
+
async with self._lock:
|
|
199
|
+
if not self._is_initialized:
|
|
200
|
+
logger.debug("Already closed, skipping")
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
logger.info("Starting graceful shutdown")
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
if self.pool and not self.pool.closed:
|
|
207
|
+
logger.debug("Closing connection pool gracefully")
|
|
208
|
+
await self.pool.close()
|
|
209
|
+
logger.info("Connection pool closed gracefully")
|
|
210
|
+
|
|
211
|
+
except Exception as e:
|
|
212
|
+
logger.error(
|
|
213
|
+
"Error during shutdown: %s",
|
|
214
|
+
str(e),
|
|
215
|
+
exc_info=logger.isEnabledFor(logging.DEBUG)
|
|
216
|
+
)
|
|
217
|
+
finally:
|
|
218
|
+
await self._safe_close()
|
|
219
|
+
AsyncPostgresCheckpointer._instances.discard(self)
|
|
220
|
+
logger.debug("Instance unregistered")
|
|
221
|
+
|
|
222
|
+
async def __aenter__(self) -> AsyncPostgresSaver:
|
|
223
|
+
logger.debug("Entering context manager")
|
|
224
|
+
await self.initialize()
|
|
225
|
+
return await self.checkpointer()
|
|
226
|
+
|
|
227
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
228
|
+
logger.debug("Exiting context manager")
|
|
229
|
+
await self.close()
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def is_initialized(self):
|
|
233
|
+
return self._is_initialized
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
from agent_api_server.shared.common import process_model_from_config_dict
|
|
4
|
+
from agent_api_server.shared.util_func import load_graph_config, load_graph
|
|
5
|
+
from model_manage_client import ModelManageClient
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
class AgentRegistry:
|
|
10
|
+
def __init__(self, base_url:str, client_token:str):
|
|
11
|
+
self.client = ModelManageClient(
|
|
12
|
+
base_url=base_url,
|
|
13
|
+
client_token=client_token
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
async def get_llm_model_from_agent_cfg(agent: Dict) -> Dict:
|
|
18
|
+
graph_cfg = await load_graph_config()
|
|
19
|
+
_, graph_instance, _ = await load_graph(agent['agent_name'], graph_cfg, False)
|
|
20
|
+
|
|
21
|
+
cfg = graph_instance.get_context_jsonschema()
|
|
22
|
+
|
|
23
|
+
logger.info(f"Get agent config {cfg} from {agent['agent_name']}")
|
|
24
|
+
|
|
25
|
+
return process_model_from_config_dict(cfg)
|
|
26
|
+
|
|
27
|
+
async def register_all(self, agents: List[Dict]):
|
|
28
|
+
for agent in agents:
|
|
29
|
+
try:
|
|
30
|
+
model_info = await self.get_llm_model_from_agent_cfg(agent)
|
|
31
|
+
|
|
32
|
+
extra_params = {
|
|
33
|
+
"agent_description": agent["agent_description"],
|
|
34
|
+
"agent_icon_url": agent["agent_icon_url"],
|
|
35
|
+
"agent_api_version": agent["agent_api_version"],
|
|
36
|
+
"agent_features": agent["agent_features"],
|
|
37
|
+
"agent_labels": agent["agent_labels"],
|
|
38
|
+
"support_models": model_info,
|
|
39
|
+
"has_site": agent["has_site"],
|
|
40
|
+
"is_system_agent": agent["is_system_agent"],
|
|
41
|
+
"multilangs": agent.get("multilangs", {})
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
agent_find = self.client.get_agent(agent_name=agent["agent_name"])
|
|
45
|
+
if agent_find:
|
|
46
|
+
logger.info(f"agent: {agent['agent_name']} already register successfully, update agent with extra_params {extra_params}")
|
|
47
|
+
|
|
48
|
+
self.client.update_agent(
|
|
49
|
+
agent_name=agent["agent_name"],
|
|
50
|
+
agent_url=agent["agent_url"],
|
|
51
|
+
**extra_params
|
|
52
|
+
)
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
self.client.register_agent(
|
|
56
|
+
agent_name=agent["agent_name"],
|
|
57
|
+
agent_id=agent["agent_name"],
|
|
58
|
+
agent_url=agent["agent_url"],
|
|
59
|
+
**extra_params
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
logger.info(
|
|
63
|
+
f"agent: {agent['agent_name']} register successfully, register agent with extra_params {extra_params}")
|
|
64
|
+
except Exception as e:
|
|
65
|
+
raise e
|