bohr-agent-sdk 0.1.100__py3-none-any.whl → 0.1.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {bohr_agent_sdk-0.1.100.dist-info → bohr_agent_sdk-0.1.102.dist-info}/METADATA +6 -2
  2. bohr_agent_sdk-0.1.102.dist-info/RECORD +80 -0
  3. dp/agent/adapter/adk/client/calculation_mcp_tool.py +10 -2
  4. dp/agent/cli/cli.py +126 -25
  5. dp/agent/cli/templates/__init__.py +1 -0
  6. dp/agent/cli/templates/calculation/simple.py.template +15 -0
  7. dp/agent/cli/templates/device/tescan_device.py.template +158 -0
  8. dp/agent/cli/templates/main.py.template +67 -0
  9. dp/agent/cli/templates/ui/__init__.py +1 -0
  10. dp/agent/cli/templates/ui/api/__init__.py +1 -0
  11. dp/agent/cli/templates/ui/api/config.py +32 -0
  12. dp/agent/cli/templates/ui/api/constants.py +61 -0
  13. dp/agent/cli/templates/ui/api/debug.py +257 -0
  14. dp/agent/cli/templates/ui/api/files.py +469 -0
  15. dp/agent/cli/templates/ui/api/files_upload.py +115 -0
  16. dp/agent/cli/templates/ui/api/files_user.py +50 -0
  17. dp/agent/cli/templates/ui/api/messages.py +161 -0
  18. dp/agent/cli/templates/ui/api/projects.py +146 -0
  19. dp/agent/cli/templates/ui/api/sessions.py +93 -0
  20. dp/agent/cli/templates/ui/api/utils.py +161 -0
  21. dp/agent/cli/templates/ui/api/websocket.py +184 -0
  22. dp/agent/cli/templates/ui/config/__init__.py +1 -0
  23. dp/agent/cli/templates/ui/config/agent_config.py +257 -0
  24. dp/agent/cli/templates/ui/frontend/index.html +13 -0
  25. dp/agent/cli/templates/ui/frontend/package.json +46 -0
  26. dp/agent/cli/templates/ui/frontend/tsconfig.json +26 -0
  27. dp/agent/cli/templates/ui/frontend/tsconfig.node.json +10 -0
  28. dp/agent/cli/templates/ui/frontend/ui-static/assets/index-DdAmKhul.js +105 -0
  29. dp/agent/cli/templates/ui/frontend/ui-static/assets/index-DfN2raU9.css +1 -0
  30. dp/agent/cli/templates/ui/frontend/ui-static/index.html +14 -0
  31. dp/agent/cli/templates/ui/frontend/vite.config.ts +37 -0
  32. dp/agent/cli/templates/ui/scripts/build_ui.py +56 -0
  33. dp/agent/cli/templates/ui/server/__init__.py +0 -0
  34. dp/agent/cli/templates/ui/server/app.py +98 -0
  35. dp/agent/cli/templates/ui/server/connection.py +210 -0
  36. dp/agent/cli/templates/ui/server/file_watcher.py +85 -0
  37. dp/agent/cli/templates/ui/server/middleware.py +43 -0
  38. dp/agent/cli/templates/ui/server/models.py +53 -0
  39. dp/agent/cli/templates/ui/server/session_manager.py +1158 -0
  40. dp/agent/cli/templates/ui/server/user_files.py +85 -0
  41. dp/agent/cli/templates/ui/server/utils.py +50 -0
  42. dp/agent/cli/templates/ui/test_download.py +98 -0
  43. dp/agent/cli/templates/ui/ui_utils.py +260 -0
  44. dp/agent/cli/templates/ui/websocket-server.py +87 -0
  45. dp/agent/server/storage/http_storage.py +1 -1
  46. bohr_agent_sdk-0.1.100.dist-info/RECORD +0 -40
  47. {bohr_agent_sdk-0.1.100.dist-info → bohr_agent_sdk-0.1.102.dist-info}/WHEEL +0 -0
  48. {bohr_agent_sdk-0.1.100.dist-info → bohr_agent_sdk-0.1.102.dist-info}/entry_points.txt +0 -0
  49. {bohr_agent_sdk-0.1.100.dist-info → bohr_agent_sdk-0.1.102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1158 @@
1
+ """
2
+ Session manager - using ADK native DatabaseSessionService implementation
3
+ """
4
+ import os
5
+ import json
6
+ import uuid
7
+ import asyncio
8
+ import traceback
9
+ import logging
10
+ from typing import Dict, Optional, Any
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from fastapi import WebSocket
14
+ from google.adk import Runner
15
+ from google.genai import types
16
+ from google.adk.sessions import DatabaseSessionService, InMemorySessionService, Session
17
+
18
+ from server.connection import ConnectionContext
19
+ from server.user_files import UserFileManager
20
+ from config.agent_config import agentconfig
21
+
22
+ # Configure logging output to file
23
+ # Use relative path to project root or environment variable configuration
24
+ log_file_path = os.environ.get('WEBSOCKET_LOG_PATH', './websocket.log')
25
+ logging.basicConfig(
26
+ level=logging.DEBUG,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
28
+ handlers=[
29
+ logging.FileHandler(log_file_path, encoding='utf-8'),
30
+ logging.StreamHandler() # Also output to console
31
+ ]
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class SessionManager:
37
+ """
38
+ Session manager
39
+ Based on ADK native DatabaseSessionService, providing session management, persistence, user isolation
40
+ """
41
+
42
+ # Constants
43
+ MAX_WAIT_TIME = 5 # Max wait time for runner initialization (seconds)
44
+ WAIT_INTERVAL = 0.1 # Wait interval (seconds)
45
+ MAX_CONTEXT_MESSAGES = 8 # Max messages in context
46
+
47
+ def __init__(self):
48
+ """Initialize session manager"""
49
+ # Active connection management
50
+ self.active_connections: Dict[WebSocket, ConnectionContext] = {}
51
+
52
+ # Application config
53
+ self.app_name = agentconfig.config.get("agent", {}).get("name", "Agent")
54
+
55
+ # Initialize paths
56
+ user_working_dir = os.environ.get('USER_WORKING_DIR', os.getcwd())
57
+ files_config = agentconfig.get_files_config()
58
+ sessions_dir = files_config.get('sessionsDir', '.agent_sessions')
59
+
60
+ # Session storage directory
61
+ sessions_path = Path(sessions_dir)
62
+ if sessions_path.is_absolute():
63
+ self.sessions_dir = sessions_path
64
+ else:
65
+ self.sessions_dir = Path(user_working_dir) / sessions_dir
66
+
67
+ # Ensure directory exists
68
+ self.sessions_dir.mkdir(parents=True, exist_ok=True)
69
+
70
+ # SessionService cache (independent instance for each user)
71
+ self.session_services: Dict[str, Any] = {}
72
+
73
+ # Initialize user file manager
74
+ self.user_file_manager = UserFileManager(user_working_dir, str(self.sessions_dir))
75
+
76
+ # Runner cache
77
+ self.runners: Dict[str, Runner] = {}
78
+
79
+ # Runner error cache
80
+ self._runner_errors: Dict[str, str] = {}
81
+
82
+ def _create_session_service(self, user_identifier: str, is_registered: bool):
83
+ """
84
+ Create SessionService for user
85
+
86
+ Args:
87
+ user_identifier: User identifier
88
+ is_registered: Whether user is registered
89
+
90
+ Returns:
91
+ SessionService instance
92
+ """
93
+ if is_registered:
94
+ # Registered users use DatabaseSessionService for persistence
95
+ user_db_dir = self.sessions_dir / "users" / user_identifier
96
+ user_db_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ db_file = user_db_dir / "sessions.db"
99
+ db_url = f"sqlite:///{db_file}"
100
+
101
+ return DatabaseSessionService(db_url=db_url)
102
+ else:
103
+ # Temporary users use in-memory storage
104
+ return InMemorySessionService()
105
+
106
+ async def connect_client(self, websocket: WebSocket, access_key: str = "", app_key: str = ""):
107
+ """
108
+ Connect new client
109
+
110
+ Args:
111
+ websocket: WebSocket connection
112
+ access_key: Bohrium access key
113
+ app_key: Bohrium app key
114
+ """
115
+ await websocket.accept()
116
+
117
+ # Create connection context
118
+ context = ConnectionContext(websocket, access_key, app_key)
119
+ self.active_connections[websocket] = context
120
+
121
+ # Asynchronously initialize user info
122
+ await context.init_bohrium_user_id()
123
+ user_identifier = context.get_user_identifier()
124
+ is_registered = context.is_registered_user()
125
+
126
+ # Create independent SessionService for this user
127
+ session_service = self._create_session_service(user_identifier, is_registered)
128
+ self.session_services[user_identifier] = session_service
129
+
130
+ # Load or create sessions
131
+ await self._load_or_create_sessions(context, session_service)
132
+
133
+ # Send initial data
134
+ await self._send_initial_data(context, session_service)
135
+
136
+ async def disconnect_client(self, websocket: WebSocket):
137
+ """
138
+ Disconnect client
139
+
140
+ Args:
141
+ websocket: WebSocket connection
142
+ """
143
+ if websocket not in self.active_connections:
144
+ return
145
+
146
+ context = self.active_connections[websocket]
147
+ user_identifier = context.get_user_identifier()
148
+
149
+ # Mark connection as disconnected to prevent further operations
150
+ context.is_connected = False
151
+
152
+ # Clean up SessionService
153
+ if user_identifier in self.session_services:
154
+ del self.session_services[user_identifier]
155
+
156
+ # Clean up Runner
157
+ for key in list(self.runners.keys()):
158
+ if key.startswith(f"{user_identifier}_"):
159
+ del self.runners[key]
160
+
161
+ # Clean up related error cache
162
+ for key in list(self._runner_errors.keys()):
163
+ if key.startswith(f"{user_identifier}_"):
164
+ del self._runner_errors[key]
165
+
166
+ # Clean up connection context
167
+ context.cleanup()
168
+ del self.active_connections[websocket]
169
+
170
+ async def create_session(self, context: ConnectionContext) -> Session:
171
+ """
172
+ Create new session
173
+
174
+ Args:
175
+ context: Connection context
176
+
177
+ Returns:
178
+ Session ID
179
+ """
180
+ user_identifier = context.get_user_identifier()
181
+ session_service = self.session_services.get(user_identifier)
182
+
183
+ if not session_service:
184
+ return None
185
+
186
+ # Generate session ID
187
+ session_id = str(uuid.uuid4())
188
+
189
+ # Create session metadata
190
+ metadata = {
191
+ "created_at": datetime.now().isoformat(),
192
+ "last_message_at": datetime.now().isoformat(),
193
+ "message_count": 0,
194
+ "title": "Untitled",
195
+ "project_id": context.project_id
196
+ }
197
+
198
+ session = await session_service.create_session(
199
+ app_name=self.app_name,
200
+ user_id=user_identifier,
201
+ session_id=session_id,
202
+ state={"metadata": metadata}
203
+ )
204
+
205
+ # Asynchronously initialize Runner
206
+ asyncio.create_task(self._init_runner(context, session.id))
207
+
208
+ # Update current session
209
+ context.current_session_id = session.id
210
+
211
+ return session
212
+
213
+ async def delete_session(self, context: ConnectionContext, session_id: str) -> bool:
214
+ """
215
+ Delete session
216
+
217
+ Args:
218
+ context: Connection context
219
+ session_id: Session ID
220
+
221
+ Returns:
222
+ Whether deletion succeeded
223
+ """
224
+ user_identifier = context.get_user_identifier()
225
+ session_service = self.session_services.get(user_identifier)
226
+
227
+ if not session_service:
228
+ return False
229
+
230
+ try:
231
+ # delete_session method returns None, not boolean
232
+ await session_service.delete_session(
233
+ app_name=self.app_name,
234
+ user_id=user_identifier,
235
+ session_id=session_id
236
+ )
237
+
238
+ # No exception means success
239
+
240
+ # Clean up Runner
241
+ runner_key = f"{user_identifier}_{session_id}"
242
+ if runner_key in self.runners:
243
+ del self.runners[runner_key]
244
+
245
+ return True # Return success if no exception
246
+ except Exception as e:
247
+ return False
248
+
249
+ async def switch_session(self, context: ConnectionContext, session_id: str) -> bool:
250
+ """
251
+ Switch current session
252
+
253
+ Args:
254
+ context: Connection context
255
+ session_id: Session ID
256
+
257
+ Returns:
258
+ Whether switch succeeded
259
+ """
260
+ user_identifier = context.get_user_identifier()
261
+ session_service = self.session_services.get(user_identifier)
262
+
263
+ if not session_service:
264
+ return False
265
+
266
+ # Check if session exists
267
+ session = await session_service.get_session(
268
+ app_name=self.app_name,
269
+ user_id=user_identifier,
270
+ session_id=session_id
271
+ )
272
+
273
+ if not session:
274
+ return False
275
+
276
+ # Switch session
277
+ context.current_session_id = session_id
278
+
279
+ # Ensure Runner is initialized
280
+ runner_key = f"{user_identifier}_{session_id}"
281
+ if runner_key not in self.runners:
282
+ asyncio.create_task(self._init_runner(context, session_id))
283
+
284
+ return True
285
+
286
+ async def process_message(self, context: ConnectionContext, message: str, attachments: list = None):
287
+ # Save context reference for URL generation
288
+ self.current_context = context
289
+ """
290
+ Process user message
291
+
292
+ Args:
293
+ context: Connection context
294
+ message: User message
295
+ """
296
+ # Check project_id
297
+ if not context.project_id and not os.environ.get('BOHR_PROJECT_ID'):
298
+ await self._send_error(context, "🔒 请先设置项目 ID")
299
+ return
300
+
301
+ if not context.current_session_id:
302
+ await self._send_error(context, "没有活动的会话")
303
+ return
304
+
305
+ user_identifier = context.get_user_identifier()
306
+ session_service = self.session_services.get(user_identifier)
307
+
308
+ if not session_service:
309
+ await self._send_error(context, "会话服务未初始化")
310
+ return
311
+
312
+ # 获取会话
313
+ session = await session_service.get_session(
314
+ app_name=self.app_name,
315
+ user_id=user_identifier,
316
+ session_id=context.current_session_id
317
+ )
318
+
319
+ if not session:
320
+ await self._send_error(context, "会话不存在")
321
+ return
322
+
323
+ # 等待 Runner 初始化
324
+ runner = await self._get_or_wait_runner(context, context.current_session_id)
325
+ if not runner:
326
+ error_details = self._runner_errors.get(f"{user_identifier}_{context.current_session_id}", "未知错误")
327
+ await self._send_error(
328
+ context,
329
+ f"会话初始化失败\n\n可能的原因:\n"
330
+ f"1. Agent 配置文件路径错误\n"
331
+ f"2. Agent 模块导入失败\n"
332
+ f"3. Project ID 无效\n\n"
333
+ f"错误详情:{error_details}\n\n"
334
+ f"请检查 config/agent-config.json 中的配置"
335
+ )
336
+ return
337
+
338
+ # 更新会话元数据(在处理消息之前)
339
+ await self._update_session_metadata(context, session, message)
340
+
341
+ # Get user file directory
342
+ user_files_dir = self.user_file_manager.get_user_files_dir(user_id=user_identifier)
343
+ original_cwd = os.getcwd()
344
+
345
+ try:
346
+ # Switch to user file directory
347
+ os.chdir(user_files_dir)
348
+
349
+ # Build message content
350
+ content = self._build_message_content(session, message, attachments)
351
+
352
+ # Process message stream
353
+ await self._process_message_stream(
354
+ context,
355
+ runner,
356
+ content,
357
+ user_identifier,
358
+ context.current_session_id
359
+ )
360
+
361
+ except Exception as e:
362
+ await self._send_error(context, f"处理消息失败: {str(e)}")
363
+
364
+ finally:
365
+ # Restore working directory
366
+ try:
367
+ os.chdir(original_cwd)
368
+ except Exception as e:
369
+ pass
370
+
371
+ async def _load_or_create_sessions(self, context: ConnectionContext, session_service):
372
+ """Load or create sessions"""
373
+ user_identifier = context.get_user_identifier()
374
+
375
+ try:
376
+ response = await session_service.list_sessions(
377
+ app_name=self.app_name,
378
+ user_id=user_identifier
379
+ )
380
+
381
+ # Get session list from ListSessionsResponse object
382
+ sessions = response.sessions if hasattr(response, 'sessions') else []
383
+
384
+ if sessions:
385
+ # Sort by last message time
386
+ sessions.sort(
387
+ key=lambda s: self._get_session_last_update_time(s),
388
+ reverse=True
389
+ )
390
+
391
+ # Select most recent session as current
392
+ context.current_session_id = sessions[0].id
393
+
394
+ # Initialize Runner for each session
395
+ for session in sessions:
396
+ asyncio.create_task(self._init_runner(context, session.id))
397
+
398
+ else:
399
+ # Create new session
400
+ await self._create_default_session(context, session_service)
401
+
402
+ except Exception as e:
403
+ await self._create_default_session(context, session_service)
404
+
405
+ async def _create_default_session(self, context: ConnectionContext, session_service):
406
+ """Create default session"""
407
+ user_identifier = context.get_user_identifier()
408
+ session_id = str(uuid.uuid4())
409
+
410
+ # Create session metadata
411
+ metadata = {
412
+ "created_at": datetime.now().isoformat(),
413
+ "last_message_at": datetime.now().isoformat(),
414
+ "message_count": 0,
415
+ "title": "Untitled",
416
+ "project_id": context.project_id
417
+ }
418
+
419
+ session = await session_service.create_session(
420
+ app_name=self.app_name,
421
+ user_id=user_identifier,
422
+ session_id=session_id,
423
+ state={"metadata": metadata}
424
+ )
425
+
426
+ context.current_session_id = session.id
427
+
428
+ # Initialize Runner
429
+ asyncio.create_task(self._init_runner(context, session.id))
430
+
431
+ async def _init_runner(self, context: ConnectionContext, session_id: str, retry_count: int = 0):
432
+ """Asynchronously initialize Runner with retry mechanism"""
433
+ user_identifier = context.get_user_identifier()
434
+ runner_key = f"{user_identifier}_{session_id}"
435
+ max_retries = 3
436
+
437
+ logger.info(f"🚀 开始初始化 Runner: {runner_key} (尝试 {retry_count + 1}/{max_retries})")
438
+ logger.debug(f" 用户标识: {user_identifier}")
439
+ logger.debug(f" 会话ID: {session_id}")
440
+ logger.debug(f" Access Key: {'有' if context.access_key else '无'}")
441
+ logger.debug(f" App Key: {'有' if context.app_key else '无'}")
442
+
443
+ try:
444
+ # Get project_id
445
+ project_id = context.project_id or os.environ.get('BOHR_PROJECT_ID')
446
+ if project_id:
447
+ project_id = int(project_id) if isinstance(project_id, str) else project_id
448
+ logger.debug(f" Project ID: {project_id}")
449
+
450
+ # Create agent
451
+ logger.info(f"📦 创建 Agent...")
452
+ logger.debug(f" 配置模块: {agentconfig.config.get('agent', {}).get('module')}")
453
+ logger.debug(f" Agent名称: {agentconfig.config.get('agent', {}).get('name')}")
454
+
455
+ loop = asyncio.get_event_loop()
456
+ user_agent = await loop.run_in_executor(
457
+ None,
458
+ agentconfig.get_agent,
459
+ context.access_key or "",
460
+ context.app_key or "",
461
+ project_id
462
+ )
463
+
464
+ if not user_agent:
465
+ raise ValueError("Agent 创建失败: 返回 None")
466
+
467
+ logger.info(f"✅ Agent 创建成功: {type(user_agent).__name__}")
468
+
469
+ # Create Runner
470
+ logger.info(f"🏃 创建 Runner...")
471
+
472
+ # 检查连接是否仍然有效
473
+ if not context.is_connected:
474
+ logger.warning(f"⚠️ 连接已断开,跳过 Runner 初始化: {runner_key}")
475
+ return
476
+
477
+ session_service = self.session_services.get(user_identifier)
478
+ if not session_service:
479
+ logger.warning(f"⚠️ SessionService 已被清理,跳过 Runner 初始化: {runner_key}")
480
+ return
481
+
482
+ runner = Runner(
483
+ agent=user_agent,
484
+ session_service=session_service,
485
+ app_name=self.app_name
486
+ )
487
+
488
+ self.runners[runner_key] = runner
489
+ logger.info(f"✅ Runner 初始化成功: {runner_key}")
490
+ logger.debug(f" 当前 Runner 数量: {len(self.runners)}")
491
+
492
+ # 清除之前的错误记录
493
+ if runner_key in self._runner_errors:
494
+ del self._runner_errors[runner_key]
495
+
496
+ except (ImportError, Exception) as e:
497
+ error_type = "导入错误" if isinstance(e, ImportError) else "Runner 初始化失败"
498
+ error_msg = f"❌ {error_type}: {str(e)}\n类型: {type(e).__name__}\n{traceback.format_exc()}"
499
+ logger.error(error_msg)
500
+
501
+ # 如果还有重试机会
502
+ if retry_count < max_retries - 1:
503
+ logger.info(f"🔄 准备重试 Runner 初始化: {runner_key}")
504
+ # 清理可能的部分初始化状态
505
+ if runner_key in self.runners:
506
+ del self.runners[runner_key]
507
+
508
+ # 等待一小段时间后重试
509
+ await asyncio.sleep(1)
510
+
511
+ # 递归调用自己进行重试
512
+ await self._init_runner(context, session_id, retry_count + 1)
513
+ else:
514
+ # 所有重试都失败,存储错误信息
515
+ self._runner_errors[runner_key] = f"{error_msg}\n\n已尝试 {max_retries} 次初始化,全部失败。"
516
+ logger.error(f"❌ Runner 初始化彻底失败: {runner_key},已尝试 {max_retries} 次")
517
+
518
+ async def _get_or_wait_runner(self, context: ConnectionContext, session_id: str) -> Optional[Runner]:
519
+ """Get or wait for Runner initialization with auto-recovery"""
520
+ user_identifier = context.get_user_identifier()
521
+ runner_key = f"{user_identifier}_{session_id}"
522
+
523
+ logger.debug(f"⏳ 等待 Runner 初始化: {runner_key}")
524
+
525
+ # Wait for Runner initialization
526
+ retry_count = 0
527
+ max_retries = int(self.MAX_WAIT_TIME / self.WAIT_INTERVAL)
528
+ recovery_attempted = False
529
+
530
+ while runner_key not in self.runners and retry_count < max_retries:
531
+ # 检查是否有错误
532
+ if runner_key in self._runner_errors:
533
+ logger.error(f"Runner 初始化已失败: {self._runner_errors[runner_key]}")
534
+
535
+ # 如果还没有尝试过恢复,尝试一次
536
+ if not recovery_attempted:
537
+ recovery_attempted = True
538
+ logger.info(f"🔧 尝试自动恢复 Runner: {runner_key}")
539
+
540
+ # 清除错误记录
541
+ del self._runner_errors[runner_key]
542
+
543
+ # 触发新的初始化尝试
544
+ asyncio.create_task(self._init_runner(context, session_id))
545
+
546
+ # 继续等待
547
+ await asyncio.sleep(self.WAIT_INTERVAL)
548
+ retry_count += 1
549
+ continue
550
+ else:
551
+ # 已经尝试过恢复但仍然失败
552
+ # 发送详细的错误信息到前端
553
+ await self._send_error(
554
+ context,
555
+ f"会话初始化失败\n\n错误详情:\n{self._runner_errors.get(runner_key, '未知错误')}\n\n"
556
+ f"建议:\n"
557
+ f"1. 请尝试新建一个会话\n"
558
+ f"2. 检查 Agent 配置是否正确\n"
559
+ f"3. 确认 Project ID 是否有效"
560
+ )
561
+ # 清除错误缓存
562
+ if runner_key in self._runner_errors:
563
+ del self._runner_errors[runner_key]
564
+ return None
565
+
566
+ await asyncio.sleep(self.WAIT_INTERVAL)
567
+ retry_count += 1
568
+
569
+ if retry_count % 10 == 0: # 每秒记录一次
570
+ logger.debug(f" 仍在等待... ({retry_count * self.WAIT_INTERVAL:.1f}秒)")
571
+
572
+ runner = self.runners.get(runner_key)
573
+ if runner:
574
+ logger.info(f"✅ 获取 Runner 成功: {runner_key}")
575
+ else:
576
+ logger.error(f"❌ 超时等待 Runner: {runner_key} (等待了 {self.MAX_WAIT_TIME} 秒)")
577
+ # 如果超时且没有错误记录,可能是初始化太慢
578
+ if runner_key not in self._runner_errors:
579
+ await self._send_error(
580
+ context,
581
+ f"会话初始化超时\n\n"
582
+ f"可能的原因:\n"
583
+ f"1. Agent 初始化时间过长\n"
584
+ f"2. 系统资源不足\n\n"
585
+ f"建议尝试新建会话"
586
+ )
587
+
588
+ return runner
589
+
590
+ async def _update_session_metadata(self, context: ConnectionContext, session: Session, message: str):
591
+ """Correctly update metadata in session.state through append_event"""
592
+ # Get existing metadata
593
+ metadata = session.state.get('metadata', {}) if session.state else {}
594
+
595
+ # Prepare new metadata
596
+ new_metadata = dict(metadata) # Create copy
597
+ new_metadata['last_message_at'] = datetime.now().isoformat()
598
+ new_metadata['message_count'] = new_metadata.get('message_count', 0) + 1
599
+
600
+ # Use message content as title for first message
601
+ if new_metadata['message_count'] == 1:
602
+ new_metadata['title'] = message[:50] if len(message) > 50 else message
603
+
604
+ # Create state_delta through EventActions
605
+ from google.adk.events import Event, EventActions
606
+
607
+ state_delta = {
608
+ 'metadata': new_metadata
609
+ }
610
+
611
+ # Create event containing state_delta
612
+ update_event = Event(
613
+ invocation_id=f"metadata_update_{datetime.now().timestamp()}",
614
+ author="system",
615
+ actions=EventActions(state_delta=state_delta),
616
+ timestamp=datetime.now().timestamp()
617
+ )
618
+
619
+ # Update state correctly through append_event
620
+ user_identifier = context.get_user_identifier()
621
+ session_service = self.session_services.get(user_identifier)
622
+ if session_service:
623
+ await session_service.append_event(session, update_event)
624
+
625
+ async def _process_message_stream(
626
+ self,
627
+ context: ConnectionContext,
628
+ runner: Runner,
629
+ content: types.Content,
630
+ user_identifier: str,
631
+ session_id: str
632
+ ):
633
+ """Process message stream - using ADK native event handling"""
634
+ streaming_text = "" # Accumulate streaming text
635
+
636
+ # Run Runner
637
+ async for event in runner.run_async(
638
+ new_message=content,
639
+ user_id=user_identifier,
640
+ session_id=session_id
641
+ ):
642
+ # 1. Check event author
643
+ author = getattr(event, 'author', None)
644
+
645
+ # 2. Check if it's streaming output
646
+ is_partial = getattr(event, 'partial', False)
647
+
648
+ # 3. Handle function calls (tool call requests)
649
+ function_calls = event.get_function_calls() if hasattr(event, 'get_function_calls') else []
650
+ if function_calls:
651
+ for call in function_calls:
652
+ await self._send_message(context, {
653
+ "type": "tool",
654
+ "tool_name": call.name,
655
+ "args": call.args, # Add tool call parameters
656
+ "status": "executing",
657
+ "timestamp": datetime.now().isoformat()
658
+ })
659
+ await asyncio.sleep(0.2)
660
+ # 4. Handle function responses (tool execution results)
661
+ function_responses = event.get_function_responses() if hasattr(event, 'get_function_responses') else []
662
+ if function_responses:
663
+ for response in function_responses:
664
+ # Format response result
665
+ result_str = self._format_response_data(response.response)
666
+ await self._send_message(context, {
667
+ "type": "tool",
668
+ "tool_name": response.name,
669
+ "result": result_str,
670
+ "status": "completed",
671
+ "timestamp": datetime.now().isoformat()
672
+ })
673
+
674
+ # 5. Handle text content
675
+ if hasattr(event, 'content') and event.content:
676
+ if hasattr(event.content, 'parts') and event.content.parts:
677
+ for part in event.content.parts:
678
+ if hasattr(part, 'text') and part.text:
679
+ if is_partial:
680
+ # Streaming text, accumulate
681
+ streaming_text += part.text
682
+ else:
683
+ # Complete text
684
+ text_to_send = streaming_text + part.text if streaming_text else part.text
685
+ streaming_text = "" # Reset accumulator
686
+
687
+ # Send message based on role
688
+ role = getattr(event.content, 'role', 'model')
689
+ if role == 'model':
690
+ await self._send_message(context, {
691
+ "type": "assistant",
692
+ "content": text_to_send,
693
+ "session_id": session_id
694
+ })
695
+
696
+ # 6. Check if it's final response
697
+ if hasattr(event, 'is_final_response') and event.is_final_response():
698
+ # If there's accumulated streaming text, send it now
699
+ if streaming_text:
700
+ await self._send_message(context, {
701
+ "type": "assistant",
702
+ "content": streaming_text,
703
+ "session_id": session_id
704
+ })
705
+ streaming_text = ""
706
+
707
+ # 7. Handle Actions (state changes and control flow)
708
+ if hasattr(event, 'actions') and event.actions:
709
+ # State changes
710
+ if hasattr(event.actions, 'state_delta') and event.actions.state_delta:
711
+ pass
712
+
713
+ # Skip summarization flag
714
+ if hasattr(event.actions, 'skip_summarization') and event.actions.skip_summarization:
715
+ pass
716
+
717
+ # Agent transfer
718
+ if hasattr(event.actions, 'transfer_to_agent') and event.actions.transfer_to_agent:
719
+ pass
720
+
721
+ # Send completion marker
722
+ await self._send_message(context, {
723
+ "type": "complete",
724
+ "content": ""
725
+ })
726
+
727
+ # Send updated session list
728
+ await self.send_sessions_list(context)
729
+
730
+ # _handle_tool_events method removed, functionality integrated into _process_message_stream
731
+
732
+ async def _get_session_metadata(self, session_service, user_identifier: str, session_id: str) -> dict:
733
+ """Get latest session metadata"""
734
+ try:
735
+ fresh_session = await session_service.get_session(
736
+ app_name=self.app_name,
737
+ user_id=user_identifier,
738
+ session_id=session_id
739
+ )
740
+ if fresh_session and hasattr(fresh_session, 'state') and isinstance(fresh_session.state, dict):
741
+ return fresh_session.state.get('metadata', {})
742
+ except Exception as e:
743
+ pass
744
+ return {}
745
+
746
+ async def _send_initial_data(self, context: ConnectionContext, session_service):
747
+ """Send initial data to client"""
748
+ user_identifier = context.get_user_identifier()
749
+
750
+ # Send session list
751
+ response = await session_service.list_sessions(
752
+ app_name=self.app_name,
753
+ user_id=user_identifier
754
+ )
755
+
756
+ # Get session list from ListSessionsResponse object
757
+ sessions = response.sessions if hasattr(response, 'sessions') else []
758
+
759
+ sessions_data = []
760
+ for session in sessions:
761
+ # Get latest metadata
762
+ metadata = await self._get_session_metadata(session_service, user_identifier, session.id)
763
+ if not metadata: # If fetch fails, use original data as fallback
764
+ metadata = session.state.get('metadata', {}) if session.state else {}
765
+
766
+ sessions_data.append({
767
+ "id": session.id,
768
+ "title": metadata.get("title", "Untitled"),
769
+ "created_at": metadata.get("created_at", datetime.now().isoformat()),
770
+ "last_message_at": metadata.get("last_message_at", datetime.now().isoformat()),
771
+ "message_count": metadata.get("message_count", 0)
772
+ })
773
+
774
+ await self._send_message(context, {
775
+ "type": "sessions_list",
776
+ "sessions": sessions_data,
777
+ "current_session_id": context.current_session_id
778
+ })
779
+
780
+ # Send current session message history
781
+ if context.current_session_id:
782
+ await self._send_session_messages(context, session_service, context.current_session_id)
783
+
784
+ async def _send_session_messages(self, context: ConnectionContext, session_service, session_id: str):
785
+ """Send session message history"""
786
+ user_identifier = context.get_user_identifier()
787
+
788
+ session = await session_service.get_session(
789
+ app_name=self.app_name,
790
+ user_id=user_identifier,
791
+ session_id=session_id
792
+ )
793
+
794
+ if not session or not hasattr(session, 'events'):
795
+ return
796
+
797
+ messages_data = []
798
+
799
+ for event in session.events:
800
+ # Parse events, convert to frontend-understandable format
801
+ if not hasattr(event, 'content'):
802
+ continue
803
+
804
+ content = event.content
805
+ role = getattr(content, 'role', None)
806
+ timestamp = self._format_timestamp(getattr(event, "timestamp", None))
807
+
808
+ # Handle message content
809
+ if hasattr(content, 'parts'):
810
+ for part in content.parts:
811
+ # Handle text messages
812
+ if hasattr(part, 'text') and part.text:
813
+ if role == 'user':
814
+ messages_data.append({
815
+ "id": str(uuid.uuid4()),
816
+ "role": "user",
817
+ "type": "user",
818
+ "content": part.text,
819
+ "timestamp": timestamp
820
+ })
821
+ elif role == 'model':
822
+ messages_data.append({
823
+ "id": str(uuid.uuid4()),
824
+ "role": "assistant",
825
+ "type": "assistant",
826
+ "content": part.text,
827
+ "timestamp": timestamp
828
+ })
829
+
830
+ # Handle tool calls - don't show executing state in history
831
+ elif hasattr(part, 'function_call') and part.function_call:
832
+ # Skip function_call in history, only show final results
833
+ pass
834
+
835
+ # Handle tool responses - only show completed tool calls
836
+ elif hasattr(part, 'function_response') and part.function_response:
837
+ func_resp = part.function_response
838
+ tool_name = getattr(func_resp, 'name', 'unknown')
839
+ result_str = self._format_response_data(getattr(func_resp, 'response', {}))
840
+ # Use simple UUID for history messages
841
+ messages_data.append({
842
+ "id": str(uuid.uuid4()),
843
+ "role": "tool",
844
+ "type": "tool",
845
+ "tool_name": tool_name,
846
+ "tool_status": "completed",
847
+ "content": result_str,
848
+ "timestamp": timestamp
849
+ })
850
+ else:
851
+ # Simple text content
852
+ if role == 'user':
853
+ messages_data.append({
854
+ "id": str(uuid.uuid4()),
855
+ "role": "user",
856
+ "type": "user",
857
+ "content": str(content),
858
+ "timestamp": timestamp
859
+ })
860
+ elif role == 'model':
861
+ messages_data.append({
862
+ "id": str(uuid.uuid4()),
863
+ "role": "assistant",
864
+ "type": "assistant",
865
+ "content": str(content),
866
+ "timestamp": timestamp
867
+ })
868
+
869
+ await self._send_message(context, {
870
+ "type": "session_messages",
871
+ "session_id": session_id,
872
+ "messages": messages_data
873
+ })
874
+
875
+ async def _send_message(self, context: ConnectionContext, message: dict):
876
+ """Send message to client"""
877
+ if 'id' not in message:
878
+ message['id'] = f"{message.get('type', 'unknown')}_{datetime.now().timestamp()}"
879
+
880
+ try:
881
+ await context.websocket.send_json(message)
882
+ except Exception as e:
883
+ asyncio.create_task(self.disconnect_client(context.websocket))
884
+
885
+ async def _send_error(self, context: ConnectionContext, error_message: str):
886
+ """Send error message"""
887
+ await self._send_message(context, {
888
+ "type": "error",
889
+ "content": error_message
890
+ })
891
+
892
+ def _get_session_last_update_time(self, session: Session) -> datetime:
893
+ """Get session last update time"""
894
+ # Get metadata from session.state
895
+ if hasattr(session, 'state') and isinstance(session.state, dict):
896
+ metadata = session.state.get('metadata', {})
897
+ last_message_at = metadata.get('last_message_at')
898
+ if last_message_at:
899
+ try:
900
+ return datetime.fromisoformat(last_message_at)
901
+ except:
902
+ pass
903
+
904
+ # Use ADK native last_update_time
905
+ if hasattr(session, 'last_update_time'):
906
+ return datetime.fromtimestamp(session.last_update_time)
907
+
908
+ return datetime.min
909
+
910
+ def _get_base_url(self, context: ConnectionContext) -> str:
911
+ """动态获取基础URL"""
912
+ headers = getattr(context, 'request_headers', {})
913
+
914
+ # 1. 从Origin头获取
915
+ origin = headers.get('origin', '')
916
+ if origin:
917
+ return origin
918
+
919
+ # 2. 从Host头获取
920
+ host = headers.get('host', '')
921
+ if host:
922
+ forwarded_proto = headers.get('x-forwarded-proto', '')
923
+ protocol = 'https' if forwarded_proto == 'https' else 'http'
924
+ return f"{protocol}://{host}"
925
+
926
+ # 3. 从环境变量获取
927
+ base_url = os.environ.get('AGENT_API_URL', '')
928
+ if base_url:
929
+ return base_url.rstrip('/')
930
+
931
+ # 4. 默认值
932
+ return "http://localhost:8000"
933
+
934
+ def _build_message_content(self, session, message: str, attachments: list = None) -> types.Content:
935
+ """Build message content (including history context and attachments)"""
936
+ # Build message with file attachment information
937
+ enhanced_message = message
938
+
939
+ if attachments and hasattr(self, 'current_context'):
940
+ # 获取基础URL和用户ID
941
+ base_url = self._get_base_url(self.current_context)
942
+ user_id = self.current_context.get_user_identifier()
943
+
944
+ file_info = "\n\n用户已上传文件,请你跟据上传的文件路径,调用传入所希望调用的工具中,:"
945
+ for att in attachments:
946
+ file_info += f"\n 文件路径: {att['relative_path']}"
947
+
948
+ enhanced_message = message + file_info if message else file_info.strip()
949
+
950
+ return types.Content(
951
+ role='user',
952
+ parts=[types.Part(text=enhanced_message)]
953
+ )
954
+
955
+ def _format_response_data(self, response_data):
956
+ """Format response data"""
957
+ if isinstance(response_data, (dict, list, tuple)):
958
+ try:
959
+ return json.dumps(response_data, indent=2, ensure_ascii=False)
960
+ except:
961
+ return str(response_data)
962
+ return str(response_data) if not isinstance(response_data, str) else response_data
963
+
964
+ def _extract_final_response(self, events: list) -> Optional[str]:
965
+ """Extract final response from event list"""
966
+ for event in reversed(events):
967
+ if hasattr(event, 'content') and event.content:
968
+ content = event.content
969
+ if hasattr(content, 'parts') and content.parts:
970
+ text_parts = []
971
+ for part in content.parts:
972
+ if hasattr(part, 'text') and part.text:
973
+ text_parts.append(part.text)
974
+ if text_parts:
975
+ return '\n'.join(text_parts)
976
+ return None
977
+
978
+ def _event_to_message_data(self, event) -> Optional[dict]:
979
+ """Convert event to message data"""
980
+ if not event:
981
+ return None
982
+
983
+ # Handle different types of events
984
+ message_data = {
985
+ "id": str(uuid.uuid4()),
986
+ "timestamp": self._format_timestamp(getattr(event, "timestamp", None))
987
+ }
988
+
989
+ # Extract info based on event type
990
+ if hasattr(event, 'type'):
991
+ message_data["type"] = event.type
992
+
993
+ if hasattr(event, 'role'):
994
+ # Unify role field: convert role to frontend-expected type format
995
+ role = event.role
996
+ if role == 'model':
997
+ message_data["type"] = "assistant"
998
+ elif role == 'user':
999
+ message_data["type"] = "user"
1000
+ else:
1001
+ message_data["type"] = role
1002
+ message_data["role"] = role # Preserve original role info
1003
+
1004
+ if hasattr(event, 'content'):
1005
+ # Handle Content objects
1006
+ content = event.content
1007
+ if hasattr(content, 'parts'):
1008
+ text_parts = []
1009
+ tool_calls = []
1010
+ tool_responses = []
1011
+
1012
+ for part in content.parts:
1013
+ # Handle text content
1014
+ if hasattr(part, 'text') and part.text is not None:
1015
+ text_parts.append(part.text)
1016
+
1017
+ # Handle tool calls
1018
+ if hasattr(part, 'function_call'):
1019
+ func_call = part.function_call
1020
+ tool_calls.append({
1021
+ "id": getattr(func_call, 'id', ''),
1022
+ "name": getattr(func_call, 'name', ''),
1023
+ "args": getattr(func_call, 'args', {})
1024
+ })
1025
+
1026
+ # Handle tool responses
1027
+ if hasattr(part, 'function_response'):
1028
+ func_resp = part.function_response
1029
+ tool_responses.append({
1030
+ "id": getattr(func_resp, 'id', ''),
1031
+ "name": getattr(func_resp, 'name', ''),
1032
+ "response": getattr(func_resp, 'response', {})
1033
+ })
1034
+
1035
+ # Set message content
1036
+ if text_parts:
1037
+ message_data["content"] = '\n'.join(text_parts)
1038
+
1039
+ # Set tool call info
1040
+ if tool_calls:
1041
+ message_data["tool_calls"] = tool_calls
1042
+
1043
+ if tool_responses:
1044
+ message_data["tool_responses"] = tool_responses
1045
+
1046
+ else:
1047
+ message_data["content"] = str(content)
1048
+
1049
+ # Only return messages with content
1050
+ if "content" in message_data or "tool_calls" in message_data or "tool_responses" in message_data:
1051
+ return message_data
1052
+
1053
+ return None
1054
+
1055
+ def _format_timestamp(self, timestamp) -> str:
1056
+ """Format timestamp"""
1057
+ if timestamp is None:
1058
+ return datetime.now(timezone.utc).isoformat()
1059
+
1060
+ if isinstance(timestamp, (int, float)):
1061
+ # Convert Unix timestamp to ISO format
1062
+ return datetime.fromtimestamp(timestamp, tz=timezone.utc).isoformat()
1063
+
1064
+ if isinstance(timestamp, str):
1065
+ return timestamp
1066
+
1067
+ return datetime.now(timezone.utc).isoformat()
1068
+
1069
+
1070
+ def get_user_identifier_from_request(self, access_key: str = None, app_key: str = None) -> Optional[str]:
1071
+ """
1072
+ Get user identifier from request info (prefer from connected context)
1073
+
1074
+ Args:
1075
+ access_key: Bohrium access key
1076
+ app_key: Bohrium app key (reserved for future extension)
1077
+
1078
+ Returns:
1079
+ User identifier or None
1080
+ """
1081
+ if access_key:
1082
+ # Check if there's a connected user
1083
+ for context in self.active_connections.values():
1084
+ if context.access_key == access_key:
1085
+ return context.get_user_identifier()
1086
+ return None
1087
+
1088
+ async def send_sessions_list(self, context: ConnectionContext):
1089
+ """
1090
+ Send session list to client
1091
+
1092
+ Args:
1093
+ context: Connection context
1094
+ """
1095
+ user_identifier = context.get_user_identifier()
1096
+ session_service = self.session_services.get(user_identifier)
1097
+
1098
+ if not session_service:
1099
+ await self._send_error(context, "会话服务未初始化")
1100
+ return
1101
+
1102
+ try:
1103
+ # Get session list
1104
+ response = await session_service.list_sessions(
1105
+ app_name=self.app_name,
1106
+ user_id=user_identifier
1107
+ )
1108
+
1109
+ # Get session list from ListSessionsResponse object
1110
+ sessions = response.sessions if hasattr(response, 'sessions') else []
1111
+
1112
+ sessions_data = []
1113
+ for session in sessions:
1114
+ # Uniformly use helper method to get latest metadata
1115
+ metadata = await self._get_session_metadata(session_service, user_identifier, session.id)
1116
+ if not metadata: # If fetch fails, use original data as fallback
1117
+ metadata = session.state.get('metadata', {}) if session.state else {}
1118
+
1119
+ title = metadata.get("title", "Untitled")
1120
+
1121
+ sessions_data.append({
1122
+ "id": session.id,
1123
+ "title": title,
1124
+ "created_at": metadata.get("created_at", datetime.now().isoformat()),
1125
+ "last_message_at": metadata.get("last_message_at", datetime.now().isoformat()),
1126
+ "message_count": metadata.get("message_count", 0)
1127
+ })
1128
+
1129
+ await self._send_message(context, {
1130
+ "type": "sessions_list",
1131
+ "sessions": sessions_data,
1132
+ "current_session_id": context.current_session_id
1133
+ })
1134
+
1135
+ except Exception as e:
1136
+ await self._send_error(context, "获取会话列表失败")
1137
+
1138
+ async def send_session_messages(self, context: ConnectionContext, session_id: str):
1139
+ """
1140
+ Send message history for specified session
1141
+
1142
+ Args:
1143
+ context: Connection context
1144
+ session_id: Session ID
1145
+ """
1146
+ user_identifier = context.get_user_identifier()
1147
+ session_service = self.session_services.get(user_identifier)
1148
+
1149
+ if not session_service:
1150
+ await self._send_error(context, "会话服务未初始化")
1151
+ return
1152
+
1153
+ try:
1154
+ # Directly call internal method, reuse logic
1155
+ await self._send_session_messages(context, session_service, session_id)
1156
+
1157
+ except Exception as e:
1158
+ await self._send_error(context, "获取会话消息失败")