agent-api-server 2.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. agent_api_server/__init__.py +0 -0
  2. agent_api_server/api/__init__.py +0 -0
  3. agent_api_server/api/v1/__init__.py +0 -0
  4. agent_api_server/api/v1/api.py +25 -0
  5. agent_api_server/api/v1/config.py +57 -0
  6. agent_api_server/api/v1/graph.py +59 -0
  7. agent_api_server/api/v1/schema.py +57 -0
  8. agent_api_server/api/v1/thread.py +563 -0
  9. agent_api_server/cache/__init__.py +0 -0
  10. agent_api_server/cache/redis_cache.py +385 -0
  11. agent_api_server/callback_handler.py +18 -0
  12. agent_api_server/client/css/styles.css +1202 -0
  13. agent_api_server/client/favicon.ico +0 -0
  14. agent_api_server/client/index.html +102 -0
  15. agent_api_server/client/js/app.js +1499 -0
  16. agent_api_server/client/js/index.umd.js +824 -0
  17. agent_api_server/config_center/config_center.py +239 -0
  18. agent_api_server/configs/__init__.py +3 -0
  19. agent_api_server/configs/config.py +163 -0
  20. agent_api_server/dynamic_llm/__init__.py +0 -0
  21. agent_api_server/dynamic_llm/dynamic_llm.py +331 -0
  22. agent_api_server/listener.py +530 -0
  23. agent_api_server/log/__init__.py +0 -0
  24. agent_api_server/log/formatters.py +122 -0
  25. agent_api_server/log/logging.json +50 -0
  26. agent_api_server/mcp_convert/__init__.py +0 -0
  27. agent_api_server/mcp_convert/mcp_convert.py +375 -0
  28. agent_api_server/memeory/__init__.py +0 -0
  29. agent_api_server/memeory/postgres.py +233 -0
  30. agent_api_server/register/__init__.py +0 -0
  31. agent_api_server/register/register.py +65 -0
  32. agent_api_server/service.py +354 -0
  33. agent_api_server/service_hub/service_hub.py +233 -0
  34. agent_api_server/service_hub/service_hub_test.py +700 -0
  35. agent_api_server/shared/__init__.py +0 -0
  36. agent_api_server/shared/ase.py +54 -0
  37. agent_api_server/shared/base_model.py +103 -0
  38. agent_api_server/shared/common.py +110 -0
  39. agent_api_server/shared/decode_token.py +107 -0
  40. agent_api_server/shared/detect_message.py +410 -0
  41. agent_api_server/shared/get_model_info.py +491 -0
  42. agent_api_server/shared/message.py +419 -0
  43. agent_api_server/shared/util_func.py +372 -0
  44. agent_api_server/sso_service/__init__.py +1 -0
  45. agent_api_server/sso_service/sdk/__init__.py +1 -0
  46. agent_api_server/sso_service/sdk/client.py +224 -0
  47. agent_api_server/sso_service/sdk/credential.py +11 -0
  48. agent_api_server/sso_service/sdk/encoding.py +22 -0
  49. agent_api_server/sso_service/sso_service.py +177 -0
  50. agent_api_server-2.1.7.dist-info/METADATA +130 -0
  51. agent_api_server-2.1.7.dist-info/RECORD +52 -0
  52. agent_api_server-2.1.7.dist-info/WHEEL +4 -0
