agent-os-server 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_os_server-0.0.4.dist-info/METADATA +168 -0
- agent_os_server-0.0.4.dist-info/RECORD +11 -0
- agent_os_server-0.0.4.dist-info/WHEEL +5 -0
- agent_os_server-0.0.4.dist-info/entry_points.txt +4 -0
- agent_os_server-0.0.4.dist-info/licenses/LICENSE +0 -0
- agent_os_server-0.0.4.dist-info/top_level.txt +1 -0
- agent_server/__init__.py +0 -0
- agent_server/adk_server.py +372 -0
- agent_server/agui_server.py +69 -0
- agent_server/cmd_server.py +340 -0
- agent_server/utils.py +154 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: agent-os-server
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: 智能体服务
|
|
5
|
+
Author-email: fubo <fb_linux@163.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://codeup.aliyun.com/64650c96168f0a5963451dec/agi-next/agent-server/blob/master/README.md
|
|
8
|
+
Project-URL: Repository, https://codeup.aliyun.com/64650c96168f0a5963451dec/agi-next/agent-server.git
|
|
9
|
+
Keywords: agent-server
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: agent-os-base>=0.1.2
|
|
17
|
+
Provides-Extra: test
|
|
18
|
+
Requires-Dist: pytest>=6.0; extra == "test"
|
|
19
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
20
|
+
Dynamic: license-file
|
|
21
|
+
|
|
22
|
+
# Agent Server
|
|
23
|
+
|
|
24
|
+
一个基于 Google ADK (Agent Development Kit) 和 FastAPI 构建的智能体服务服务器,用于运行和管理 AI 代理。
|
|
25
|
+
|
|
26
|
+
## 项目概述
|
|
27
|
+
|
|
28
|
+
`agent-server` 提供了多种接口方式来与 AI 代理进行交互,包括 HTTP API、AGUI 接口和命令行界面。
|
|
29
|
+
|
|
30
|
+
## 核心功能
|
|
31
|
+
|
|
32
|
+
### 1. ADK Server
|
|
33
|
+
|
|
34
|
+
基于 FastAPI 的 HTTP API 服务器,使用 Google ADK 的 Runner 来执行代理。
|
|
35
|
+
|
|
36
|
+
- **功能**: 提供标准的 HTTP API 接口
|
|
37
|
+
- **会话管理**: 支持会话创建和管理
|
|
38
|
+
- **安全**: 包含用户白名单中间件
|
|
39
|
+
- **启动命令**: `agent-server-adk`
|
|
40
|
+
|
|
41
|
+
### 2. AGUI Server
|
|
42
|
+
|
|
43
|
+
提供 AGUI (Agent Graphical User Interface) 接口,用于图形化交互。
|
|
44
|
+
|
|
45
|
+
- **主要端点**:
|
|
46
|
+
- `POST /agui/api/v1/chat/completions` - 聊天完成接口
|
|
47
|
+
- `GET /agui/api/v1/reload` - 热重载代理配置
|
|
48
|
+
- `GET /agui/api/v1/health` - 健康检查
|
|
49
|
+
- **启动命令**: `agent-server-agui`
|
|
50
|
+
|
|
51
|
+
### 3. CMD Server
|
|
52
|
+
|
|
53
|
+
命令行交互界面,使用 `rich` 库实现彩色输出和 Markdown 渲染。
|
|
54
|
+
|
|
55
|
+
- **功能**:
|
|
56
|
+
- 任务提交和执行
|
|
57
|
+
- 会话管理(列出、删除)
|
|
58
|
+
- 实时输出代理响应
|
|
59
|
+
- 支持工具调用和函数响应展示
|
|
60
|
+
- **启动命令**: `agent-server-cmd`
|
|
61
|
+
|
|
62
|
+
## 技术架构
|
|
63
|
+
|
|
64
|
+
### 依赖
|
|
65
|
+
- **Google ADK**: Agent Development Kit,用于代理运行
|
|
66
|
+
- **FastAPI**: Web 框架,提供 HTTP API
|
|
67
|
+
- **Uvicorn**: ASGI 服务器
|
|
68
|
+
- **Rich**: 命令行美化输出
|
|
69
|
+
- **Python Dotenv**: 环境变量管理
|
|
70
|
+
|
|
71
|
+
### 工具类 (utils.py)
|
|
72
|
+
|
|
73
|
+
- **Utils.Service**: 提供 session 和 memory 服务的创建
|
|
74
|
+
- `memory_service()`: 创建记忆服务
|
|
75
|
+
- `session_service()`: 创建会话服务
|
|
76
|
+
- **Utils.Agent**: 负责代理的加载和运行器创建
|
|
77
|
+
- `load_adk_agents()`: 从指定路径加载代理配置
|
|
78
|
+
- `create_adk_agent_agui_runner()`: 创建 AGUI 运行器
|
|
79
|
+
- `create_adk_agent_adk_runner()`: 创建 ADK 运行器
|
|
80
|
+
|
|
81
|
+
## 配置说明
|
|
82
|
+
|
|
83
|
+
配置文件:`.server.env`
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
# 代理配置路径
|
|
87
|
+
AGENTS_PATH=/Users/fubo/develop/apps/agent-os/agents
|
|
88
|
+
|
|
89
|
+
# 会话服务数据库 URI(留空使用内存服务)
|
|
90
|
+
SESSION_SERVICE_URI=
|
|
91
|
+
|
|
92
|
+
# 记忆服务 URI(留空使用内存服务)
|
|
93
|
+
MEMORY_SERVICE_URI=
|
|
94
|
+
|
|
95
|
+
# 用户服务 URI
|
|
96
|
+
USER_SERVICE_URI=
|
|
97
|
+
|
|
98
|
+
# 追踪服务 URI
|
|
99
|
+
TRACE_SERVICE_URI=
|
|
100
|
+
|
|
101
|
+
# 服务器监听地址
|
|
102
|
+
HOST=0.0.0.0
|
|
103
|
+
|
|
104
|
+
# 服务器端口
|
|
105
|
+
PORT=8000
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
## 安装与使用
|
|
109
|
+
|
|
110
|
+
### 安装依赖
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
pip install -e .
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
### 启动服务
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
# 启动 AGUI 服务器
|
|
120
|
+
agent-server-agui
|
|
121
|
+
|
|
122
|
+
# 启动命令行界面
|
|
123
|
+
agent-server-cmd
|
|
124
|
+
|
|
125
|
+
# 启动 ADK 服务器
|
|
126
|
+
agent-server-adk
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## 项目结构
|
|
130
|
+
|
|
131
|
+
```
|
|
132
|
+
agent-server/
|
|
133
|
+
├── src/
|
|
134
|
+
│ └── agent_server/
|
|
135
|
+
│ ├── __init__.py
|
|
136
|
+
│ ├── adk_server.py # ADK HTTP 服务器
|
|
137
|
+
│ ├── agui_server.py # AGUI 服务器
|
|
138
|
+
│ ├── cmd_server.py # 命令行服务器
|
|
139
|
+
│ ├── utils.py # 工具类
|
|
140
|
+
│ └── .server.env # 配置文件
|
|
141
|
+
├── tests/
|
|
142
|
+
├── pyproject.toml
|
|
143
|
+
├── Dockerfile
|
|
144
|
+
└── README.md
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
## 开发信息
|
|
148
|
+
|
|
149
|
+
- **版本**: 0.0.1
|
|
150
|
+
- **作者**: fubo <fb_linux@163.com>
|
|
151
|
+
- **许可证**: MIT
|
|
152
|
+
- **Python 版本**: >= 3.10
|
|
153
|
+
|
|
154
|
+
## 依赖项
|
|
155
|
+
|
|
156
|
+
- `agent-base>=0.1.2`
|
|
157
|
+
- `fastapi`
|
|
158
|
+
- `uvicorn`
|
|
159
|
+
- `rich`
|
|
160
|
+
- `python-dotenv`
|
|
161
|
+
- `google-adk`
|
|
162
|
+
|
|
163
|
+
## 测试
|
|
164
|
+
|
|
165
|
+
```bash
|
|
166
|
+
pip install -e ".[test]"
|
|
167
|
+
pytest
|
|
168
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
agent_os_server-0.0.4.dist-info/licenses/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
agent_server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
agent_server/adk_server.py,sha256=-i1gNbVkfiiqZsZ1zwScI6nuzK0UamqbNEU_zNQm2Jc,13488
|
|
4
|
+
agent_server/agui_server.py,sha256=rYcOpmrZII2v_ZrydaWY298dgpkOg3gVDjCv7VZZ9cI,2388
|
|
5
|
+
agent_server/cmd_server.py,sha256=Q1yMKHIiuhgj4MaVuZwE093H4DhTWEEFwB_9mWNBjhQ,12975
|
|
6
|
+
agent_server/utils.py,sha256=vbf3mGyycWzFk1XXE9GdYX_z2qzkWpHGgNqXLiquxhs,5350
|
|
7
|
+
agent_os_server-0.0.4.dist-info/METADATA,sha256=25Pu88H_htUlQvddfJR-2RGMSCd0vtQauzhxBW2YLLM,4054
|
|
8
|
+
agent_os_server-0.0.4.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
9
|
+
agent_os_server-0.0.4.dist-info/entry_points.txt,sha256=SNeJ2OiWVPKWKaJIk_r4fgjeTcb7j-sJGXp1d5bZCzM,164
|
|
10
|
+
agent_os_server-0.0.4.dist-info/top_level.txt,sha256=IgClwJDz-2vEwgH27kHdjOhpnyqvomdkfGgrXBV5Tc4,13
|
|
11
|
+
agent_os_server-0.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
agent_server
|
agent_server/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import uvicorn
|
|
6
|
+
|
|
7
|
+
logging.basicConfig(
|
|
8
|
+
level=logging.INFO,
|
|
9
|
+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
10
|
+
stream=sys.stderr
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from collections.abc import AsyncIterable
|
|
14
|
+
from dotenv import dotenv_values
|
|
15
|
+
|
|
16
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
17
|
+
from fastapi.sse import EventSourceResponse, ServerSentEvent
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
|
|
20
|
+
from google.adk.runners import Runner
|
|
21
|
+
from google.adk.sessions import DatabaseSessionService
|
|
22
|
+
from google.adk.artifacts import InMemoryArtifactService
|
|
23
|
+
from google.genai import types
|
|
24
|
+
from google.adk.apps import App
|
|
25
|
+
from .utils import Utils
|
|
26
|
+
|
|
27
|
+
# from agents.proxy_agent.agent import root_agent
|
|
28
|
+
|
|
29
|
+
#### ===========
|
|
30
|
+
# logging.basicConfig(
|
|
31
|
+
# level=logging.INFO,
|
|
32
|
+
# format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
33
|
+
# )
|
|
34
|
+
# logger = logging.getLogger("agent_os")
|
|
35
|
+
#
|
|
36
|
+
#
|
|
37
|
+
# app_name = "agent_os_app"
|
|
38
|
+
# session_service = DatabaseSessionService(
|
|
39
|
+
# db_url="sqlite+aiosqlite:///./.resources/sqlite.db/sessions.db"
|
|
40
|
+
# )
|
|
41
|
+
#
|
|
42
|
+
# Runner(app=App(name=, root_agent=root_agent,plugins=))
|
|
43
|
+
# runner = Runner(
|
|
44
|
+
# app_name=app_name,
|
|
45
|
+
# agent=root_agent,
|
|
46
|
+
# session_service=session_service,
|
|
47
|
+
# artifact_service=InMemoryArtifactService(),
|
|
48
|
+
# auto_create_session=False,
|
|
49
|
+
# )
|
|
50
|
+
|
|
51
|
+
#### ===========
|
|
52
|
+
|
|
53
|
+
config = dotenv_values(".server.env")
|
|
54
|
+
server_config = Utils.Server.ServerConfig(
|
|
55
|
+
agents_path=config.get("AGENTS_PATH", ""),
|
|
56
|
+
session_service_uri=config.get("SESSION_SERVICE_URI", ""),
|
|
57
|
+
memory_service_uri=config.get("MEMORY_SERVICE_URI", ""),
|
|
58
|
+
user_service_uri=config.get("USER_SERVICE_URI", ""),
|
|
59
|
+
trace_service_uri=config.get("TRACE_SERVICE_URI", ""),
|
|
60
|
+
host=config.get("HOST", "0.0.0.0"),
|
|
61
|
+
port=int(config.get("PORT", "8000"))
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
class CreateSessionRequest(BaseModel):
|
|
65
|
+
user_id: str
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ChatRequest(BaseModel):
|
|
69
|
+
user_id: str
|
|
70
|
+
session_id: str
|
|
71
|
+
message: str
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Application:
|
|
75
|
+
global server_config
|
|
76
|
+
api = FastAPI(title=f"{Utils.Agent.agent_name} ADK Server")
|
|
77
|
+
agent_app_name = f"{Utils.Agent.agent_name} ADK APP"
|
|
78
|
+
api.add_middleware(Utils.Server.UserWhitelistMiddleware)
|
|
79
|
+
agent_adk_runner = Utils.Agent.create_adk_agent_adk_runner(
|
|
80
|
+
agents_path=server_config.agents_path,
|
|
81
|
+
session_service_uri=server_config.session_service_uri,
|
|
82
|
+
memory_service_uri=server_config.memory_service_uri
|
|
83
|
+
)
|
|
84
|
+
session_service = agent_adk_runner.session_service
|
|
85
|
+
if agent_adk_runner is None:
|
|
86
|
+
raise Exception("Agent load failed")
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def safe_jsonable(value):
|
|
90
|
+
if value is None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
if hasattr(value, "model_dump"):
|
|
94
|
+
try:
|
|
95
|
+
return value.model_dump(mode="json")
|
|
96
|
+
except Exception:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
if hasattr(value, "dict"):
|
|
100
|
+
try:
|
|
101
|
+
return value.dict()
|
|
102
|
+
except Exception:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
if isinstance(value, (str, int, float, bool, list, dict)):
|
|
106
|
+
return value
|
|
107
|
+
|
|
108
|
+
return str(value)
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
async def ensure_session(user_id: str, session_id: str | None = None):
|
|
112
|
+
if session_id:
|
|
113
|
+
session = await Application.session_service.get_session(
|
|
114
|
+
app_name=Application.agent_app_name,
|
|
115
|
+
user_id=user_id,
|
|
116
|
+
session_id=session_id,
|
|
117
|
+
)
|
|
118
|
+
if session:
|
|
119
|
+
return session
|
|
120
|
+
|
|
121
|
+
return await Application.session_service.create_session(
|
|
122
|
+
app_name=Application.agent_app_name,
|
|
123
|
+
user_id=user_id,
|
|
124
|
+
session_id=session_id,
|
|
125
|
+
state={},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return await Application.session_service.create_session(
|
|
129
|
+
app_name=Application.agent_app_name,
|
|
130
|
+
user_id=user_id,
|
|
131
|
+
state={},
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@api.post("/session/create", tags=["ADK API"])
|
|
136
|
+
async def create_session(request: CreateSessionRequest):
|
|
137
|
+
try:
|
|
138
|
+
session = await Application.ensure_session(user_id=request.user_id)
|
|
139
|
+
return {"user_id": request.user_id, "session_id": session.id}
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logging.exception("create_session failed")
|
|
142
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@api.post("/chat/stream", response_class=EventSourceResponse, tags=["ADK API"])
|
|
146
|
+
async def chat_stream(request: ChatRequest, http_request: Request) -> AsyncIterable[ServerSentEvent]:
|
|
147
|
+
"""
|
|
148
|
+
关键点:
|
|
149
|
+
1. 这里不要 return EventSourceResponse(...)
|
|
150
|
+
2. 路由函数本身直接 yield,符合 FastAPI 官方 SSE 用法
|
|
151
|
+
3. 没有消息时也定期发 comment,防止连接静默
|
|
152
|
+
"""
|
|
153
|
+
await Application.ensure_session(user_id=request.user_id, session_id=request.session_id)
|
|
154
|
+
|
|
155
|
+
queue: asyncio.Queue[ServerSentEvent | None] = asyncio.Queue()
|
|
156
|
+
|
|
157
|
+
async def producer():
|
|
158
|
+
try:
|
|
159
|
+
content = types.Content(
|
|
160
|
+
role="user",
|
|
161
|
+
parts=[types.Part(text=request.message)]
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
logging.info(
|
|
165
|
+
"stream start user_id=%s session_id=%s",
|
|
166
|
+
request.user_id,
|
|
167
|
+
request.session_id,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
async for event in Application.agent_adk_runner.run_async(
|
|
171
|
+
user_id=request.user_id,
|
|
172
|
+
session_id=request.session_id,
|
|
173
|
+
new_message=content,
|
|
174
|
+
):
|
|
175
|
+
if event.content and event.content.parts:
|
|
176
|
+
for part in event.content.parts:
|
|
177
|
+
if getattr(part, "text", None):
|
|
178
|
+
await queue.put(
|
|
179
|
+
ServerSentEvent(
|
|
180
|
+
data={
|
|
181
|
+
"type": "thinking" if getattr(part, "thought", False) else "text",
|
|
182
|
+
"content": part.text,
|
|
183
|
+
},
|
|
184
|
+
event="message",
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if getattr(part, "function_call", None) is not None:
|
|
189
|
+
await queue.put(
|
|
190
|
+
ServerSentEvent(
|
|
191
|
+
data={
|
|
192
|
+
"type": "tool_call",
|
|
193
|
+
"delta": {
|
|
194
|
+
"tool_name": part.function_call.name,
|
|
195
|
+
"params": Application.safe_jsonable(part.function_call.args),
|
|
196
|
+
},
|
|
197
|
+
},
|
|
198
|
+
event="message",
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if getattr(part, "function_response", None) is not None:
|
|
203
|
+
await queue.put(
|
|
204
|
+
ServerSentEvent(
|
|
205
|
+
data={
|
|
206
|
+
"type": "tool_response",
|
|
207
|
+
"delta": {
|
|
208
|
+
"tool_name": part.function_response.name,
|
|
209
|
+
"response": Application.safe_jsonable(part.function_response.response),
|
|
210
|
+
},
|
|
211
|
+
},
|
|
212
|
+
event="message",
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if event.actions and event.actions.state_delta:
|
|
217
|
+
await queue.put(
|
|
218
|
+
ServerSentEvent(
|
|
219
|
+
data={
|
|
220
|
+
"type": "state_update",
|
|
221
|
+
"delta": Application.safe_jsonable(event.actions.state_delta),
|
|
222
|
+
},
|
|
223
|
+
event="message",
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if event.is_final_response():
|
|
228
|
+
final_session = await Application.session_service.get_session(
|
|
229
|
+
app_name=Application.agent_app_name,
|
|
230
|
+
user_id=request.user_id,
|
|
231
|
+
session_id=request.session_id,
|
|
232
|
+
)
|
|
233
|
+
await queue.put(
|
|
234
|
+
ServerSentEvent(
|
|
235
|
+
data={
|
|
236
|
+
"type": "final",
|
|
237
|
+
"state": Application.safe_jsonable(final_session.state if final_session else None),
|
|
238
|
+
},
|
|
239
|
+
event="final",
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
await queue.put(ServerSentEvent(raw_data="[DONE]", event="done"))
|
|
244
|
+
logging.info(
|
|
245
|
+
"stream done user_id=%s session_id=%s",
|
|
246
|
+
request.user_id,
|
|
247
|
+
request.session_id,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
except asyncio.CancelledError:
|
|
251
|
+
logging.info(
|
|
252
|
+
"stream producer cancelled user_id=%s session_id=%s",
|
|
253
|
+
request.user_id,
|
|
254
|
+
request.session_id,
|
|
255
|
+
)
|
|
256
|
+
raise
|
|
257
|
+
except Exception as e:
|
|
258
|
+
logging.exception(
|
|
259
|
+
"stream producer failed user_id=%s session_id=%s",
|
|
260
|
+
request.user_id,
|
|
261
|
+
request.session_id,
|
|
262
|
+
)
|
|
263
|
+
await queue.put(
|
|
264
|
+
ServerSentEvent(
|
|
265
|
+
data={"type": "error", "message": str(e)},
|
|
266
|
+
event="error",
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
finally:
|
|
270
|
+
await queue.put(None)
|
|
271
|
+
|
|
272
|
+
producer_task = asyncio.create_task(producer())
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
# 首包
|
|
276
|
+
yield ServerSentEvent(comment="stream opened")
|
|
277
|
+
|
|
278
|
+
while True:
|
|
279
|
+
if await http_request.is_disconnected():
|
|
280
|
+
producer_task.cancel()
|
|
281
|
+
break
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
item = await asyncio.wait_for(queue.get(), timeout=10)
|
|
285
|
+
except asyncio.TimeoutError:
|
|
286
|
+
# 空闲保活
|
|
287
|
+
yield ServerSentEvent(comment="keep-alive")
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
if item is None:
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
yield item
|
|
294
|
+
|
|
295
|
+
finally:
|
|
296
|
+
if not producer_task.done():
|
|
297
|
+
producer_task.cancel()
|
|
298
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
299
|
+
await producer_task
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
@api.post("/chat", tags=["ADK API"])
|
|
303
|
+
async def chat(request: ChatRequest):
|
|
304
|
+
try:
|
|
305
|
+
await Application.ensure_session(user_id=request.user_id, session_id=request.session_id)
|
|
306
|
+
|
|
307
|
+
full_response = ""
|
|
308
|
+
final_state = None
|
|
309
|
+
|
|
310
|
+
content = types.Content(
|
|
311
|
+
role="user",
|
|
312
|
+
parts=[types.Part(text=request.message)]
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
async for event in Application.agent_adk_runner.run_async(
|
|
316
|
+
user_id=request.user_id,
|
|
317
|
+
session_id=request.session_id,
|
|
318
|
+
new_message=content,
|
|
319
|
+
):
|
|
320
|
+
if event.content and event.content.parts:
|
|
321
|
+
for part in event.content.parts:
|
|
322
|
+
if getattr(part, "text", None):
|
|
323
|
+
full_response += part.text
|
|
324
|
+
|
|
325
|
+
if event.is_final_response():
|
|
326
|
+
final_session = await Application.session_service.get_session(
|
|
327
|
+
app_name=Application.agent_app_name,
|
|
328
|
+
user_id=request.user_id,
|
|
329
|
+
session_id=request.session_id,
|
|
330
|
+
)
|
|
331
|
+
final_state = final_session.state if final_session else None
|
|
332
|
+
|
|
333
|
+
return {
|
|
334
|
+
"response": full_response,
|
|
335
|
+
"final_state": Application.safe_jsonable(final_state),
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logging.exception("chat failed")
|
|
340
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
@api.get("/session/{user_id}/{session_id}/state", tags=["ADK API"])
|
|
344
|
+
async def get_session_state(user_id: str, session_id: str):
|
|
345
|
+
try:
|
|
346
|
+
session = await Application.session_service.get_session(
|
|
347
|
+
app_name=Application.agent_app_name,
|
|
348
|
+
user_id=user_id,
|
|
349
|
+
session_id=session_id,
|
|
350
|
+
)
|
|
351
|
+
if not session:
|
|
352
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
353
|
+
|
|
354
|
+
return {"state": Application.safe_jsonable(session.state)}
|
|
355
|
+
|
|
356
|
+
except HTTPException:
|
|
357
|
+
raise
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logging.exception("get_session_state failed")
|
|
360
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
361
|
+
|
|
362
|
+
def main():
|
|
363
|
+
global server_config
|
|
364
|
+
logging.basicConfig(
|
|
365
|
+
level=logging.INFO, format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
|
366
|
+
stream=sys.stderr
|
|
367
|
+
)
|
|
368
|
+
uvicorn.run(Application.api, host=server_config.host, port=server_config.port)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
if __name__ == '__main__':
|
|
372
|
+
main()
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import logging
|
|
3
|
+
# logging.basicConfig(
|
|
4
|
+
# level=logging.INFO,
|
|
5
|
+
# format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
6
|
+
# )
|
|
7
|
+
import uvicorn
|
|
8
|
+
from dotenv import dotenv_values
|
|
9
|
+
from fastapi import FastAPI
|
|
10
|
+
from ag_ui_adk import add_adk_fastapi_endpoint
|
|
11
|
+
from .utils import Utils
|
|
12
|
+
|
|
13
|
+
config = dotenv_values(".server.env")
|
|
14
|
+
server_config = Utils.Server.ServerConfig(
|
|
15
|
+
agents_path=config.get("AGENTS_PATH", ""),
|
|
16
|
+
session_service_uri=config.get("SESSION_SERVICE_URI", ""),
|
|
17
|
+
memory_service_uri=config.get("MEMORY_SERVICE_URI", ""),
|
|
18
|
+
user_service_uri=config.get("USER_SERVICE_URI", ""),
|
|
19
|
+
trace_service_uri=config.get("TRACE_SERVICE_URI", ""),
|
|
20
|
+
host=config.get("HOST", "0.0.0.0"),
|
|
21
|
+
port=int(config.get("PORT", "8000"))
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Application:
|
|
26
|
+
global server_config
|
|
27
|
+
api = FastAPI(title=f"{Utils.Agent.agent_name} AGUI Server")
|
|
28
|
+
api.add_middleware(Utils.Server.UserWhitelistMiddleware)
|
|
29
|
+
agent_agui_runner = Utils.Agent.create_adk_agent_agui_runner(
|
|
30
|
+
agents_path=server_config.agents_path,
|
|
31
|
+
session_service_uri=server_config.session_service_uri,
|
|
32
|
+
memory_service_uri=server_config.memory_service_uri
|
|
33
|
+
)
|
|
34
|
+
if agent_agui_runner is None:
|
|
35
|
+
raise Exception("Agent load failed")
|
|
36
|
+
|
|
37
|
+
add_adk_fastapi_endpoint(api, agent_agui_runner, path="/agui/api/v1/chat/completions")
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
@api.get("/agui/api/v1/reload")
|
|
41
|
+
async def reload():
|
|
42
|
+
Application.agent_agui_runner = Utils.Agent.create_adk_agent_agui_runner(
|
|
43
|
+
agents_path=server_config.agents_path,
|
|
44
|
+
session_service_uri=server_config.session_service_uri,
|
|
45
|
+
memory_service_uri=server_config.memory_service_uri
|
|
46
|
+
)
|
|
47
|
+
if Application.agent_agui_runner is None:
|
|
48
|
+
logging.error("Failed to load agents")
|
|
49
|
+
return {"code": 10000, "message": "Failed to load agents", "data": ""}
|
|
50
|
+
|
|
51
|
+
return {"code": 0, "message": "", "data": "OK"}
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
@api.get("/agui/api/v1/health")
|
|
55
|
+
async def health_check():
|
|
56
|
+
return {"code": 0, "message": "", "data": "OK"}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def main():
|
|
60
|
+
global server_config
|
|
61
|
+
logging.basicConfig(
|
|
62
|
+
level=logging.INFO, format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
|
63
|
+
stream=sys.stderr
|
|
64
|
+
)
|
|
65
|
+
uvicorn.run(Application.api, host=server_config.host, port=server_config.port)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == '__main__':
|
|
69
|
+
main()
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import json
|
|
3
|
+
import asyncio
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Optional, Dict, List
|
|
6
|
+
from dotenv import dotenv_values
|
|
7
|
+
# import logging
|
|
8
|
+
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
9
|
+
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.markdown import Markdown
|
|
12
|
+
|
|
13
|
+
from google.adk.runners import Runner
|
|
14
|
+
from google.adk.sessions import BaseSessionService
|
|
15
|
+
from google.genai import types
|
|
16
|
+
|
|
17
|
+
from .utils import Utils
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class JobRecord:
|
|
22
|
+
job_id: str
|
|
23
|
+
session_id: str
|
|
24
|
+
prompt: str
|
|
25
|
+
task: Optional[asyncio.Task] = None
|
|
26
|
+
done: bool = False
|
|
27
|
+
error: Optional[str] = None
|
|
28
|
+
buffer: List[str] = field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AgentRunner:
|
|
32
|
+
def __init__(self) -> None:
|
|
33
|
+
self.user_id = "cmd-user"
|
|
34
|
+
self.console = Console()
|
|
35
|
+
config = dotenv_values(".server.env")
|
|
36
|
+
server_config = Utils.Server.ServerConfig(
|
|
37
|
+
agents_path=config.get("AGENTS_PATH", ""),
|
|
38
|
+
session_service_uri=config.get("SESSION_SERVICE_URI", ""),
|
|
39
|
+
memory_service_uri=config.get("MEMORY_SERVICE_URI", ""),
|
|
40
|
+
user_service_uri=config.get("USER_SERVICE_URI", ""),
|
|
41
|
+
trace_service_uri=config.get("TRACE_SERVICE_URI", ""),
|
|
42
|
+
host=config.get("HOST", "0.0.0.0"),
|
|
43
|
+
port=int(config.get("PORT", "8000"))
|
|
44
|
+
)
|
|
45
|
+
self.runner: Runner = Utils.Agent.create_adk_agent_adk_runner(
|
|
46
|
+
agents_path=server_config.agents_path,
|
|
47
|
+
session_service_uri=server_config.session_service_uri,
|
|
48
|
+
memory_service_uri=server_config.memory_service_uri,
|
|
49
|
+
)
|
|
50
|
+
self.runner.auto_create_session = True
|
|
51
|
+
self.session_service: BaseSessionService = self.runner.session_service
|
|
52
|
+
|
|
53
|
+
self.jobs: Dict[str, JobRecord] = {}
|
|
54
|
+
self.foreground_job_id: Optional[str] = None
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def print_banner():
|
|
58
|
+
banner = r"""
|
|
59
|
+
█████╗ ██████╗ ███████╗███╗ ██╗████████╗ ██████╗ ███████╗
|
|
60
|
+
██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝ ██╔═══██╗██╔════╝
|
|
61
|
+
███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║ ██║ ██║███████╗
|
|
62
|
+
██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║ ██║ ██║╚════██║
|
|
63
|
+
██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║ ╚██████╔╝███████║
|
|
64
|
+
╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚═════╝ ╚══════╝
|
|
65
|
+
|
|
66
|
+
██████╗ ███████╗
|
|
67
|
+
██╔═══██╗██╔════╝
|
|
68
|
+
██║ ██║███████╗
|
|
69
|
+
██║ ██║╚════██║
|
|
70
|
+
╚██████╔╝███████║
|
|
71
|
+
╚═════╝ ╚══════╝
|
|
72
|
+
"""
|
|
73
|
+
print(banner)
|
|
74
|
+
|
|
75
|
+
def render_output(self, text: str):
|
|
76
|
+
self.console.print(Markdown(text), justify="left")
|
|
77
|
+
|
|
78
|
+
def format_part(self, part: types.Part) -> Optional[str]:
|
|
79
|
+
if part.text is not None and part.text != "" and part.thought is True:
|
|
80
|
+
return f"===\n* **分析**\n{part.text}"
|
|
81
|
+
|
|
82
|
+
if part.text is not None and part.text != "" and (part.thought is None or part.thought is False):
|
|
83
|
+
return f"===\n* **答案**\n{part.text}"
|
|
84
|
+
|
|
85
|
+
if part.function_call is not None:
|
|
86
|
+
try:
|
|
87
|
+
params_json = json.dumps(part.function_call.args, indent=2, ensure_ascii=False)
|
|
88
|
+
except Exception:
|
|
89
|
+
params_json = str(part.function_call.args)
|
|
90
|
+
|
|
91
|
+
return (
|
|
92
|
+
f"===\n* **工具调用**\n"
|
|
93
|
+
f"- 工具名:{part.function_call.name}\n"
|
|
94
|
+
f"- 工具参数:\n```python\n{params_json}\n```"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if part.function_response is not None:
|
|
98
|
+
try:
|
|
99
|
+
response_json = json.dumps(part.function_response.response, indent=2, ensure_ascii=False)
|
|
100
|
+
except Exception:
|
|
101
|
+
response_json = str(part.function_response.response)
|
|
102
|
+
|
|
103
|
+
return (
|
|
104
|
+
f"===\n* **工具调用返回结果**\n"
|
|
105
|
+
f"- 工具名:{part.function_response.name}\n"
|
|
106
|
+
f"- 工具返回结果:\n```json\n{response_json}\n```"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
def _emit_to_job(self, job: JobRecord, text: str):
|
|
112
|
+
job.buffer.append(text)
|
|
113
|
+
|
|
114
|
+
# 只有当前前台任务才实时显示
|
|
115
|
+
if self.foreground_job_id == job.job_id:
|
|
116
|
+
self.render_output(text)
|
|
117
|
+
|
|
118
|
+
async def submit_job(self, input_text: str) -> str:
|
|
119
|
+
job_id = uuid.uuid4().hex[:8]
|
|
120
|
+
session_id = uuid.uuid4().hex
|
|
121
|
+
|
|
122
|
+
job = JobRecord(
|
|
123
|
+
job_id=job_id,
|
|
124
|
+
session_id=session_id,
|
|
125
|
+
prompt=input_text,
|
|
126
|
+
)
|
|
127
|
+
self.jobs[job_id] = job
|
|
128
|
+
|
|
129
|
+
job.task = asyncio.create_task(self._run_job(job))
|
|
130
|
+
return job_id
|
|
131
|
+
|
|
132
|
+
async def _run_job(self, job: JobRecord):
|
|
133
|
+
content = types.Content(role="user", parts=[types.Part(text=job.prompt)])
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
async for chunk in self.runner.run_async(
|
|
137
|
+
user_id=self.user_id,
|
|
138
|
+
session_id=job.session_id,
|
|
139
|
+
new_message=content,
|
|
140
|
+
):
|
|
141
|
+
for part in chunk.content.parts:
|
|
142
|
+
rendered = self.format_part(part)
|
|
143
|
+
if rendered:
|
|
144
|
+
self._emit_to_job(job, rendered)
|
|
145
|
+
|
|
146
|
+
job.done = True
|
|
147
|
+
if self.foreground_job_id == job.job_id:
|
|
148
|
+
self.render_output(f"===\n* **任务完成**\n- job_id:`{job.job_id}`")
|
|
149
|
+
except asyncio.CancelledError:
|
|
150
|
+
job.done = True
|
|
151
|
+
job.error = "cancelled"
|
|
152
|
+
raise
|
|
153
|
+
except Exception as e:
|
|
154
|
+
job.done = True
|
|
155
|
+
job.error = str(e)
|
|
156
|
+
err_text = f"===\n* **任务异常**\n- job_id:`{job.job_id}`\n- 错误:```text\n{type(e).__name__}: {e}\n```"
|
|
157
|
+
self._emit_to_job(job, err_text)
|
|
158
|
+
|
|
159
|
+
async def _list_sessions_raw(self):
|
|
160
|
+
resp = await self.session_service.list_sessions(
|
|
161
|
+
app_name=self.app_name,
|
|
162
|
+
user_id=self.user_id,
|
|
163
|
+
)
|
|
164
|
+
if hasattr(resp, "sessions"):
|
|
165
|
+
return resp.sessions
|
|
166
|
+
return resp or []
|
|
167
|
+
|
|
168
|
+
async def show_sessions(self):
|
|
169
|
+
sessions = await self._list_sessions_raw()
|
|
170
|
+
|
|
171
|
+
if not sessions:
|
|
172
|
+
self.render_output("当前没有 session。")
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
lines = ["===\n* **Session 列表**"]
|
|
176
|
+
for s in sessions:
|
|
177
|
+
sid = getattr(s, "session_id", None) or getattr(s, "id", None) or str(s)
|
|
178
|
+
uid = getattr(s, "user_id", self.user_id)
|
|
179
|
+
app = getattr(s, "app_name", self.app_name)
|
|
180
|
+
lines.append(f"- session_id=`{sid}` | user=`{uid}` | app=`{app}`")
|
|
181
|
+
self.render_output("\n".join(lines))
|
|
182
|
+
|
|
183
|
+
async def delete_session(self, session_id: str):
|
|
184
|
+
# 如果该 session 对应的 job 还在运行,先取消任务并从内存中移除
|
|
185
|
+
for job_id, job in list(self.jobs.items()):
|
|
186
|
+
if job.session_id == session_id:
|
|
187
|
+
if job.task and not job.task.done():
|
|
188
|
+
job.task.cancel()
|
|
189
|
+
if self.foreground_job_id == job_id:
|
|
190
|
+
self.foreground_job_id = None
|
|
191
|
+
self.jobs.pop(job_id, None)
|
|
192
|
+
break
|
|
193
|
+
|
|
194
|
+
await self.session_service.delete_session(
|
|
195
|
+
app_name=self.app_name,
|
|
196
|
+
user_id=self.user_id,
|
|
197
|
+
session_id=session_id,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
self.render_output(f"===\n* **已删除 Session**\n- session_id:`{session_id}`")
|
|
201
|
+
|
|
202
|
+
async def delete_job(self, job_id: str):
|
|
203
|
+
job = self.jobs.get(job_id)
|
|
204
|
+
if not job:
|
|
205
|
+
self.render_output(f"===\n* **错误**\n未找到 job:`{job_id}`")
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
if job.task and not job.task.done():
|
|
209
|
+
job.task.cancel()
|
|
210
|
+
|
|
211
|
+
await self.session_service.delete_session(
|
|
212
|
+
app_name=self.app_name,
|
|
213
|
+
user_id=self.user_id,
|
|
214
|
+
session_id=job.session_id,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if self.foreground_job_id == job_id:
|
|
218
|
+
self.foreground_job_id = None
|
|
219
|
+
|
|
220
|
+
self.jobs.pop(job_id, None)
|
|
221
|
+
self.render_output(
|
|
222
|
+
f"===\n* **已删除 Job 对应 Session**\n- job_id:`{job_id}`\n- session_id:`{job.session_id}`"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async def cancel_job(self, job_id: str):
|
|
226
|
+
job = self.jobs.get(job_id)
|
|
227
|
+
if not job:
|
|
228
|
+
self.render_output(f"===\n* **错误**\n未找到 job:`{job_id}`")
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
if job.task and not job.task.done():
|
|
232
|
+
job.task.cancel()
|
|
233
|
+
self.render_output(f"===\n* **已取消任务**\n- job_id:`{job_id}`")
|
|
234
|
+
else:
|
|
235
|
+
self.render_output(f"===\n* **任务已结束**\n- job_id:`{job_id}`")
|
|
236
|
+
|
|
237
|
+
def show_jobs(self):
|
|
238
|
+
if not self.jobs:
|
|
239
|
+
self.render_output("当前没有任务。")
|
|
240
|
+
return
|
|
241
|
+
|
|
242
|
+
lines = ["===\n* **任务列表**"]
|
|
243
|
+
for jid, job in self.jobs.items():
|
|
244
|
+
status = "done" if job.done else "running"
|
|
245
|
+
if job.error == "cancelled":
|
|
246
|
+
status = "cancelled"
|
|
247
|
+
elif job.error:
|
|
248
|
+
status = "error"
|
|
249
|
+
fg = " ← 前台" if self.foreground_job_id == jid else ""
|
|
250
|
+
lines.append(
|
|
251
|
+
f"- `{jid}` | `{status}` | session=`{job.session_id}`{fg}"
|
|
252
|
+
)
|
|
253
|
+
self.render_output("\n".join(lines))
|
|
254
|
+
|
|
255
|
+
async def bring_to_foreground(self, job_id: str):
|
|
256
|
+
job = self.jobs.get(job_id)
|
|
257
|
+
if not job:
|
|
258
|
+
self.render_output(f"===\n* **错误**\n未找到 job:`{job_id}`")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
self.foreground_job_id = job_id
|
|
262
|
+
self.render_output(f"===\n* **切换前台**\n- 当前前台 job:`{job_id}`")
|
|
263
|
+
|
|
264
|
+
# 重放历史输出
|
|
265
|
+
if job.buffer:
|
|
266
|
+
self.render_output(f"===\n* **历史输出重放**\n- job:`{job_id}`")
|
|
267
|
+
for item in job.buffer:
|
|
268
|
+
self.render_output(item)
|
|
269
|
+
|
|
270
|
+
def send_to_background(self):
|
|
271
|
+
self.foreground_job_id = None
|
|
272
|
+
self.render_output("===\n* **已切回后台模式**")
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
async def run_main():
|
|
276
|
+
agent_runner = AgentRunner()
|
|
277
|
+
agent_runner.print_banner()
|
|
278
|
+
agent_runner.render_output(
|
|
279
|
+
"输入自然语言即提交一个任务;\n"
|
|
280
|
+
"`/jobs` 查看任务;`/sessions` 查看 session;\n"
|
|
281
|
+
"`/fg <job_id>` 切前台;`/bg` 回后台;\n"
|
|
282
|
+
"`/del <session_id>` 删除 session;`/deljob <job_id>` 删除 job 对应 session;\n"
|
|
283
|
+
"`/cancel <job_id>` 取消任务;`exit` `bye` `quit` 退出。"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
while True:
|
|
287
|
+
try:
|
|
288
|
+
# 不阻塞事件循环:后台任务可以继续并发跑
|
|
289
|
+
input_text = await asyncio.to_thread(input, "\n> ")
|
|
290
|
+
input_text = input_text.strip()
|
|
291
|
+
|
|
292
|
+
if not input_text:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
if input_text.lower() in {"bye", "exit", "quit"}:
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
if input_text == "/jobs":
|
|
299
|
+
agent_runner.show_jobs()
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
if input_text == "/sessions":
|
|
303
|
+
await agent_runner.show_sessions()
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
if input_text.startswith("/fg "):
|
|
307
|
+
job_id = input_text.split(maxsplit=1)[1].strip()
|
|
308
|
+
await agent_runner.bring_to_foreground(job_id)
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
if input_text == "/bg":
|
|
312
|
+
agent_runner.send_to_background()
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
if input_text.startswith("/deljob "):
|
|
316
|
+
job_id = input_text.split(maxsplit=1)[1].strip()
|
|
317
|
+
await agent_runner.delete_job(job_id)
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
if input_text.startswith("/del "):
|
|
321
|
+
session_id = input_text.split(maxsplit=1)[1].strip()
|
|
322
|
+
await agent_runner.delete_session(session_id)
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
if input_text.startswith("/cancel "):
|
|
326
|
+
job_id = input_text.split(maxsplit=1)[1].strip()
|
|
327
|
+
await agent_runner.cancel_job(job_id)
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
job_id = await agent_runner.submit_job(input_text)
|
|
331
|
+
agent_runner.render_output(f"===\n* **已提交任务**\n- job_id:`{job_id}`")
|
|
332
|
+
except (EOFError, KeyboardInterrupt):
|
|
333
|
+
print("\n已退出。")
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
def main():
|
|
337
|
+
asyncio.run(run_main())
|
|
338
|
+
|
|
339
|
+
if __name__ == "__main__":
|
|
340
|
+
main()
|
agent_server/utils.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from fastapi import Request
|
|
5
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
6
|
+
from starlette.responses import Response
|
|
7
|
+
from google.adk.agents import BaseAgent
|
|
8
|
+
from google.adk.memory import InMemoryMemoryService, BaseMemoryService
|
|
9
|
+
from google.adk.sessions import DatabaseSessionService, InMemorySessionService, BaseSessionService
|
|
10
|
+
from google.adk.runners import Runner
|
|
11
|
+
from ag_ui_adk import ADKAgent
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Utils:
|
|
15
|
+
class Service:
|
|
16
|
+
@staticmethod
|
|
17
|
+
def memory_service(db_uri: str) -> BaseMemoryService | None:
|
|
18
|
+
"""
|
|
19
|
+
Memory服务
|
|
20
|
+
"""
|
|
21
|
+
if db_uri is None:
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
if db_uri == "":
|
|
25
|
+
return InMemoryMemoryService()
|
|
26
|
+
|
|
27
|
+
return InMemoryMemoryService()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def session_service(db_url: str) -> BaseSessionService | None:
|
|
32
|
+
"""Session服务"""
|
|
33
|
+
if db_url is None:
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
if db_url == "":
|
|
37
|
+
return InMemorySessionService()
|
|
38
|
+
|
|
39
|
+
return DatabaseSessionService(db_url=db_url)
|
|
40
|
+
|
|
41
|
+
class Agent:
|
|
42
|
+
agent_name = "AgentOS"
|
|
43
|
+
@staticmethod
|
|
44
|
+
def load_adk_agents(agents_path: str) -> BaseAgent | None:
|
|
45
|
+
"""
|
|
46
|
+
加载智能体配置
|
|
47
|
+
:param agents_path:
|
|
48
|
+
:return:
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
import sys
|
|
52
|
+
sys.path.append(str(agents_path))
|
|
53
|
+
from agents import root_agent
|
|
54
|
+
return root_agent
|
|
55
|
+
except Exception as e:
|
|
56
|
+
logging.error(f"Failed to load agents {e}")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def create_adk_agent_agui_runner(
|
|
61
|
+
agents_path: str,
|
|
62
|
+
session_service_uri: str,
|
|
63
|
+
memory_service_uri: str
|
|
64
|
+
) -> ADKAgent | None:
|
|
65
|
+
"""
|
|
66
|
+
加载智能体配置
|
|
67
|
+
"""
|
|
68
|
+
agent: BaseAgent = Utils.Agent.load_adk_agents(agents_path)
|
|
69
|
+
if agent is not None:
|
|
70
|
+
return ADKAgent(
|
|
71
|
+
adk_agent=agent,
|
|
72
|
+
app_name=f"{Utils.Agent.agent_name} AGUI APP",
|
|
73
|
+
session_service=Utils.Service.session_service(db_url=session_service_uri),
|
|
74
|
+
memory_service=Utils.Service.memory_service(db_uri=memory_service_uri),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def create_adk_agent_adk_runner(
|
|
82
|
+
agents_path: str,
|
|
83
|
+
session_service_uri: str,
|
|
84
|
+
memory_service_uri: str
|
|
85
|
+
) -> Runner | None:
|
|
86
|
+
"""
|
|
87
|
+
加载智能体配置
|
|
88
|
+
"""
|
|
89
|
+
agent: BaseAgent = Utils.Agent.load_adk_agents(agents_path)
|
|
90
|
+
if agent is not None:
|
|
91
|
+
return Runner(
|
|
92
|
+
agent=agent,
|
|
93
|
+
app_name=f"{Utils.Agent.agent_name} ADK APP",
|
|
94
|
+
session_service=Utils.Service.session_service(db_url=session_service_uri),
|
|
95
|
+
memory_service=Utils.Service.memory_service(db_uri=memory_service_uri),
|
|
96
|
+
auto_create_session=False,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Server:
|
|
103
|
+
class UserWhitelistMiddleware(BaseHTTPMiddleware):
|
|
104
|
+
ALLOWED_USERS = {"alice", "bob", "charlie@example.com"}
|
|
105
|
+
|
|
106
|
+
async def dispatch(self, request: Request, call_next):
|
|
107
|
+
# 1. 可选:定义路由白名单,用于跳过验证的端点
|
|
108
|
+
# 例如,一个公开的 /health 端点,无需用户验证
|
|
109
|
+
if request.url.path in [
|
|
110
|
+
"/agui/api/v1/health",
|
|
111
|
+
"/agui/api/v1/reload",
|
|
112
|
+
"/agui/api/v1/chat/completions",
|
|
113
|
+
"/adk/api/v1/health",
|
|
114
|
+
"/adk/api/v1/reload",
|
|
115
|
+
"/adk/api/v1/chat/completions",
|
|
116
|
+
"/docs", "/openapi.json"
|
|
117
|
+
]:
|
|
118
|
+
return await call_next(request)
|
|
119
|
+
|
|
120
|
+
# 2. 提取并验证用户ID
|
|
121
|
+
user_id = request.headers.get("X-User-Id")
|
|
122
|
+
if not user_id:
|
|
123
|
+
logging.warning(f"拒绝访问: 未提供用户ID,来自 {request.client.host}")
|
|
124
|
+
return Response("未授权: 缺少用户身份标识", status_code=401)
|
|
125
|
+
|
|
126
|
+
if user_id not in Utils.Server.UserWhitelistMiddleware.ALLOWED_USERS:
|
|
127
|
+
logging.warning(f"拒绝访问: 用户 '{user_id}' 不在白名单内,来自 {request.client.host}")
|
|
128
|
+
return Response(f"禁止访问: 用户 '{user_id}' 无权限", status_code=403)
|
|
129
|
+
|
|
130
|
+
# 3. 用户在白名单内,记录日志并继续处理请求
|
|
131
|
+
logging.info(f"用户 '{user_id}' 通过白名单验证")
|
|
132
|
+
return await call_next(request)
|
|
133
|
+
|
|
134
|
+
class ServerConfig(BaseModel):
|
|
135
|
+
# 智能体配置文件夹
|
|
136
|
+
agents_path: str
|
|
137
|
+
|
|
138
|
+
# session服务URI
|
|
139
|
+
session_service_uri: str
|
|
140
|
+
|
|
141
|
+
# memory服务URI
|
|
142
|
+
memory_service_uri: str
|
|
143
|
+
|
|
144
|
+
# 用户服务URI
|
|
145
|
+
user_service_uri: str
|
|
146
|
+
|
|
147
|
+
# trace service URI
|
|
148
|
+
trace_service_uri: str
|
|
149
|
+
|
|
150
|
+
# host
|
|
151
|
+
host: str
|
|
152
|
+
|
|
153
|
+
# port
|
|
154
|
+
port: int
|