bohr-agent-sdk 0.1.101__py3-none-any.whl → 0.1.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bohr_agent_sdk-0.1.101.dist-info → bohr_agent_sdk-0.1.102.dist-info}/METADATA +6 -2
- bohr_agent_sdk-0.1.102.dist-info/RECORD +80 -0
- dp/agent/cli/cli.py +126 -25
- dp/agent/cli/templates/__init__.py +1 -0
- dp/agent/cli/templates/calculation/simple.py.template +15 -0
- dp/agent/cli/templates/device/tescan_device.py.template +158 -0
- dp/agent/cli/templates/main.py.template +67 -0
- dp/agent/cli/templates/ui/__init__.py +1 -0
- dp/agent/cli/templates/ui/api/__init__.py +1 -0
- dp/agent/cli/templates/ui/api/config.py +32 -0
- dp/agent/cli/templates/ui/api/constants.py +61 -0
- dp/agent/cli/templates/ui/api/debug.py +257 -0
- dp/agent/cli/templates/ui/api/files.py +469 -0
- dp/agent/cli/templates/ui/api/files_upload.py +115 -0
- dp/agent/cli/templates/ui/api/files_user.py +50 -0
- dp/agent/cli/templates/ui/api/messages.py +161 -0
- dp/agent/cli/templates/ui/api/projects.py +146 -0
- dp/agent/cli/templates/ui/api/sessions.py +93 -0
- dp/agent/cli/templates/ui/api/utils.py +161 -0
- dp/agent/cli/templates/ui/api/websocket.py +184 -0
- dp/agent/cli/templates/ui/config/__init__.py +1 -0
- dp/agent/cli/templates/ui/config/agent_config.py +257 -0
- dp/agent/cli/templates/ui/frontend/index.html +13 -0
- dp/agent/cli/templates/ui/frontend/package.json +46 -0
- dp/agent/cli/templates/ui/frontend/tsconfig.json +26 -0
- dp/agent/cli/templates/ui/frontend/tsconfig.node.json +10 -0
- dp/agent/cli/templates/ui/frontend/ui-static/assets/index-DdAmKhul.js +105 -0
- dp/agent/cli/templates/ui/frontend/ui-static/assets/index-DfN2raU9.css +1 -0
- dp/agent/cli/templates/ui/frontend/ui-static/index.html +14 -0
- dp/agent/cli/templates/ui/frontend/vite.config.ts +37 -0
- dp/agent/cli/templates/ui/scripts/build_ui.py +56 -0
- dp/agent/cli/templates/ui/server/__init__.py +0 -0
- dp/agent/cli/templates/ui/server/app.py +98 -0
- dp/agent/cli/templates/ui/server/connection.py +210 -0
- dp/agent/cli/templates/ui/server/file_watcher.py +85 -0
- dp/agent/cli/templates/ui/server/middleware.py +43 -0
- dp/agent/cli/templates/ui/server/models.py +53 -0
- dp/agent/cli/templates/ui/server/session_manager.py +1158 -0
- dp/agent/cli/templates/ui/server/user_files.py +85 -0
- dp/agent/cli/templates/ui/server/utils.py +50 -0
- dp/agent/cli/templates/ui/test_download.py +98 -0
- dp/agent/cli/templates/ui/ui_utils.py +260 -0
- dp/agent/cli/templates/ui/websocket-server.py +87 -0
- dp/agent/server/storage/http_storage.py +1 -1
- bohr_agent_sdk-0.1.101.dist-info/RECORD +0 -40
- {bohr_agent_sdk-0.1.101.dist-info → bohr_agent_sdk-0.1.102.dist-info}/WHEEL +0 -0
- {bohr_agent_sdk-0.1.101.dist-info → bohr_agent_sdk-0.1.102.dist-info}/entry_points.txt +0 -0
- {bohr_agent_sdk-0.1.101.dist-info → bohr_agent_sdk-0.1.102.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Session manager - using ADK native DatabaseSessionService implementation
|
|
3
|
+
"""
|
|
4
|
+
import os
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
import asyncio
|
|
8
|
+
import traceback
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Dict, Optional, Any
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from fastapi import WebSocket
|
|
14
|
+
from google.adk import Runner
|
|
15
|
+
from google.genai import types
|
|
16
|
+
from google.adk.sessions import DatabaseSessionService, InMemorySessionService, Session
|
|
17
|
+
|
|
18
|
+
from server.connection import ConnectionContext
|
|
19
|
+
from server.user_files import UserFileManager
|
|
20
|
+
from config.agent_config import agentconfig
|
|
21
|
+
|
|
22
|
+
# Configure logging output to file
|
|
23
|
+
# Use relative path to project root or environment variable configuration
|
|
24
|
+
log_file_path = os.environ.get('WEBSOCKET_LOG_PATH', './websocket.log')
|
|
25
|
+
logging.basicConfig(
|
|
26
|
+
level=logging.DEBUG,
|
|
27
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
28
|
+
handlers=[
|
|
29
|
+
logging.FileHandler(log_file_path, encoding='utf-8'),
|
|
30
|
+
logging.StreamHandler() # Also output to console
|
|
31
|
+
]
|
|
32
|
+
)
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SessionManager:
|
|
37
|
+
"""
|
|
38
|
+
Session manager
|
|
39
|
+
Based on ADK native DatabaseSessionService, providing session management, persistence, user isolation
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# Constants
|
|
43
|
+
MAX_WAIT_TIME = 5 # Max wait time for runner initialization (seconds)
|
|
44
|
+
WAIT_INTERVAL = 0.1 # Wait interval (seconds)
|
|
45
|
+
MAX_CONTEXT_MESSAGES = 8 # Max messages in context
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
"""Initialize session manager"""
|
|
49
|
+
# Active connection management
|
|
50
|
+
self.active_connections: Dict[WebSocket, ConnectionContext] = {}
|
|
51
|
+
|
|
52
|
+
# Application config
|
|
53
|
+
self.app_name = agentconfig.config.get("agent", {}).get("name", "Agent")
|
|
54
|
+
|
|
55
|
+
# Initialize paths
|
|
56
|
+
user_working_dir = os.environ.get('USER_WORKING_DIR', os.getcwd())
|
|
57
|
+
files_config = agentconfig.get_files_config()
|
|
58
|
+
sessions_dir = files_config.get('sessionsDir', '.agent_sessions')
|
|
59
|
+
|
|
60
|
+
# Session storage directory
|
|
61
|
+
sessions_path = Path(sessions_dir)
|
|
62
|
+
if sessions_path.is_absolute():
|
|
63
|
+
self.sessions_dir = sessions_path
|
|
64
|
+
else:
|
|
65
|
+
self.sessions_dir = Path(user_working_dir) / sessions_dir
|
|
66
|
+
|
|
67
|
+
# Ensure directory exists
|
|
68
|
+
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
# SessionService cache (independent instance for each user)
|
|
71
|
+
self.session_services: Dict[str, Any] = {}
|
|
72
|
+
|
|
73
|
+
# Initialize user file manager
|
|
74
|
+
self.user_file_manager = UserFileManager(user_working_dir, str(self.sessions_dir))
|
|
75
|
+
|
|
76
|
+
# Runner cache
|
|
77
|
+
self.runners: Dict[str, Runner] = {}
|
|
78
|
+
|
|
79
|
+
# Runner error cache
|
|
80
|
+
self._runner_errors: Dict[str, str] = {}
|
|
81
|
+
|
|
82
|
+
def _create_session_service(self, user_identifier: str, is_registered: bool):
|
|
83
|
+
"""
|
|
84
|
+
Create SessionService for user
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
user_identifier: User identifier
|
|
88
|
+
is_registered: Whether user is registered
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
SessionService instance
|
|
92
|
+
"""
|
|
93
|
+
if is_registered:
|
|
94
|
+
# Registered users use DatabaseSessionService for persistence
|
|
95
|
+
user_db_dir = self.sessions_dir / "users" / user_identifier
|
|
96
|
+
user_db_dir.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
db_file = user_db_dir / "sessions.db"
|
|
99
|
+
db_url = f"sqlite:///{db_file}"
|
|
100
|
+
|
|
101
|
+
return DatabaseSessionService(db_url=db_url)
|
|
102
|
+
else:
|
|
103
|
+
# Temporary users use in-memory storage
|
|
104
|
+
return InMemorySessionService()
|
|
105
|
+
|
|
106
|
+
async def connect_client(self, websocket: WebSocket, access_key: str = "", app_key: str = ""):
|
|
107
|
+
"""
|
|
108
|
+
Connect new client
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
websocket: WebSocket connection
|
|
112
|
+
access_key: Bohrium access key
|
|
113
|
+
app_key: Bohrium app key
|
|
114
|
+
"""
|
|
115
|
+
await websocket.accept()
|
|
116
|
+
|
|
117
|
+
# Create connection context
|
|
118
|
+
context = ConnectionContext(websocket, access_key, app_key)
|
|
119
|
+
self.active_connections[websocket] = context
|
|
120
|
+
|
|
121
|
+
# Asynchronously initialize user info
|
|
122
|
+
await context.init_bohrium_user_id()
|
|
123
|
+
user_identifier = context.get_user_identifier()
|
|
124
|
+
is_registered = context.is_registered_user()
|
|
125
|
+
|
|
126
|
+
# Create independent SessionService for this user
|
|
127
|
+
session_service = self._create_session_service(user_identifier, is_registered)
|
|
128
|
+
self.session_services[user_identifier] = session_service
|
|
129
|
+
|
|
130
|
+
# Load or create sessions
|
|
131
|
+
await self._load_or_create_sessions(context, session_service)
|
|
132
|
+
|
|
133
|
+
# Send initial data
|
|
134
|
+
await self._send_initial_data(context, session_service)
|
|
135
|
+
|
|
136
|
+
async def disconnect_client(self, websocket: WebSocket):
|
|
137
|
+
"""
|
|
138
|
+
Disconnect client
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
websocket: WebSocket connection
|
|
142
|
+
"""
|
|
143
|
+
if websocket not in self.active_connections:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
context = self.active_connections[websocket]
|
|
147
|
+
user_identifier = context.get_user_identifier()
|
|
148
|
+
|
|
149
|
+
# Mark connection as disconnected to prevent further operations
|
|
150
|
+
context.is_connected = False
|
|
151
|
+
|
|
152
|
+
# Clean up SessionService
|
|
153
|
+
if user_identifier in self.session_services:
|
|
154
|
+
del self.session_services[user_identifier]
|
|
155
|
+
|
|
156
|
+
# Clean up Runner
|
|
157
|
+
for key in list(self.runners.keys()):
|
|
158
|
+
if key.startswith(f"{user_identifier}_"):
|
|
159
|
+
del self.runners[key]
|
|
160
|
+
|
|
161
|
+
# Clean up related error cache
|
|
162
|
+
for key in list(self._runner_errors.keys()):
|
|
163
|
+
if key.startswith(f"{user_identifier}_"):
|
|
164
|
+
del self._runner_errors[key]
|
|
165
|
+
|
|
166
|
+
# Clean up connection context
|
|
167
|
+
context.cleanup()
|
|
168
|
+
del self.active_connections[websocket]
|
|
169
|
+
|
|
170
|
+
async def create_session(self, context: ConnectionContext) -> Session:
|
|
171
|
+
"""
|
|
172
|
+
Create new session
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
context: Connection context
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Session ID
|
|
179
|
+
"""
|
|
180
|
+
user_identifier = context.get_user_identifier()
|
|
181
|
+
session_service = self.session_services.get(user_identifier)
|
|
182
|
+
|
|
183
|
+
if not session_service:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
# Generate session ID
|
|
187
|
+
session_id = str(uuid.uuid4())
|
|
188
|
+
|
|
189
|
+
# Create session metadata
|
|
190
|
+
metadata = {
|
|
191
|
+
"created_at": datetime.now().isoformat(),
|
|
192
|
+
"last_message_at": datetime.now().isoformat(),
|
|
193
|
+
"message_count": 0,
|
|
194
|
+
"title": "Untitled",
|
|
195
|
+
"project_id": context.project_id
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
session = await session_service.create_session(
|
|
199
|
+
app_name=self.app_name,
|
|
200
|
+
user_id=user_identifier,
|
|
201
|
+
session_id=session_id,
|
|
202
|
+
state={"metadata": metadata}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Asynchronously initialize Runner
|
|
206
|
+
asyncio.create_task(self._init_runner(context, session.id))
|
|
207
|
+
|
|
208
|
+
# Update current session
|
|
209
|
+
context.current_session_id = session.id
|
|
210
|
+
|
|
211
|
+
return session
|
|
212
|
+
|
|
213
|
+
async def delete_session(self, context: ConnectionContext, session_id: str) -> bool:
|
|
214
|
+
"""
|
|
215
|
+
Delete session
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
context: Connection context
|
|
219
|
+
session_id: Session ID
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Whether deletion succeeded
|
|
223
|
+
"""
|
|
224
|
+
user_identifier = context.get_user_identifier()
|
|
225
|
+
session_service = self.session_services.get(user_identifier)
|
|
226
|
+
|
|
227
|
+
if not session_service:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
# delete_session method returns None, not boolean
|
|
232
|
+
await session_service.delete_session(
|
|
233
|
+
app_name=self.app_name,
|
|
234
|
+
user_id=user_identifier,
|
|
235
|
+
session_id=session_id
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# No exception means success
|
|
239
|
+
|
|
240
|
+
# Clean up Runner
|
|
241
|
+
runner_key = f"{user_identifier}_{session_id}"
|
|
242
|
+
if runner_key in self.runners:
|
|
243
|
+
del self.runners[runner_key]
|
|
244
|
+
|
|
245
|
+
return True # Return success if no exception
|
|
246
|
+
except Exception as e:
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
async def switch_session(self, context: ConnectionContext, session_id: str) -> bool:
|
|
250
|
+
"""
|
|
251
|
+
Switch current session
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
context: Connection context
|
|
255
|
+
session_id: Session ID
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Whether switch succeeded
|
|
259
|
+
"""
|
|
260
|
+
user_identifier = context.get_user_identifier()
|
|
261
|
+
session_service = self.session_services.get(user_identifier)
|
|
262
|
+
|
|
263
|
+
if not session_service:
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
# Check if session exists
|
|
267
|
+
session = await session_service.get_session(
|
|
268
|
+
app_name=self.app_name,
|
|
269
|
+
user_id=user_identifier,
|
|
270
|
+
session_id=session_id
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if not session:
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
# Switch session
|
|
277
|
+
context.current_session_id = session_id
|
|
278
|
+
|
|
279
|
+
# Ensure Runner is initialized
|
|
280
|
+
runner_key = f"{user_identifier}_{session_id}"
|
|
281
|
+
if runner_key not in self.runners:
|
|
282
|
+
asyncio.create_task(self._init_runner(context, session_id))
|
|
283
|
+
|
|
284
|
+
return True
|
|
285
|
+
|
|
286
|
+
async def process_message(self, context: ConnectionContext, message: str, attachments: list = None):
|
|
287
|
+
# Save context reference for URL generation
|
|
288
|
+
self.current_context = context
|
|
289
|
+
"""
|
|
290
|
+
Process user message
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
context: Connection context
|
|
294
|
+
message: User message
|
|
295
|
+
"""
|
|
296
|
+
# Check project_id
|
|
297
|
+
if not context.project_id and not os.environ.get('BOHR_PROJECT_ID'):
|
|
298
|
+
await self._send_error(context, "🔒 请先设置项目 ID")
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
if not context.current_session_id:
|
|
302
|
+
await self._send_error(context, "没有活动的会话")
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
user_identifier = context.get_user_identifier()
|
|
306
|
+
session_service = self.session_services.get(user_identifier)
|
|
307
|
+
|
|
308
|
+
if not session_service:
|
|
309
|
+
await self._send_error(context, "会话服务未初始化")
|
|
310
|
+
return
|
|
311
|
+
|
|
312
|
+
# 获取会话
|
|
313
|
+
session = await session_service.get_session(
|
|
314
|
+
app_name=self.app_name,
|
|
315
|
+
user_id=user_identifier,
|
|
316
|
+
session_id=context.current_session_id
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if not session:
|
|
320
|
+
await self._send_error(context, "会话不存在")
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
# 等待 Runner 初始化
|
|
324
|
+
runner = await self._get_or_wait_runner(context, context.current_session_id)
|
|
325
|
+
if not runner:
|
|
326
|
+
error_details = self._runner_errors.get(f"{user_identifier}_{context.current_session_id}", "未知错误")
|
|
327
|
+
await self._send_error(
|
|
328
|
+
context,
|
|
329
|
+
f"会话初始化失败\n\n可能的原因:\n"
|
|
330
|
+
f"1. Agent 配置文件路径错误\n"
|
|
331
|
+
f"2. Agent 模块导入失败\n"
|
|
332
|
+
f"3. Project ID 无效\n\n"
|
|
333
|
+
f"错误详情:{error_details}\n\n"
|
|
334
|
+
f"请检查 config/agent-config.json 中的配置"
|
|
335
|
+
)
|
|
336
|
+
return
|
|
337
|
+
|
|
338
|
+
# 更新会话元数据(在处理消息之前)
|
|
339
|
+
await self._update_session_metadata(context, session, message)
|
|
340
|
+
|
|
341
|
+
# Get user file directory
|
|
342
|
+
user_files_dir = self.user_file_manager.get_user_files_dir(user_id=user_identifier)
|
|
343
|
+
original_cwd = os.getcwd()
|
|
344
|
+
|
|
345
|
+
try:
|
|
346
|
+
# Switch to user file directory
|
|
347
|
+
os.chdir(user_files_dir)
|
|
348
|
+
|
|
349
|
+
# Build message content
|
|
350
|
+
content = self._build_message_content(session, message, attachments)
|
|
351
|
+
|
|
352
|
+
# Process message stream
|
|
353
|
+
await self._process_message_stream(
|
|
354
|
+
context,
|
|
355
|
+
runner,
|
|
356
|
+
content,
|
|
357
|
+
user_identifier,
|
|
358
|
+
context.current_session_id
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
except Exception as e:
|
|
362
|
+
await self._send_error(context, f"处理消息失败: {str(e)}")
|
|
363
|
+
|
|
364
|
+
finally:
|
|
365
|
+
# Restore working directory
|
|
366
|
+
try:
|
|
367
|
+
os.chdir(original_cwd)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
async def _load_or_create_sessions(self, context: ConnectionContext, session_service):
|
|
372
|
+
"""Load or create sessions"""
|
|
373
|
+
user_identifier = context.get_user_identifier()
|
|
374
|
+
|
|
375
|
+
try:
|
|
376
|
+
response = await session_service.list_sessions(
|
|
377
|
+
app_name=self.app_name,
|
|
378
|
+
user_id=user_identifier
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Get session list from ListSessionsResponse object
|
|
382
|
+
sessions = response.sessions if hasattr(response, 'sessions') else []
|
|
383
|
+
|
|
384
|
+
if sessions:
|
|
385
|
+
# Sort by last message time
|
|
386
|
+
sessions.sort(
|
|
387
|
+
key=lambda s: self._get_session_last_update_time(s),
|
|
388
|
+
reverse=True
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Select most recent session as current
|
|
392
|
+
context.current_session_id = sessions[0].id
|
|
393
|
+
|
|
394
|
+
# Initialize Runner for each session
|
|
395
|
+
for session in sessions:
|
|
396
|
+
asyncio.create_task(self._init_runner(context, session.id))
|
|
397
|
+
|
|
398
|
+
else:
|
|
399
|
+
# Create new session
|
|
400
|
+
await self._create_default_session(context, session_service)
|
|
401
|
+
|
|
402
|
+
except Exception as e:
|
|
403
|
+
await self._create_default_session(context, session_service)
|
|
404
|
+
|
|
405
|
+
async def _create_default_session(self, context: ConnectionContext, session_service):
|
|
406
|
+
"""Create default session"""
|
|
407
|
+
user_identifier = context.get_user_identifier()
|
|
408
|
+
session_id = str(uuid.uuid4())
|
|
409
|
+
|
|
410
|
+
# Create session metadata
|
|
411
|
+
metadata = {
|
|
412
|
+
"created_at": datetime.now().isoformat(),
|
|
413
|
+
"last_message_at": datetime.now().isoformat(),
|
|
414
|
+
"message_count": 0,
|
|
415
|
+
"title": "Untitled",
|
|
416
|
+
"project_id": context.project_id
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
session = await session_service.create_session(
|
|
420
|
+
app_name=self.app_name,
|
|
421
|
+
user_id=user_identifier,
|
|
422
|
+
session_id=session_id,
|
|
423
|
+
state={"metadata": metadata}
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
context.current_session_id = session.id
|
|
427
|
+
|
|
428
|
+
# Initialize Runner
|
|
429
|
+
asyncio.create_task(self._init_runner(context, session.id))
|
|
430
|
+
|
|
431
|
+
async def _init_runner(self, context: ConnectionContext, session_id: str, retry_count: int = 0):
|
|
432
|
+
"""Asynchronously initialize Runner with retry mechanism"""
|
|
433
|
+
user_identifier = context.get_user_identifier()
|
|
434
|
+
runner_key = f"{user_identifier}_{session_id}"
|
|
435
|
+
max_retries = 3
|
|
436
|
+
|
|
437
|
+
logger.info(f"🚀 开始初始化 Runner: {runner_key} (尝试 {retry_count + 1}/{max_retries})")
|
|
438
|
+
logger.debug(f" 用户标识: {user_identifier}")
|
|
439
|
+
logger.debug(f" 会话ID: {session_id}")
|
|
440
|
+
logger.debug(f" Access Key: {'有' if context.access_key else '无'}")
|
|
441
|
+
logger.debug(f" App Key: {'有' if context.app_key else '无'}")
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
# Get project_id
|
|
445
|
+
project_id = context.project_id or os.environ.get('BOHR_PROJECT_ID')
|
|
446
|
+
if project_id:
|
|
447
|
+
project_id = int(project_id) if isinstance(project_id, str) else project_id
|
|
448
|
+
logger.debug(f" Project ID: {project_id}")
|
|
449
|
+
|
|
450
|
+
# Create agent
|
|
451
|
+
logger.info(f"📦 创建 Agent...")
|
|
452
|
+
logger.debug(f" 配置模块: {agentconfig.config.get('agent', {}).get('module')}")
|
|
453
|
+
logger.debug(f" Agent名称: {agentconfig.config.get('agent', {}).get('name')}")
|
|
454
|
+
|
|
455
|
+
loop = asyncio.get_event_loop()
|
|
456
|
+
user_agent = await loop.run_in_executor(
|
|
457
|
+
None,
|
|
458
|
+
agentconfig.get_agent,
|
|
459
|
+
context.access_key or "",
|
|
460
|
+
context.app_key or "",
|
|
461
|
+
project_id
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if not user_agent:
|
|
465
|
+
raise ValueError("Agent 创建失败: 返回 None")
|
|
466
|
+
|
|
467
|
+
logger.info(f"✅ Agent 创建成功: {type(user_agent).__name__}")
|
|
468
|
+
|
|
469
|
+
# Create Runner
|
|
470
|
+
logger.info(f"🏃 创建 Runner...")
|
|
471
|
+
|
|
472
|
+
# 检查连接是否仍然有效
|
|
473
|
+
if not context.is_connected:
|
|
474
|
+
logger.warning(f"⚠️ 连接已断开,跳过 Runner 初始化: {runner_key}")
|
|
475
|
+
return
|
|
476
|
+
|
|
477
|
+
session_service = self.session_services.get(user_identifier)
|
|
478
|
+
if not session_service:
|
|
479
|
+
logger.warning(f"⚠️ SessionService 已被清理,跳过 Runner 初始化: {runner_key}")
|
|
480
|
+
return
|
|
481
|
+
|
|
482
|
+
runner = Runner(
|
|
483
|
+
agent=user_agent,
|
|
484
|
+
session_service=session_service,
|
|
485
|
+
app_name=self.app_name
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self.runners[runner_key] = runner
|
|
489
|
+
logger.info(f"✅ Runner 初始化成功: {runner_key}")
|
|
490
|
+
logger.debug(f" 当前 Runner 数量: {len(self.runners)}")
|
|
491
|
+
|
|
492
|
+
# 清除之前的错误记录
|
|
493
|
+
if runner_key in self._runner_errors:
|
|
494
|
+
del self._runner_errors[runner_key]
|
|
495
|
+
|
|
496
|
+
except (ImportError, Exception) as e:
|
|
497
|
+
error_type = "导入错误" if isinstance(e, ImportError) else "Runner 初始化失败"
|
|
498
|
+
error_msg = f"❌ {error_type}: {str(e)}\n类型: {type(e).__name__}\n{traceback.format_exc()}"
|
|
499
|
+
logger.error(error_msg)
|
|
500
|
+
|
|
501
|
+
# 如果还有重试机会
|
|
502
|
+
if retry_count < max_retries - 1:
|
|
503
|
+
logger.info(f"🔄 准备重试 Runner 初始化: {runner_key}")
|
|
504
|
+
# 清理可能的部分初始化状态
|
|
505
|
+
if runner_key in self.runners:
|
|
506
|
+
del self.runners[runner_key]
|
|
507
|
+
|
|
508
|
+
# 等待一小段时间后重试
|
|
509
|
+
await asyncio.sleep(1)
|
|
510
|
+
|
|
511
|
+
# 递归调用自己进行重试
|
|
512
|
+
await self._init_runner(context, session_id, retry_count + 1)
|
|
513
|
+
else:
|
|
514
|
+
# 所有重试都失败,存储错误信息
|
|
515
|
+
self._runner_errors[runner_key] = f"{error_msg}\n\n已尝试 {max_retries} 次初始化,全部失败。"
|
|
516
|
+
logger.error(f"❌ Runner 初始化彻底失败: {runner_key},已尝试 {max_retries} 次")
|
|
517
|
+
|
|
518
|
+
async def _get_or_wait_runner(self, context: ConnectionContext, session_id: str) -> Optional[Runner]:
|
|
519
|
+
"""Get or wait for Runner initialization with auto-recovery"""
|
|
520
|
+
user_identifier = context.get_user_identifier()
|
|
521
|
+
runner_key = f"{user_identifier}_{session_id}"
|
|
522
|
+
|
|
523
|
+
logger.debug(f"⏳ 等待 Runner 初始化: {runner_key}")
|
|
524
|
+
|
|
525
|
+
# Wait for Runner initialization
|
|
526
|
+
retry_count = 0
|
|
527
|
+
max_retries = int(self.MAX_WAIT_TIME / self.WAIT_INTERVAL)
|
|
528
|
+
recovery_attempted = False
|
|
529
|
+
|
|
530
|
+
while runner_key not in self.runners and retry_count < max_retries:
|
|
531
|
+
# 检查是否有错误
|
|
532
|
+
if runner_key in self._runner_errors:
|
|
533
|
+
logger.error(f"Runner 初始化已失败: {self._runner_errors[runner_key]}")
|
|
534
|
+
|
|
535
|
+
# 如果还没有尝试过恢复,尝试一次
|
|
536
|
+
if not recovery_attempted:
|
|
537
|
+
recovery_attempted = True
|
|
538
|
+
logger.info(f"🔧 尝试自动恢复 Runner: {runner_key}")
|
|
539
|
+
|
|
540
|
+
# 清除错误记录
|
|
541
|
+
del self._runner_errors[runner_key]
|
|
542
|
+
|
|
543
|
+
# 触发新的初始化尝试
|
|
544
|
+
asyncio.create_task(self._init_runner(context, session_id))
|
|
545
|
+
|
|
546
|
+
# 继续等待
|
|
547
|
+
await asyncio.sleep(self.WAIT_INTERVAL)
|
|
548
|
+
retry_count += 1
|
|
549
|
+
continue
|
|
550
|
+
else:
|
|
551
|
+
# 已经尝试过恢复但仍然失败
|
|
552
|
+
# 发送详细的错误信息到前端
|
|
553
|
+
await self._send_error(
|
|
554
|
+
context,
|
|
555
|
+
f"会话初始化失败\n\n错误详情:\n{self._runner_errors.get(runner_key, '未知错误')}\n\n"
|
|
556
|
+
f"建议:\n"
|
|
557
|
+
f"1. 请尝试新建一个会话\n"
|
|
558
|
+
f"2. 检查 Agent 配置是否正确\n"
|
|
559
|
+
f"3. 确认 Project ID 是否有效"
|
|
560
|
+
)
|
|
561
|
+
# 清除错误缓存
|
|
562
|
+
if runner_key in self._runner_errors:
|
|
563
|
+
del self._runner_errors[runner_key]
|
|
564
|
+
return None
|
|
565
|
+
|
|
566
|
+
await asyncio.sleep(self.WAIT_INTERVAL)
|
|
567
|
+
retry_count += 1
|
|
568
|
+
|
|
569
|
+
if retry_count % 10 == 0: # 每秒记录一次
|
|
570
|
+
logger.debug(f" 仍在等待... ({retry_count * self.WAIT_INTERVAL:.1f}秒)")
|
|
571
|
+
|
|
572
|
+
runner = self.runners.get(runner_key)
|
|
573
|
+
if runner:
|
|
574
|
+
logger.info(f"✅ 获取 Runner 成功: {runner_key}")
|
|
575
|
+
else:
|
|
576
|
+
logger.error(f"❌ 超时等待 Runner: {runner_key} (等待了 {self.MAX_WAIT_TIME} 秒)")
|
|
577
|
+
# 如果超时且没有错误记录,可能是初始化太慢
|
|
578
|
+
if runner_key not in self._runner_errors:
|
|
579
|
+
await self._send_error(
|
|
580
|
+
context,
|
|
581
|
+
f"会话初始化超时\n\n"
|
|
582
|
+
f"可能的原因:\n"
|
|
583
|
+
f"1. Agent 初始化时间过长\n"
|
|
584
|
+
f"2. 系统资源不足\n\n"
|
|
585
|
+
f"建议尝试新建会话"
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
return runner
|
|
589
|
+
|
|
590
|
+
async def _update_session_metadata(self, context: ConnectionContext, session: Session, message: str):
|
|
591
|
+
"""Correctly update metadata in session.state through append_event"""
|
|
592
|
+
# Get existing metadata
|
|
593
|
+
metadata = session.state.get('metadata', {}) if session.state else {}
|
|
594
|
+
|
|
595
|
+
# Prepare new metadata
|
|
596
|
+
new_metadata = dict(metadata) # Create copy
|
|
597
|
+
new_metadata['last_message_at'] = datetime.now().isoformat()
|
|
598
|
+
new_metadata['message_count'] = new_metadata.get('message_count', 0) + 1
|
|
599
|
+
|
|
600
|
+
# Use message content as title for first message
|
|
601
|
+
if new_metadata['message_count'] == 1:
|
|
602
|
+
new_metadata['title'] = message[:50] if len(message) > 50 else message
|
|
603
|
+
|
|
604
|
+
# Create state_delta through EventActions
|
|
605
|
+
from google.adk.events import Event, EventActions
|
|
606
|
+
|
|
607
|
+
state_delta = {
|
|
608
|
+
'metadata': new_metadata
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
# Create event containing state_delta
|
|
612
|
+
update_event = Event(
|
|
613
|
+
invocation_id=f"metadata_update_{datetime.now().timestamp()}",
|
|
614
|
+
author="system",
|
|
615
|
+
actions=EventActions(state_delta=state_delta),
|
|
616
|
+
timestamp=datetime.now().timestamp()
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Update state correctly through append_event
|
|
620
|
+
user_identifier = context.get_user_identifier()
|
|
621
|
+
session_service = self.session_services.get(user_identifier)
|
|
622
|
+
if session_service:
|
|
623
|
+
await session_service.append_event(session, update_event)
|
|
624
|
+
|
|
625
|
+
async def _process_message_stream(
|
|
626
|
+
self,
|
|
627
|
+
context: ConnectionContext,
|
|
628
|
+
runner: Runner,
|
|
629
|
+
content: types.Content,
|
|
630
|
+
user_identifier: str,
|
|
631
|
+
session_id: str
|
|
632
|
+
):
|
|
633
|
+
"""Process message stream - using ADK native event handling"""
|
|
634
|
+
streaming_text = "" # Accumulate streaming text
|
|
635
|
+
|
|
636
|
+
# Run Runner
|
|
637
|
+
async for event in runner.run_async(
|
|
638
|
+
new_message=content,
|
|
639
|
+
user_id=user_identifier,
|
|
640
|
+
session_id=session_id
|
|
641
|
+
):
|
|
642
|
+
# 1. Check event author
|
|
643
|
+
author = getattr(event, 'author', None)
|
|
644
|
+
|
|
645
|
+
# 2. Check if it's streaming output
|
|
646
|
+
is_partial = getattr(event, 'partial', False)
|
|
647
|
+
|
|
648
|
+
# 3. Handle function calls (tool call requests)
|
|
649
|
+
function_calls = event.get_function_calls() if hasattr(event, 'get_function_calls') else []
|
|
650
|
+
if function_calls:
|
|
651
|
+
for call in function_calls:
|
|
652
|
+
await self._send_message(context, {
|
|
653
|
+
"type": "tool",
|
|
654
|
+
"tool_name": call.name,
|
|
655
|
+
"args": call.args, # Add tool call parameters
|
|
656
|
+
"status": "executing",
|
|
657
|
+
"timestamp": datetime.now().isoformat()
|
|
658
|
+
})
|
|
659
|
+
await asyncio.sleep(0.2)
|
|
660
|
+
# 4. Handle function responses (tool execution results)
|
|
661
|
+
function_responses = event.get_function_responses() if hasattr(event, 'get_function_responses') else []
|
|
662
|
+
if function_responses:
|
|
663
|
+
for response in function_responses:
|
|
664
|
+
# Format response result
|
|
665
|
+
result_str = self._format_response_data(response.response)
|
|
666
|
+
await self._send_message(context, {
|
|
667
|
+
"type": "tool",
|
|
668
|
+
"tool_name": response.name,
|
|
669
|
+
"result": result_str,
|
|
670
|
+
"status": "completed",
|
|
671
|
+
"timestamp": datetime.now().isoformat()
|
|
672
|
+
})
|
|
673
|
+
|
|
674
|
+
# 5. Handle text content
|
|
675
|
+
if hasattr(event, 'content') and event.content:
|
|
676
|
+
if hasattr(event.content, 'parts') and event.content.parts:
|
|
677
|
+
for part in event.content.parts:
|
|
678
|
+
if hasattr(part, 'text') and part.text:
|
|
679
|
+
if is_partial:
|
|
680
|
+
# Streaming text, accumulate
|
|
681
|
+
streaming_text += part.text
|
|
682
|
+
else:
|
|
683
|
+
# Complete text
|
|
684
|
+
text_to_send = streaming_text + part.text if streaming_text else part.text
|
|
685
|
+
streaming_text = "" # Reset accumulator
|
|
686
|
+
|
|
687
|
+
# Send message based on role
|
|
688
|
+
role = getattr(event.content, 'role', 'model')
|
|
689
|
+
if role == 'model':
|
|
690
|
+
await self._send_message(context, {
|
|
691
|
+
"type": "assistant",
|
|
692
|
+
"content": text_to_send,
|
|
693
|
+
"session_id": session_id
|
|
694
|
+
})
|
|
695
|
+
|
|
696
|
+
# 6. Check if it's final response
|
|
697
|
+
if hasattr(event, 'is_final_response') and event.is_final_response():
|
|
698
|
+
# If there's accumulated streaming text, send it now
|
|
699
|
+
if streaming_text:
|
|
700
|
+
await self._send_message(context, {
|
|
701
|
+
"type": "assistant",
|
|
702
|
+
"content": streaming_text,
|
|
703
|
+
"session_id": session_id
|
|
704
|
+
})
|
|
705
|
+
streaming_text = ""
|
|
706
|
+
|
|
707
|
+
# 7. Handle Actions (state changes and control flow)
|
|
708
|
+
if hasattr(event, 'actions') and event.actions:
|
|
709
|
+
# State changes
|
|
710
|
+
if hasattr(event.actions, 'state_delta') and event.actions.state_delta:
|
|
711
|
+
pass
|
|
712
|
+
|
|
713
|
+
# Skip summarization flag
|
|
714
|
+
if hasattr(event.actions, 'skip_summarization') and event.actions.skip_summarization:
|
|
715
|
+
pass
|
|
716
|
+
|
|
717
|
+
# Agent transfer
|
|
718
|
+
if hasattr(event.actions, 'transfer_to_agent') and event.actions.transfer_to_agent:
|
|
719
|
+
pass
|
|
720
|
+
|
|
721
|
+
# Send completion marker
|
|
722
|
+
await self._send_message(context, {
|
|
723
|
+
"type": "complete",
|
|
724
|
+
"content": ""
|
|
725
|
+
})
|
|
726
|
+
|
|
727
|
+
# Send updated session list
|
|
728
|
+
await self.send_sessions_list(context)
|
|
729
|
+
|
|
730
|
+
# _handle_tool_events method removed, functionality integrated into _process_message_stream
|
|
731
|
+
|
|
732
|
+
async def _get_session_metadata(self, session_service, user_identifier: str, session_id: str) -> dict:
|
|
733
|
+
"""Get latest session metadata"""
|
|
734
|
+
try:
|
|
735
|
+
fresh_session = await session_service.get_session(
|
|
736
|
+
app_name=self.app_name,
|
|
737
|
+
user_id=user_identifier,
|
|
738
|
+
session_id=session_id
|
|
739
|
+
)
|
|
740
|
+
if fresh_session and hasattr(fresh_session, 'state') and isinstance(fresh_session.state, dict):
|
|
741
|
+
return fresh_session.state.get('metadata', {})
|
|
742
|
+
except Exception as e:
|
|
743
|
+
pass
|
|
744
|
+
return {}
|
|
745
|
+
|
|
746
|
+
async def _send_initial_data(self, context: ConnectionContext, session_service):
|
|
747
|
+
"""Send initial data to client"""
|
|
748
|
+
user_identifier = context.get_user_identifier()
|
|
749
|
+
|
|
750
|
+
# Send session list
|
|
751
|
+
response = await session_service.list_sessions(
|
|
752
|
+
app_name=self.app_name,
|
|
753
|
+
user_id=user_identifier
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Get session list from ListSessionsResponse object
|
|
757
|
+
sessions = response.sessions if hasattr(response, 'sessions') else []
|
|
758
|
+
|
|
759
|
+
sessions_data = []
|
|
760
|
+
for session in sessions:
|
|
761
|
+
# Get latest metadata
|
|
762
|
+
metadata = await self._get_session_metadata(session_service, user_identifier, session.id)
|
|
763
|
+
if not metadata: # If fetch fails, use original data as fallback
|
|
764
|
+
metadata = session.state.get('metadata', {}) if session.state else {}
|
|
765
|
+
|
|
766
|
+
sessions_data.append({
|
|
767
|
+
"id": session.id,
|
|
768
|
+
"title": metadata.get("title", "Untitled"),
|
|
769
|
+
"created_at": metadata.get("created_at", datetime.now().isoformat()),
|
|
770
|
+
"last_message_at": metadata.get("last_message_at", datetime.now().isoformat()),
|
|
771
|
+
"message_count": metadata.get("message_count", 0)
|
|
772
|
+
})
|
|
773
|
+
|
|
774
|
+
await self._send_message(context, {
|
|
775
|
+
"type": "sessions_list",
|
|
776
|
+
"sessions": sessions_data,
|
|
777
|
+
"current_session_id": context.current_session_id
|
|
778
|
+
})
|
|
779
|
+
|
|
780
|
+
# Send current session message history
|
|
781
|
+
if context.current_session_id:
|
|
782
|
+
await self._send_session_messages(context, session_service, context.current_session_id)
|
|
783
|
+
|
|
784
|
+
async def _send_session_messages(self, context: ConnectionContext, session_service, session_id: str):
|
|
785
|
+
"""Send session message history"""
|
|
786
|
+
user_identifier = context.get_user_identifier()
|
|
787
|
+
|
|
788
|
+
session = await session_service.get_session(
|
|
789
|
+
app_name=self.app_name,
|
|
790
|
+
user_id=user_identifier,
|
|
791
|
+
session_id=session_id
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
if not session or not hasattr(session, 'events'):
|
|
795
|
+
return
|
|
796
|
+
|
|
797
|
+
messages_data = []
|
|
798
|
+
|
|
799
|
+
for event in session.events:
|
|
800
|
+
# Parse events, convert to frontend-understandable format
|
|
801
|
+
if not hasattr(event, 'content'):
|
|
802
|
+
continue
|
|
803
|
+
|
|
804
|
+
content = event.content
|
|
805
|
+
role = getattr(content, 'role', None)
|
|
806
|
+
timestamp = self._format_timestamp(getattr(event, "timestamp", None))
|
|
807
|
+
|
|
808
|
+
# Handle message content
|
|
809
|
+
if hasattr(content, 'parts'):
|
|
810
|
+
for part in content.parts:
|
|
811
|
+
# Handle text messages
|
|
812
|
+
if hasattr(part, 'text') and part.text:
|
|
813
|
+
if role == 'user':
|
|
814
|
+
messages_data.append({
|
|
815
|
+
"id": str(uuid.uuid4()),
|
|
816
|
+
"role": "user",
|
|
817
|
+
"type": "user",
|
|
818
|
+
"content": part.text,
|
|
819
|
+
"timestamp": timestamp
|
|
820
|
+
})
|
|
821
|
+
elif role == 'model':
|
|
822
|
+
messages_data.append({
|
|
823
|
+
"id": str(uuid.uuid4()),
|
|
824
|
+
"role": "assistant",
|
|
825
|
+
"type": "assistant",
|
|
826
|
+
"content": part.text,
|
|
827
|
+
"timestamp": timestamp
|
|
828
|
+
})
|
|
829
|
+
|
|
830
|
+
# Handle tool calls - don't show executing state in history
|
|
831
|
+
elif hasattr(part, 'function_call') and part.function_call:
|
|
832
|
+
# Skip function_call in history, only show final results
|
|
833
|
+
pass
|
|
834
|
+
|
|
835
|
+
# Handle tool responses - only show completed tool calls
|
|
836
|
+
elif hasattr(part, 'function_response') and part.function_response:
|
|
837
|
+
func_resp = part.function_response
|
|
838
|
+
tool_name = getattr(func_resp, 'name', 'unknown')
|
|
839
|
+
result_str = self._format_response_data(getattr(func_resp, 'response', {}))
|
|
840
|
+
# Use simple UUID for history messages
|
|
841
|
+
messages_data.append({
|
|
842
|
+
"id": str(uuid.uuid4()),
|
|
843
|
+
"role": "tool",
|
|
844
|
+
"type": "tool",
|
|
845
|
+
"tool_name": tool_name,
|
|
846
|
+
"tool_status": "completed",
|
|
847
|
+
"content": result_str,
|
|
848
|
+
"timestamp": timestamp
|
|
849
|
+
})
|
|
850
|
+
else:
|
|
851
|
+
# Simple text content
|
|
852
|
+
if role == 'user':
|
|
853
|
+
messages_data.append({
|
|
854
|
+
"id": str(uuid.uuid4()),
|
|
855
|
+
"role": "user",
|
|
856
|
+
"type": "user",
|
|
857
|
+
"content": str(content),
|
|
858
|
+
"timestamp": timestamp
|
|
859
|
+
})
|
|
860
|
+
elif role == 'model':
|
|
861
|
+
messages_data.append({
|
|
862
|
+
"id": str(uuid.uuid4()),
|
|
863
|
+
"role": "assistant",
|
|
864
|
+
"type": "assistant",
|
|
865
|
+
"content": str(content),
|
|
866
|
+
"timestamp": timestamp
|
|
867
|
+
})
|
|
868
|
+
|
|
869
|
+
await self._send_message(context, {
|
|
870
|
+
"type": "session_messages",
|
|
871
|
+
"session_id": session_id,
|
|
872
|
+
"messages": messages_data
|
|
873
|
+
})
|
|
874
|
+
|
|
875
|
+
async def _send_message(self, context: ConnectionContext, message: dict):
|
|
876
|
+
"""Send message to client"""
|
|
877
|
+
if 'id' not in message:
|
|
878
|
+
message['id'] = f"{message.get('type', 'unknown')}_{datetime.now().timestamp()}"
|
|
879
|
+
|
|
880
|
+
try:
|
|
881
|
+
await context.websocket.send_json(message)
|
|
882
|
+
except Exception as e:
|
|
883
|
+
asyncio.create_task(self.disconnect_client(context.websocket))
|
|
884
|
+
|
|
885
|
+
async def _send_error(self, context: ConnectionContext, error_message: str):
|
|
886
|
+
"""Send error message"""
|
|
887
|
+
await self._send_message(context, {
|
|
888
|
+
"type": "error",
|
|
889
|
+
"content": error_message
|
|
890
|
+
})
|
|
891
|
+
|
|
892
|
+
def _get_session_last_update_time(self, session: Session) -> datetime:
|
|
893
|
+
"""Get session last update time"""
|
|
894
|
+
# Get metadata from session.state
|
|
895
|
+
if hasattr(session, 'state') and isinstance(session.state, dict):
|
|
896
|
+
metadata = session.state.get('metadata', {})
|
|
897
|
+
last_message_at = metadata.get('last_message_at')
|
|
898
|
+
if last_message_at:
|
|
899
|
+
try:
|
|
900
|
+
return datetime.fromisoformat(last_message_at)
|
|
901
|
+
except:
|
|
902
|
+
pass
|
|
903
|
+
|
|
904
|
+
# Use ADK native last_update_time
|
|
905
|
+
if hasattr(session, 'last_update_time'):
|
|
906
|
+
return datetime.fromtimestamp(session.last_update_time)
|
|
907
|
+
|
|
908
|
+
return datetime.min
|
|
909
|
+
|
|
910
|
+
def _get_base_url(self, context: ConnectionContext) -> str:
|
|
911
|
+
"""动态获取基础URL"""
|
|
912
|
+
headers = getattr(context, 'request_headers', {})
|
|
913
|
+
|
|
914
|
+
# 1. 从Origin头获取
|
|
915
|
+
origin = headers.get('origin', '')
|
|
916
|
+
if origin:
|
|
917
|
+
return origin
|
|
918
|
+
|
|
919
|
+
# 2. 从Host头获取
|
|
920
|
+
host = headers.get('host', '')
|
|
921
|
+
if host:
|
|
922
|
+
forwarded_proto = headers.get('x-forwarded-proto', '')
|
|
923
|
+
protocol = 'https' if forwarded_proto == 'https' else 'http'
|
|
924
|
+
return f"{protocol}://{host}"
|
|
925
|
+
|
|
926
|
+
# 3. 从环境变量获取
|
|
927
|
+
base_url = os.environ.get('AGENT_API_URL', '')
|
|
928
|
+
if base_url:
|
|
929
|
+
return base_url.rstrip('/')
|
|
930
|
+
|
|
931
|
+
# 4. 默认值
|
|
932
|
+
return "http://localhost:8000"
|
|
933
|
+
|
|
934
|
+
def _build_message_content(self, session, message: str, attachments: list = None) -> types.Content:
|
|
935
|
+
"""Build message content (including history context and attachments)"""
|
|
936
|
+
# Build message with file attachment information
|
|
937
|
+
enhanced_message = message
|
|
938
|
+
|
|
939
|
+
if attachments and hasattr(self, 'current_context'):
|
|
940
|
+
# 获取基础URL和用户ID
|
|
941
|
+
base_url = self._get_base_url(self.current_context)
|
|
942
|
+
user_id = self.current_context.get_user_identifier()
|
|
943
|
+
|
|
944
|
+
file_info = "\n\n用户已上传文件,请你跟据上传的文件路径,调用传入所希望调用的工具中,:"
|
|
945
|
+
for att in attachments:
|
|
946
|
+
file_info += f"\n 文件路径: {att['relative_path']}"
|
|
947
|
+
|
|
948
|
+
enhanced_message = message + file_info if message else file_info.strip()
|
|
949
|
+
|
|
950
|
+
return types.Content(
|
|
951
|
+
role='user',
|
|
952
|
+
parts=[types.Part(text=enhanced_message)]
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
def _format_response_data(self, response_data):
|
|
956
|
+
"""Format response data"""
|
|
957
|
+
if isinstance(response_data, (dict, list, tuple)):
|
|
958
|
+
try:
|
|
959
|
+
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
|
960
|
+
except:
|
|
961
|
+
return str(response_data)
|
|
962
|
+
return str(response_data) if not isinstance(response_data, str) else response_data
|
|
963
|
+
|
|
964
|
+
def _extract_final_response(self, events: list) -> Optional[str]:
|
|
965
|
+
"""Extract final response from event list"""
|
|
966
|
+
for event in reversed(events):
|
|
967
|
+
if hasattr(event, 'content') and event.content:
|
|
968
|
+
content = event.content
|
|
969
|
+
if hasattr(content, 'parts') and content.parts:
|
|
970
|
+
text_parts = []
|
|
971
|
+
for part in content.parts:
|
|
972
|
+
if hasattr(part, 'text') and part.text:
|
|
973
|
+
text_parts.append(part.text)
|
|
974
|
+
if text_parts:
|
|
975
|
+
return '\n'.join(text_parts)
|
|
976
|
+
return None
|
|
977
|
+
|
|
978
|
+
def _event_to_message_data(self, event) -> Optional[dict]:
|
|
979
|
+
"""Convert event to message data"""
|
|
980
|
+
if not event:
|
|
981
|
+
return None
|
|
982
|
+
|
|
983
|
+
# Handle different types of events
|
|
984
|
+
message_data = {
|
|
985
|
+
"id": str(uuid.uuid4()),
|
|
986
|
+
"timestamp": self._format_timestamp(getattr(event, "timestamp", None))
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
# Extract info based on event type
|
|
990
|
+
if hasattr(event, 'type'):
|
|
991
|
+
message_data["type"] = event.type
|
|
992
|
+
|
|
993
|
+
if hasattr(event, 'role'):
|
|
994
|
+
# Unify role field: convert role to frontend-expected type format
|
|
995
|
+
role = event.role
|
|
996
|
+
if role == 'model':
|
|
997
|
+
message_data["type"] = "assistant"
|
|
998
|
+
elif role == 'user':
|
|
999
|
+
message_data["type"] = "user"
|
|
1000
|
+
else:
|
|
1001
|
+
message_data["type"] = role
|
|
1002
|
+
message_data["role"] = role # Preserve original role info
|
|
1003
|
+
|
|
1004
|
+
if hasattr(event, 'content'):
|
|
1005
|
+
# Handle Content objects
|
|
1006
|
+
content = event.content
|
|
1007
|
+
if hasattr(content, 'parts'):
|
|
1008
|
+
text_parts = []
|
|
1009
|
+
tool_calls = []
|
|
1010
|
+
tool_responses = []
|
|
1011
|
+
|
|
1012
|
+
for part in content.parts:
|
|
1013
|
+
# Handle text content
|
|
1014
|
+
if hasattr(part, 'text') and part.text is not None:
|
|
1015
|
+
text_parts.append(part.text)
|
|
1016
|
+
|
|
1017
|
+
# Handle tool calls
|
|
1018
|
+
if hasattr(part, 'function_call'):
|
|
1019
|
+
func_call = part.function_call
|
|
1020
|
+
tool_calls.append({
|
|
1021
|
+
"id": getattr(func_call, 'id', ''),
|
|
1022
|
+
"name": getattr(func_call, 'name', ''),
|
|
1023
|
+
"args": getattr(func_call, 'args', {})
|
|
1024
|
+
})
|
|
1025
|
+
|
|
1026
|
+
# Handle tool responses
|
|
1027
|
+
if hasattr(part, 'function_response'):
|
|
1028
|
+
func_resp = part.function_response
|
|
1029
|
+
tool_responses.append({
|
|
1030
|
+
"id": getattr(func_resp, 'id', ''),
|
|
1031
|
+
"name": getattr(func_resp, 'name', ''),
|
|
1032
|
+
"response": getattr(func_resp, 'response', {})
|
|
1033
|
+
})
|
|
1034
|
+
|
|
1035
|
+
# Set message content
|
|
1036
|
+
if text_parts:
|
|
1037
|
+
message_data["content"] = '\n'.join(text_parts)
|
|
1038
|
+
|
|
1039
|
+
# Set tool call info
|
|
1040
|
+
if tool_calls:
|
|
1041
|
+
message_data["tool_calls"] = tool_calls
|
|
1042
|
+
|
|
1043
|
+
if tool_responses:
|
|
1044
|
+
message_data["tool_responses"] = tool_responses
|
|
1045
|
+
|
|
1046
|
+
else:
|
|
1047
|
+
message_data["content"] = str(content)
|
|
1048
|
+
|
|
1049
|
+
# Only return messages with content
|
|
1050
|
+
if "content" in message_data or "tool_calls" in message_data or "tool_responses" in message_data:
|
|
1051
|
+
return message_data
|
|
1052
|
+
|
|
1053
|
+
return None
|
|
1054
|
+
|
|
1055
|
+
def _format_timestamp(self, timestamp) -> str:
|
|
1056
|
+
"""Format timestamp"""
|
|
1057
|
+
if timestamp is None:
|
|
1058
|
+
return datetime.now(timezone.utc).isoformat()
|
|
1059
|
+
|
|
1060
|
+
if isinstance(timestamp, (int, float)):
|
|
1061
|
+
# Convert Unix timestamp to ISO format
|
|
1062
|
+
return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
|
|
1063
|
+
|
|
1064
|
+
if isinstance(timestamp, str):
|
|
1065
|
+
return timestamp
|
|
1066
|
+
|
|
1067
|
+
return datetime.now(timezone.utc).isoformat()
|
|
1068
|
+
|
|
1069
|
+
|
|
1070
|
+
def get_user_identifier_from_request(self, access_key: str = None, app_key: str = None) -> Optional[str]:
|
|
1071
|
+
"""
|
|
1072
|
+
Get user identifier from request info (prefer from connected context)
|
|
1073
|
+
|
|
1074
|
+
Args:
|
|
1075
|
+
access_key: Bohrium access key
|
|
1076
|
+
app_key: Bohrium app key (reserved for future extension)
|
|
1077
|
+
|
|
1078
|
+
Returns:
|
|
1079
|
+
User identifier or None
|
|
1080
|
+
"""
|
|
1081
|
+
if access_key:
|
|
1082
|
+
# Check if there's a connected user
|
|
1083
|
+
for context in self.active_connections.values():
|
|
1084
|
+
if context.access_key == access_key:
|
|
1085
|
+
return context.get_user_identifier()
|
|
1086
|
+
return None
|
|
1087
|
+
|
|
1088
|
+
async def send_sessions_list(self, context: ConnectionContext):
|
|
1089
|
+
"""
|
|
1090
|
+
Send session list to client
|
|
1091
|
+
|
|
1092
|
+
Args:
|
|
1093
|
+
context: Connection context
|
|
1094
|
+
"""
|
|
1095
|
+
user_identifier = context.get_user_identifier()
|
|
1096
|
+
session_service = self.session_services.get(user_identifier)
|
|
1097
|
+
|
|
1098
|
+
if not session_service:
|
|
1099
|
+
await self._send_error(context, "会话服务未初始化")
|
|
1100
|
+
return
|
|
1101
|
+
|
|
1102
|
+
try:
|
|
1103
|
+
# Get session list
|
|
1104
|
+
response = await session_service.list_sessions(
|
|
1105
|
+
app_name=self.app_name,
|
|
1106
|
+
user_id=user_identifier
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
# Get session list from ListSessionsResponse object
|
|
1110
|
+
sessions = response.sessions if hasattr(response, 'sessions') else []
|
|
1111
|
+
|
|
1112
|
+
sessions_data = []
|
|
1113
|
+
for session in sessions:
|
|
1114
|
+
# Uniformly use helper method to get latest metadata
|
|
1115
|
+
metadata = await self._get_session_metadata(session_service, user_identifier, session.id)
|
|
1116
|
+
if not metadata: # If fetch fails, use original data as fallback
|
|
1117
|
+
metadata = session.state.get('metadata', {}) if session.state else {}
|
|
1118
|
+
|
|
1119
|
+
title = metadata.get("title", "Untitled")
|
|
1120
|
+
|
|
1121
|
+
sessions_data.append({
|
|
1122
|
+
"id": session.id,
|
|
1123
|
+
"title": title,
|
|
1124
|
+
"created_at": metadata.get("created_at", datetime.now().isoformat()),
|
|
1125
|
+
"last_message_at": metadata.get("last_message_at", datetime.now().isoformat()),
|
|
1126
|
+
"message_count": metadata.get("message_count", 0)
|
|
1127
|
+
})
|
|
1128
|
+
|
|
1129
|
+
await self._send_message(context, {
|
|
1130
|
+
"type": "sessions_list",
|
|
1131
|
+
"sessions": sessions_data,
|
|
1132
|
+
"current_session_id": context.current_session_id
|
|
1133
|
+
})
|
|
1134
|
+
|
|
1135
|
+
except Exception as e:
|
|
1136
|
+
await self._send_error(context, "获取会话列表失败")
|
|
1137
|
+
|
|
1138
|
+
async def send_session_messages(self, context: ConnectionContext, session_id: str):
|
|
1139
|
+
"""
|
|
1140
|
+
Send message history for specified session
|
|
1141
|
+
|
|
1142
|
+
Args:
|
|
1143
|
+
context: Connection context
|
|
1144
|
+
session_id: Session ID
|
|
1145
|
+
"""
|
|
1146
|
+
user_identifier = context.get_user_identifier()
|
|
1147
|
+
session_service = self.session_services.get(user_identifier)
|
|
1148
|
+
|
|
1149
|
+
if not session_service:
|
|
1150
|
+
await self._send_error(context, "会话服务未初始化")
|
|
1151
|
+
return
|
|
1152
|
+
|
|
1153
|
+
try:
|
|
1154
|
+
# Directly call internal method, reuse logic
|
|
1155
|
+
await self._send_session_messages(context, session_service, session_id)
|
|
1156
|
+
|
|
1157
|
+
except Exception as e:
|
|
1158
|
+
await self._send_error(context, "获取会话消息失败")
|