File without changes
@@ -0,0 +1,375 @@
1
+ import os
2
+ import inspect
3
+ import logging
4
+ import json
5
+ from typing import Dict, Any, Optional, List, Callable, Awaitable
6
+ from functools import wraps, partial
7
+ from datetime import datetime
8
+ from langfuse import propagate_attributes
9
+ from langfuse.langchain import CallbackHandler
10
+ from agent_api_server.shared.decode_token import decode_jwt
11
+ from langgraph.config import get_stream_writer
12
+ from langgraph.graph.state import CompiledStateGraph
13
+ from fastmcp import FastMCP, Context
14
+ from mcp import types
15
+ from mcp.types import RequestParams, CallToolResult, TextContent
16
+ from agent_api_server.shared.message import handle_stream_event
17
+ from agent_api_server.shared.util_func import load_graph_config, load_graph, get_env
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ async def handle_messages(message):
23
+ writer = get_stream_writer()
24
+ if isinstance(message, types.ServerNotification):
25
+ notification = message.root
26
+ if isinstance(notification, types.ProgressNotification):
27
+ params = notification.params
28
+ logger.info(f"receive progress message: ({params.message})")
29
+ writer(params.message)
30
+
31
+ async def create_mcp_tool_from_agent() -> FastMCP:
32
+ """Create and configure MCP tools from LangGraph agent configurations.
33
+
34
+ Returns:
35
+ FastMCP: Configured MCP instance with all graph tools registered.
36
+
37
+ Raises:
38
+ RuntimeError: If tool creation fails due to configuration or loading errors.
39
+ """
40
+ mcp = FastMCP(name="langgraph_mcp_tool")
41
+
42
+ try:
43
+ graph_cfg = await load_graph_config()
44
+
45
+ # Validate configurations
46
+ graphs = graph_cfg.get("graphs", {})
47
+ if not isinstance(graphs, dict):
48
+ raise RuntimeError(f"Invalid graphs config. Expected dict, got {type(graphs)}")
49
+
50
+ agent_descriptions = graph_cfg.get("agent_description", {})
51
+ if not isinstance(agent_descriptions, dict):
52
+ logger.warning("agent_description should be dict, got %s", type(agent_descriptions))
53
+ agent_descriptions = {}
54
+
55
+ for graph_name, graph_path in graphs.items():
56
+ _, graph_instance, _ = await load_graph(graph_name, graph_cfg, False)
57
+
58
+ if not isinstance(graph_instance, CompiledStateGraph):
59
+ raise RuntimeError(f"Graph {graph_name} is not CompiledStateGraph")
60
+
61
+ tool_impl = partial(
62
+ _execute_graph_tool,
63
+ mcp=mcp,
64
+ graph_name=graph_name,
65
+ graph_instance=graph_instance
66
+ )
67
+ tool_impl.__name__ = graph_name
68
+
69
+ tool = create_tool_from_schema(
70
+ schema=graph_instance.get_input_jsonschema(),
71
+ func_name=graph_name,
72
+ func_doc=agent_descriptions.get(graph_name, ""),
73
+ implementation=tool_impl
74
+ )
75
+
76
+ mcp.tool()(_add_metrics_wrapper(tool, graph_name))
77
+ logger.info("Registered tool: %s", graph_name)
78
+
79
+ return mcp
80
+ except Exception as e:
81
+ logger.error("MCP tool creation failed", exc_info=True)
82
+ raise RuntimeError(f"MCP tool creation failed: {str(e)}")
83
+
84
+
85
+ async def _execute_graph_tool(
86
+ mcp: FastMCP,
87
+ graph_name: str,
88
+ graph_instance: CompiledStateGraph,
89
+ **kwargs: Any
90
+ ) -> Dict[str, Any]:
91
+ """Execute a LangGraph agent and stream results.
92
+
93
+ Args:
94
+ mcp: FastMCP instance
95
+ graph_name: Name of the graph to execute
96
+ graph_instance: Compiled graph instance
97
+ **kwargs: Input parameters
98
+
99
+ Returns:
100
+ Execution result dict
101
+ """
102
+ ctx = Context(fastmcp=mcp)
103
+ from fastmcp.server.dependencies import get_http_request
104
+ request = get_http_request()
105
+
106
+ use_sys_llm = request.headers.get("UseSysLLM", "")
107
+ ts_tenant = request.headers.get("TSTenant", "")
108
+ ei_token = request.headers.get("Authorization", "")
109
+ thread_id = request.headers.get('thread_id', '')
110
+ start_time = datetime.now()
111
+
112
+ logger.info(f"Executing graph '{graph_name}' for tenant '{ts_tenant}', ei_token is '{ei_token}', use_sys_llm is {use_sys_llm}")
113
+
114
+ # Validate input
115
+ schema = graph_instance.get_input_jsonschema()
116
+ input_dict = {}
117
+ validation_errors = []
118
+
119
+ for field, prop in schema.get("properties", {}).items():
120
+ if field in kwargs:
121
+ try:
122
+ input_dict[field] = _validate_field(field, kwargs[field], prop)
123
+ except ValueError as e:
124
+ validation_errors.append(str(e))
125
+ elif field in schema.get("required", []):
126
+ validation_errors.append(f"Missing required field: {field}")
127
+
128
+ if validation_errors:
129
+ error_msg = "Validation errors:\n" + "\n".join(validation_errors)
130
+ logger.error(error_msg)
131
+ return CallToolResult(
132
+ content=[TextContent(type="text", text=error_msg)],
133
+ isError=True
134
+ )
135
+
136
+ if thread_id is not None and thread_id != "":
137
+ logger.info(f"execute graph {graph_name} with thread id {thread_id} which get from context headers")
138
+ else:
139
+ thread_id = ctx.session_id
140
+ logger.info(f"execute graph {graph_name} with session id {ctx.session_id}")
141
+
142
+ user_id = None
143
+ if ei_token:
144
+ try:
145
+ header, payload, user_info = decode_jwt(ei_token)
146
+ if user_info and 'id' in user_info:
147
+ user_id = user_info['id']
148
+ logger.info(f"Extracted user_id from token: {user_id}")
149
+ else:
150
+ logger.warning("No user_id found in JWT token")
151
+ except Exception as e:
152
+ logger.warning(f"Failed to decode JWT token: {str(e)}")
153
+
154
+ # Build configurable parameters
155
+ configurable_params = {
156
+ "use_sys_llm": use_sys_llm,
157
+ "thread_id": thread_id,
158
+ "TSTenant": ts_tenant,
159
+ "EIToken": ei_token,
160
+ "graph_name": graph_name
161
+ }
162
+
163
+ if ts_tenant is not None and ts_tenant != "":
164
+ configurable_params = {
165
+ **dict(get_env(ts_tenant=ts_tenant)),
166
+ **configurable_params
167
+ }
168
+
169
+ logger.info(f"finally configurable parameter is {configurable_params}")
170
+
171
+ # Configure Langfuse callbacks if available
172
+ langfuse_keys = [
173
+ os.getenv("LANGFUSE_SECRET_KEY"),
174
+ os.getenv("LANGFUSE_PUBLIC_KEY"),
175
+ os.getenv("LANGFUSE_BASE_URL")
176
+ ]
177
+
178
+ callbacks_config = {}
179
+ if all(langfuse_keys):
180
+ langfuse_handler = CallbackHandler()
181
+ callbacks_config = {"callbacks": [langfuse_handler], "run_name": f"{graph_name}_MCP_Call"}
182
+
183
+ chunks = []
184
+ config = {"configurable": configurable_params, **callbacks_config}
185
+ try:
186
+ async def process_stream():
187
+ if user_id:
188
+ logger.info(f"execute graph {graph_name} with user_id: {user_id}")
189
+ with propagate_attributes(session_id=thread_id, user_id=user_id, trace_name=f"{graph_name}_MCP_Call"):
190
+ async for stream_event in graph_instance.astream(
191
+ input_dict,
192
+ config=config,
193
+ stream_mode=["updates"],
194
+ subgraphs=True
195
+ ):
196
+ async for chunk in handle_stream_event(stream_event):
197
+ yield chunk
198
+
199
+ else:
200
+ logger.info(f"execute graph {graph_name} without user_id")
201
+ with propagate_attributes(session_id=thread_id, trace_name=f"{graph_name}_MCP_Call"):
202
+ async for stream_event in graph_instance.astream(
203
+ input_dict,
204
+ config=config,
205
+ stream_mode=["updates"],
206
+ subgraphs=True
207
+ ):
208
+ async for chunk in handle_stream_event(stream_event):
209
+ yield chunk
210
+
211
+
212
+ async for chunk in process_stream():
213
+ ctx.request_context.meta = RequestParams.Meta(progressToken=ctx.request_id)
214
+ await ctx.report_progress(message=chunk, progress=len(chunks))
215
+ chunks.append(chunk)
216
+
217
+ if not chunks:
218
+ raise ValueError("No response from graph execution")
219
+
220
+ last_chunk = chunks[-1].strip()
221
+ if not last_chunk.startswith("data: "):
222
+ raise ValueError("Invalid response format")
223
+
224
+ data = json.loads(last_chunk[6:])
225
+ if not data.get("update_content", {}).get("content"):
226
+ raise ValueError("Missing content in response")
227
+
228
+ logger.info(
229
+ "Graph '%s' executed in %.2fs (%d chunks)",
230
+ graph_name,
231
+ (datetime.now() - start_time).total_seconds(),
232
+ len(chunks)
233
+ )
234
+
235
+ return CallToolResult(
236
+ content=[TextContent(type="text", text=data["update_content"]["content"])],
237
+ isError=False
238
+ )
239
+
240
+ except Exception as err:
241
+ logger.error("Graph execution failed: %s", str(err), exc_info=True)
242
+ return CallToolResult(
243
+ content=[TextContent(type="text", text=f"Execution failed: {str(err)}")],
244
+ isError=True
245
+ )
246
+
247
+
248
+ def create_tool_from_schema(
249
+ schema: Dict[str, Any],
250
+ func_name: str,
251
+ func_doc: str = "",
252
+ implementation: Optional[Callable[..., Awaitable[Dict[str, Any]]]] = None
253
+ ) -> Callable[..., Awaitable[Dict[str, Any]]]:
254
+ properties = schema.get("properties", {})
255
+ required = schema.get("required", [])
256
+ type_map = {
257
+ "string": str,
258
+ "number": float,
259
+ "integer": int,
260
+ "boolean": bool,
261
+ "array": List,
262
+ "object": Dict
263
+ }
264
+
265
+ params = []
266
+ type_hints = {}
267
+ for param_name, param_info in properties.items():
268
+ param_type = type_map.get(param_info.get("type"), Any)
269
+
270
+ if param_name not in required:
271
+ if 'anyOf' in param_info:
272
+ first_non_null_type = None
273
+ for item in param_info['anyOf']:
274
+ item_type = item.get('type')
275
+ if item_type and item_type != 'null':
276
+ first_non_null_type = item_type
277
+ break
278
+
279
+ if first_non_null_type:
280
+ non_null_param_type = type_map.get(first_non_null_type, Any)
281
+ param_type = non_null_param_type
282
+
283
+ type_hints[param_name] = param_type
284
+ params.append(
285
+ inspect.Parameter(
286
+ param_name,
287
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
288
+ default=None if param_name not in required else inspect.Parameter.empty,
289
+ annotation=param_type,
290
+ )
291
+ )
292
+
293
+ if implementation is None:
294
+ async def default_impl(**kwargs: Any) -> Dict[str, Any]:
295
+ missing = [f for f in required if f not in kwargs]
296
+ if missing:
297
+ raise ValueError(f"Missing fields: {missing}")
298
+ return {"content": [f"Default response for {func_name}"]}
299
+
300
+ implementation = default_impl
301
+
302
+ implementation.__name__ = func_name
303
+ implementation.__qualname__ = func_name
304
+ implementation.__doc__ = func_doc or f"Tool from schema: {schema.get('title', '')}"
305
+ implementation.__signature__ = inspect.Signature(
306
+ parameters=params,
307
+ return_annotation=Dict[str, Any],
308
+ )
309
+ implementation.__annotations__ = type_hints
310
+
311
+ return implementation
312
+
313
+
314
+ def _add_metrics_wrapper(
315
+ func: Callable[..., Awaitable[Dict[str, Any]]],
316
+ tool_name: str
317
+ ) -> Callable[..., Awaitable[Dict[str, Any]]]:
318
+
319
+ @wraps(func)
320
+ async def wrapper(*args: Any, **kwargs: Any) -> Dict[str, Any]:
321
+ start_time = datetime.now()
322
+ logger.info("Tool '%s' started", tool_name)
323
+
324
+ try:
325
+ result = await func(*args, **kwargs)
326
+ logger.info(
327
+ "Tool '%s' completed in %.2fs",
328
+ tool_name,
329
+ (datetime.now() - start_time).total_seconds()
330
+ )
331
+ return result
332
+ except Exception as e:
333
+ logger.error(
334
+ "Tool '%s' failed after %.2fs: %s",
335
+ tool_name,
336
+ (datetime.now() - start_time).total_seconds(),
337
+ str(e),
338
+ exc_info=True
339
+ )
340
+ raise
341
+
342
+ return wrapper
343
+
344
+
345
+ def _validate_field(
346
+ field_name: str,
347
+ value: Any,
348
+ schema: Dict[str, Any]
349
+ ) -> Any:
350
+ field_type = schema.get("type")
351
+
352
+ try:
353
+ if field_type == "string":
354
+ if not isinstance(value, str):
355
+ value = str(value)
356
+ if "enum" in schema and value not in schema["enum"]:
357
+ raise ValueError(f"Value not in {schema['enum']}")
358
+ elif field_type == "number":
359
+ value = float(value)
360
+ elif field_type == "integer":
361
+ value = int(value)
362
+ elif field_type == "boolean":
363
+ if isinstance(value, str):
364
+ value = value.lower() in ("true", "1", "yes")
365
+ value = bool(value)
366
+ elif field_type == "array":
367
+ if not isinstance(value, list):
368
+ raise ValueError("Expected list")
369
+ elif field_type == "object":
370
+ if not isinstance(value, dict):
371
+ raise ValueError("Expected dict")
372
+ except (ValueError, TypeError) as e:
373
+ raise ValueError(f"Invalid value for '{field_name}': {str(e)}")
374
+
375
+ return value
File without changes
@@ -0,0 +1,233 @@
1
+ import os
2
+ import asyncio
3
+ import logging
4
+ import weakref
5
+ from typing import Optional, cast, Dict, Any
6
+ from psycopg_pool import AsyncConnectionPool
7
+ from psycopg.rows import dict_row
8
+ from psycopg import OperationalError, AsyncConnection
9
+ from agent_api_server.configs import global_config
10
+ from tenacity import (
11
+ retry,
12
+ stop_after_attempt,
13
+ wait_exponential,
14
+ retry_if_exception_type,
15
+ )
16
+ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
17
+ from langgraph.store.postgres import AsyncPostgresStore
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class AsyncPostgresCheckpointer:
22
+ _instances = weakref.WeakSet()
23
+ _shutdown_lock = asyncio.Lock()
24
+ _worker_instances = weakref.WeakValueDictionary()
25
+
26
+ def __init__(self, max_retries: int = 3, retry_delay: float = 1.0):
27
+ self.worker_pid = os.getpid()
28
+ logger.info(
29
+ f"Initializing checkpointer for worker {self.worker_pid} "
30
+ f"(max_retries={max_retries}, retry_delay={retry_delay:.1f}s)"
31
+ )
32
+
33
+ self.conn_str = self._get_conn_str()
34
+ self.pool: Optional[AsyncConnectionPool] = None
35
+ self.saver: Optional[AsyncPostgresSaver] = None
36
+ self.max_retries = max(max_retries, 1)
37
+ self.retry_delay = max(retry_delay, 0.1)
38
+ self._lock = asyncio.Lock()
39
+ self._is_initialized = False
40
+
41
+ # Connection pool configuration with validation
42
+ self.pool_min_size = max(1, int(global_config.POSTGRES_POOL_MIN_SIZE))
43
+ self.pool_max_size = max(self.pool_min_size, int(global_config.POSTGRES_POOL_MAX_SIZE))
44
+ self.pool_timeout = max(1.0, float(global_config.POSTGRES_POOL_TIMEOUT))
45
+ self.pool_recycle = max(60, int(global_config.POSTGRES_POOL_RECYCLE))
46
+
47
+ logger.debug(
48
+ f"Worker {self.worker_pid} pool config: "
49
+ f"min_size={self.pool_min_size}, max_size={self.pool_max_size}, "
50
+ f"timeout={self.pool_timeout:.1f}, recycle={self.pool_recycle}"
51
+ )
52
+
53
+ AsyncPostgresCheckpointer._instances.add(self)
54
+ AsyncPostgresCheckpointer._worker_instances[self.worker_pid] = self
55
+ logger.debug(f"New checkpointer instance registered for worker {self.worker_pid}")
56
+
57
+ @classmethod
58
+ def get_worker_instance(cls) -> 'AsyncPostgresCheckpointer':
59
+ """Get or create instance for current worker process"""
60
+ worker_pid = os.getpid()
61
+ if worker_pid not in cls._worker_instances:
62
+ instance = cls()
63
+ cls._worker_instances[worker_pid] = instance
64
+ logger.debug(f"Created new checkpointer instance for worker {worker_pid}")
65
+ return cls._worker_instances[worker_pid]
66
+
67
+ @classmethod
68
+ async def close_worker_instance(cls):
69
+ """Close instance for current worker process"""
70
+ worker_pid = os.getpid()
71
+ if worker_pid in cls._worker_instances:
72
+ instance = cls._worker_instances.pop(worker_pid)
73
+ await instance.close()
74
+
75
+ @staticmethod
76
+ def _get_conn_str() -> str:
77
+ """获取安全的连接字符串(隐藏密码)"""
78
+ conn_str = global_config.POSTGRES_URL
79
+ safe_conn_str = conn_str.split('@')[0] + '@[REDACTED]' if '@' in conn_str else conn_str
80
+ logger.debug("Using connection string: %s", safe_conn_str)
81
+ return conn_str
82
+
83
+ @classmethod
84
+ async def close_all(cls):
85
+ logger.info("Initiating shutdown of all checkpointer instances")
86
+
87
+ async with cls._shutdown_lock:
88
+ instances = list(cls._instances)
89
+ if not instances:
90
+ logger.debug("No active instances to close")
91
+ return
92
+
93
+ logger.debug("Closing %d active instance(s)", len(instances))
94
+ for instance in instances:
95
+ try:
96
+ await instance.close()
97
+ logger.debug("Instance closed successfully")
98
+ except Exception as e:
99
+ logger.warning(
100
+ "Error closing instance: %s",
101
+ str(e),
102
+ exc_info=logger.isEnabledFor(logging.DEBUG)
103
+ )
104
+
105
+ async def initialize(self) -> None:
106
+ """Initialize connection pool and database schema"""
107
+ async with self._lock:
108
+ if self._is_initialized:
109
+ logger.debug("Already initialized, skipping")
110
+ return
111
+
112
+ logger.info("Starting connection pool initialization")
113
+
114
+ try:
115
+ self.pool = AsyncConnectionPool(
116
+ self.conn_str,
117
+ min_size=self.pool_min_size,
118
+ max_size=self.pool_max_size,
119
+ timeout=self.pool_timeout,
120
+ open=False,
121
+ kwargs={"row_factory": dict_row},
122
+ )
123
+ logger.debug("Connection pool created")
124
+
125
+ await self.pool.open()
126
+ logger.debug("Connection pool opened")
127
+
128
+ pool_with_dict_conn = cast(AsyncConnectionPool[AsyncConnection[Dict[str, Any]]], self.pool)
129
+ self.saver = AsyncPostgresSaver(pool_with_dict_conn)
130
+ self._is_initialized = True
131
+ logger.info("Checkpointer initialized successfully")
132
+
133
+ except Exception as e:
134
+ logger.error("Initialization failed", exc_info=True)
135
+ await self._safe_close()
136
+ raise RuntimeError("Checkpointer initialization failed") from e
137
+
138
+ @retry(
139
+ stop=stop_after_attempt(3),
140
+ wait=wait_exponential(multiplier=1, min=4, max=10),
141
+ retry=retry_if_exception_type(OperationalError),
142
+ before_sleep=lambda _: logger.warning("Retrying schema initialization..."),
143
+ )
144
+ async def initialize_schema_with_retry(self):
145
+ logger.debug("Starting schema initialization")
146
+
147
+ async with self.pool.connection() as conn:
148
+ try:
149
+ await conn.set_autocommit(True)
150
+ logger.debug("Autocommit enabled for schema setup")
151
+
152
+ await AsyncPostgresStore(conn).setup()
153
+ await AsyncPostgresSaver(conn).setup()
154
+ logger.debug("Schema setup completed")
155
+
156
+ finally:
157
+ await conn.set_autocommit(False)
158
+ logger.debug("Autocommit restored to default")
159
+
160
+ @retry(
161
+ stop=stop_after_attempt(3),
162
+ wait=wait_exponential(multiplier=1, min=1, max=10),
163
+ retry=retry_if_exception_type(OperationalError),
164
+ )
165
+ async def checkpointer(self) -> AsyncPostgresSaver | None:
166
+ if not self._is_initialized or self.saver is None:
167
+ logger.debug("Checkpointer not ready, initializing...")
168
+ await self.initialize()
169
+ return self.saver
170
+
171
+ try:
172
+ async with self.pool.connection() as conn:
173
+ await conn.execute("SELECT 1") # 简单心跳检测
174
+ except OperationalError:
175
+ logger.warning("Connection pool invalid, reinitializing...")
176
+ await self._safe_close()
177
+ await self.initialize()
178
+
179
+ return self.saver
180
+
181
+ async def _safe_close(self) -> None:
182
+ logger.debug("Starting safe cleanup")
183
+
184
+ if self.pool:
185
+ try:
186
+ if not self.pool.closed:
187
+ await self.pool.close()
188
+ except Exception as e:
189
+ logger.warning(f"Error closing pool with an error {str(e)}", exc_info=True)
190
+
191
+ self.pool = None
192
+ self.saver = None
193
+ self._is_initialized = False
194
+ logger.info("Resources cleaned up")
195
+
196
+ async def close(self) -> None:
197
+ """关闭所有连接并清理资源"""
198
+ async with self._lock:
199
+ if not self._is_initialized:
200
+ logger.debug("Already closed, skipping")
201
+ return
202
+
203
+ logger.info("Starting graceful shutdown")
204
+
205
+ try:
206
+ if self.pool and not self.pool.closed:
207
+ logger.debug("Closing connection pool gracefully")
208
+ await self.pool.close()
209
+ logger.info("Connection pool closed gracefully")
210
+
211
+ except Exception as e:
212
+ logger.error(
213
+ "Error during shutdown: %s",
214
+ str(e),
215
+ exc_info=logger.isEnabledFor(logging.DEBUG)
216
+ )
217
+ finally:
218
+ await self._safe_close()
219
+ AsyncPostgresCheckpointer._instances.discard(self)
220
+ logger.debug("Instance unregistered")
221
+
222
+ async def __aenter__(self) -> AsyncPostgresSaver:
223
+ logger.debug("Entering context manager")
224
+ await self.initialize()
225
+ return await self.checkpointer()
226
+
227
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
228
+ logger.debug("Exiting context manager")
229
+ await self.close()
230
+
231
+ @property
232
+ def is_initialized(self):
233
+ return self._is_initialized
File without changes
@@ -0,0 +1,65 @@
1
+ import logging
2
+ from typing import List, Dict
3
+ from agent_api_server.shared.common import process_model_from_config_dict
4
+ from agent_api_server.shared.util_func import load_graph_config, load_graph
5
+ from model_manage_client import ModelManageClient
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class AgentRegistry:
10
+ def __init__(self, base_url:str, client_token:str):
11
+ self.client = ModelManageClient(
12
+ base_url=base_url,
13
+ client_token=client_token
14
+ )
15
+
16
+ @staticmethod
17
+ async def get_llm_model_from_agent_cfg(agent: Dict) -> Dict:
18
+ graph_cfg = await load_graph_config()
19
+ _, graph_instance, _ = await load_graph(agent['agent_name'], graph_cfg, False)
20
+
21
+ cfg = graph_instance.get_context_jsonschema()
22
+
23
+ logger.info(f"Get agent config {cfg} from {agent['agent_name']}")
24
+
25
+ return process_model_from_config_dict(cfg)
26
+
27
+ async def register_all(self, agents: List[Dict]):
28
+ for agent in agents:
29
+ try:
30
+ model_info = await self.get_llm_model_from_agent_cfg(agent)
31
+
32
+ extra_params = {
33
+ "agent_description": agent["agent_description"],
34
+ "agent_icon_url": agent["agent_icon_url"],
35
+ "agent_api_version": agent["agent_api_version"],
36
+ "agent_features": agent["agent_features"],
37
+ "agent_labels": agent["agent_labels"],
38
+ "support_models": model_info,
39
+ "has_site": agent["has_site"],
40
+ "is_system_agent": agent["is_system_agent"],
41
+ "multilangs": agent.get("multilangs", {})
42
+ }
43
+
44
+ agent_find = self.client.get_agent(agent_name=agent["agent_name"])
45
+ if agent_find:
46
+ logger.info(f"agent: {agent['agent_name']} already register successfully, update agent with extra_params {extra_params}")
47
+
48
+ self.client.update_agent(
49
+ agent_name=agent["agent_name"],
50
+ agent_url=agent["agent_url"],
51
+ **extra_params
52
+ )
53
+ continue
54
+
55
+ self.client.register_agent(
56
+ agent_name=agent["agent_name"],
57
+ agent_id=agent["agent_name"],
58
+ agent_url=agent["agent_url"],
59
+ **extra_params
60
+ )
61
+
62
+ logger.info(
63
+ f"agent: {agent['agent_name']} register successfully, register agent with extra_params {extra_params}")
64
+ except Exception as e:
65
+ raise e