AstrBot 4.5.8__py3-none-any.whl → 4.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. astrbot/core/agent/mcp_client.py +152 -26
  2. astrbot/core/agent/message.py +7 -0
  3. astrbot/core/config/default.py +31 -1
  4. astrbot/core/core_lifecycle.py +8 -0
  5. astrbot/core/db/__init__.py +50 -1
  6. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  7. astrbot/core/db/po.py +49 -13
  8. astrbot/core/db/sqlite.py +102 -3
  9. astrbot/core/knowledge_base/kb_helper.py +314 -33
  10. astrbot/core/knowledge_base/kb_mgr.py +45 -1
  11. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  12. astrbot/core/knowledge_base/prompts.py +65 -0
  13. astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
  14. astrbot/core/pipeline/process_stage/utils.py +60 -16
  15. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
  16. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
  17. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
  18. astrbot/core/provider/entities.py +22 -9
  19. astrbot/core/provider/func_tool_manager.py +12 -9
  20. astrbot/core/provider/manager.py +4 -0
  21. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  22. astrbot/core/provider/sources/gemini_source.py +25 -8
  23. astrbot/core/provider/sources/openai_source.py +9 -16
  24. astrbot/dashboard/routes/chat.py +134 -77
  25. astrbot/dashboard/routes/knowledge_base.py +172 -0
  26. {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/METADATA +5 -4
  27. {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/RECORD +30 -26
  28. {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/WHEEL +0 -0
  29. {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/entry_points.txt +0 -0
  30. {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,236 @@
1
+ import os
2
+
3
+ import aiohttp
4
+
5
+ from astrbot import logger
6
+
7
+ from ..entities import ProviderType, RerankResult
8
+ from ..provider import RerankProvider
9
+ from ..register import register_provider_adapter
10
+
11
+
12
+ class BailianRerankError(Exception):
13
+ """百炼重排序服务异常基类"""
14
+
15
+ pass
16
+
17
+
18
+ class BailianAPIError(BailianRerankError):
19
+ """百炼API返回错误"""
20
+
21
+ pass
22
+
23
+
24
+ class BailianNetworkError(BailianRerankError):
25
+ """百炼网络请求错误"""
26
+
27
+ pass
28
+
29
+
30
+ @register_provider_adapter(
31
+ "bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
32
+ )
33
+ class BailianRerankProvider(RerankProvider):
34
+ """阿里云百炼文本重排序适配器."""
35
+
36
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
37
+ super().__init__(provider_config, provider_settings)
38
+ self.provider_config = provider_config
39
+ self.provider_settings = provider_settings
40
+
41
+ # API配置
42
+ self.api_key = provider_config.get("rerank_api_key") or os.getenv(
43
+ "DASHSCOPE_API_KEY", ""
44
+ )
45
+ if not self.api_key:
46
+ raise ValueError("阿里云百炼 API Key 不能为空。")
47
+
48
+ self.model = provider_config.get("rerank_model", "qwen3-rerank")
49
+ self.timeout = provider_config.get("timeout", 30)
50
+ self.return_documents = provider_config.get("return_documents", False)
51
+ self.instruct = provider_config.get("instruct", "")
52
+
53
+ self.base_url = provider_config.get(
54
+ "rerank_api_base",
55
+ "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
56
+ )
57
+
58
+ # 设置HTTP客户端
59
+ headers = {
60
+ "Authorization": f"Bearer {self.api_key}",
61
+ "Content-Type": "application/json",
62
+ }
63
+
64
+ self.client = aiohttp.ClientSession(
65
+ headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
66
+ )
67
+
68
+ # 设置模型名称
69
+ self.set_model(self.model)
70
+
71
+ logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
72
+
73
+ def _build_payload(
74
+ self, query: str, documents: list[str], top_n: int | None
75
+ ) -> dict:
76
+ """构建请求载荷
77
+
78
+ Args:
79
+ query: 查询文本
80
+ documents: 文档列表
81
+ top_n: 返回前N个结果,如果为None则返回所有结果
82
+
83
+ Returns:
84
+ 请求载荷字典
85
+ """
86
+ base = {"model": self.model, "input": {"query": query, "documents": documents}}
87
+
88
+ params = {
89
+ k: v
90
+ for k, v in [
91
+ ("top_n", top_n if top_n is not None and top_n > 0 else None),
92
+ ("return_documents", True if self.return_documents else None),
93
+ (
94
+ "instruct",
95
+ self.instruct
96
+ if self.instruct and self.model == "qwen3-rerank"
97
+ else None,
98
+ ),
99
+ ]
100
+ if v is not None
101
+ }
102
+
103
+ if params:
104
+ base["parameters"] = params
105
+
106
+ return base
107
+
108
+ def _parse_results(self, data: dict) -> list[RerankResult]:
109
+ """解析API响应结果
110
+
111
+ Args:
112
+ data: API响应数据
113
+
114
+ Returns:
115
+ 重排序结果列表
116
+
117
+ Raises:
118
+ BailianAPIError: API返回错误
119
+ KeyError: 结果缺少必要字段
120
+ """
121
+ # 检查响应状态
122
+ if data.get("code", "200") != "200":
123
+ raise BailianAPIError(
124
+ f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
125
+ )
126
+
127
+ results = data.get("output", {}).get("results", [])
128
+ if not results:
129
+ logger.warning(f"百炼 Rerank 返回空结果: {data}")
130
+ return []
131
+
132
+ # 转换为RerankResult对象,使用.get()避免KeyError
133
+ rerank_results = []
134
+ for idx, result in enumerate(results):
135
+ try:
136
+ index = result.get("index", idx)
137
+ relevance_score = result.get("relevance_score", 0.0)
138
+
139
+ if relevance_score is None:
140
+ logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
141
+ relevance_score = 0.0
142
+
143
+ rerank_result = RerankResult(
144
+ index=index, relevance_score=relevance_score
145
+ )
146
+ rerank_results.append(rerank_result)
147
+ except Exception as e:
148
+ logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
149
+ continue
150
+
151
+ return rerank_results
152
+
153
+ def _log_usage(self, data: dict) -> None:
154
+ """记录使用量信息
155
+
156
+ Args:
157
+ data: API响应数据
158
+ """
159
+ tokens = data.get("usage", {}).get("total_tokens", 0)
160
+ if tokens > 0:
161
+ logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
162
+
163
+ async def rerank(
164
+ self,
165
+ query: str,
166
+ documents: list[str],
167
+ top_n: int | None = None,
168
+ ) -> list[RerankResult]:
169
+ """
170
+ 对文档进行重排序
171
+
172
+ Args:
173
+ query: 查询文本
174
+ documents: 待排序的文档列表
175
+ top_n: 返回前N个结果,如果为None则使用配置中的默认值
176
+
177
+ Returns:
178
+ 重排序结果列表
179
+ """
180
+ if not documents:
181
+ logger.warning("文档列表为空,返回空结果")
182
+ return []
183
+
184
+ if not query.strip():
185
+ logger.warning("查询文本为空,返回空结果")
186
+ return []
187
+
188
+ # 检查限制
189
+ if len(documents) > 500:
190
+ logger.warning(
191
+ f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
192
+ )
193
+ documents = documents[:500]
194
+
195
+ try:
196
+ # 构建请求载荷,如果top_n为None则返回所有重排序结果
197
+ payload = self._build_payload(query, documents, top_n)
198
+
199
+ logger.debug(
200
+ f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
201
+ )
202
+
203
+ # 发送请求
204
+ async with self.client.post(self.base_url, json=payload) as response:
205
+ response.raise_for_status()
206
+ response_data = await response.json()
207
+
208
+ # 解析结果并记录使用量
209
+ results = self._parse_results(response_data)
210
+ self._log_usage(response_data)
211
+
212
+ logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
213
+
214
+ return results
215
+
216
+ except aiohttp.ClientError as e:
217
+ error_msg = f"网络请求失败: {e}"
218
+ logger.error(f"百炼 Rerank 网络请求失败: {e}")
219
+ raise BailianNetworkError(error_msg) from e
220
+ except BailianRerankError:
221
+ raise
222
+ except Exception as e:
223
+ error_msg = f"重排序失败: {e}"
224
+ logger.error(f"百炼 Rerank 处理失败: {e}")
225
+ raise BailianRerankError(error_msg) from e
226
+
227
+ async def terminate(self) -> None:
228
+ """关闭HTTP客户端会话."""
229
+ if self.client:
230
+ logger.info("关闭 百炼 Rerank 客户端会话")
231
+ try:
232
+ await self.client.close()
233
+ except Exception as e:
234
+ logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
235
+ finally:
236
+ self.client = None
@@ -290,13 +290,24 @@ class ProviderGoogleGenAI(Provider):
290
290
  parts = [types.Part.from_text(text=content)]
291
291
  append_or_extend(gemini_contents, parts, types.ModelContent)
292
292
  elif not native_tool_enabled and "tool_calls" in message:
293
- parts = [
294
- types.Part.from_function_call(
293
+ parts = []
294
+ for tool in message["tool_calls"]:
295
+ part = types.Part.from_function_call(
295
296
  name=tool["function"]["name"],
296
297
  args=json.loads(tool["function"]["arguments"]),
297
298
  )
298
- for tool in message["tool_calls"]
299
- ]
299
+ # we should set thought_signature back to part if exists
300
+ # for more info about thought_signature, see:
301
+ # https://ai.google.dev/gemini-api/docs/thought-signatures
302
+ if "extra_content" in tool and tool["extra_content"]:
303
+ ts_bs64 = (
304
+ tool["extra_content"]
305
+ .get("google", {})
306
+ .get("thought_signature")
307
+ )
308
+ if ts_bs64:
309
+ part.thought_signature = base64.b64decode(ts_bs64)
310
+ parts.append(part)
300
311
  append_or_extend(gemini_contents, parts, types.ModelContent)
301
312
  else:
302
313
  logger.warning("assistant 角色的消息内容为空,已添加空格占位")
@@ -393,10 +404,15 @@ class ProviderGoogleGenAI(Provider):
393
404
  llm_response.role = "tool"
394
405
  llm_response.tools_call_name.append(part.function_call.name)
395
406
  llm_response.tools_call_args.append(part.function_call.args)
396
- # gemini 返回的 function_call.id 可能为 None
397
- llm_response.tools_call_ids.append(
398
- part.function_call.id or part.function_call.name,
399
- )
407
+ # function_call.id might be None, use name as fallback
408
+ tool_call_id = part.function_call.id or part.function_call.name
409
+ llm_response.tools_call_ids.append(tool_call_id)
410
+ # extra_content
411
+ if part.thought_signature:
412
+ ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8")
413
+ llm_response.tools_call_extra_content[tool_call_id] = {
414
+ "google": {"thought_signature": ts_bs64}
415
+ }
400
416
  elif (
401
417
  part.inline_data
402
418
  and part.inline_data.mime_type
@@ -435,6 +451,7 @@ class ProviderGoogleGenAI(Provider):
435
451
  contents=conversation,
436
452
  config=config,
437
453
  )
454
+ logger.debug(f"genai result: {result}")
438
455
 
439
456
  if not result.candidates:
440
457
  logger.error(f"请求失败, 返回的 candidates 为空: {result}")
@@ -8,7 +8,7 @@ import re
8
8
  from collections.abc import AsyncGenerator
9
9
 
10
10
  from openai import AsyncAzureOpenAI, AsyncOpenAI
11
- from openai._exceptions import NotFoundError, UnprocessableEntityError
11
+ from openai._exceptions import NotFoundError
12
12
  from openai.lib.streaming.chat._completions import ChatCompletionStreamState
13
13
  from openai.types.chat.chat_completion import ChatCompletion
14
14
  from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -279,6 +279,7 @@ class ProviderOpenAIOfficial(Provider):
279
279
  args_ls = []
280
280
  func_name_ls = []
281
281
  tool_call_ids = []
282
+ tool_call_extra_content_dict = {}
282
283
  for tool_call in choice.message.tool_calls:
283
284
  if isinstance(tool_call, str):
284
285
  # workaround for #1359
@@ -296,11 +297,16 @@ class ProviderOpenAIOfficial(Provider):
296
297
  args_ls.append(args)
297
298
  func_name_ls.append(tool_call.function.name)
298
299
  tool_call_ids.append(tool_call.id)
300
+
301
+ # gemini-2.5 / gemini-3 series extra_content handling
302
+ extra_content = getattr(tool_call, "extra_content", None)
303
+ if extra_content is not None:
304
+ tool_call_extra_content_dict[tool_call.id] = extra_content
299
305
  llm_response.role = "tool"
300
306
  llm_response.tools_call_args = args_ls
301
307
  llm_response.tools_call_name = func_name_ls
302
308
  llm_response.tools_call_ids = tool_call_ids
303
-
309
+ llm_response.tools_call_extra_content = tool_call_extra_content_dict
304
310
  # specially handle finish reason
305
311
  if choice.finish_reason == "content_filter":
306
312
  raise Exception(
@@ -353,7 +359,7 @@ class ProviderOpenAIOfficial(Provider):
353
359
 
354
360
  payloads = {"messages": context_query, **model_config}
355
361
 
356
- # xAI 原生搜索参数(最小侵入地在此处注入)
362
+ # xAI origin search tool inject
357
363
  self._maybe_inject_xai_search(payloads, **kwargs)
358
364
 
359
365
  return payloads, context_query
@@ -475,12 +481,6 @@ class ProviderOpenAIOfficial(Provider):
475
481
  self.client.api_key = chosen_key
476
482
  llm_response = await self._query(payloads, func_tool)
477
483
  break
478
- except UnprocessableEntityError as e:
479
- logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
480
- # 尝试删除所有 image
481
- new_contexts = await self._remove_image_from_context(context_query)
482
- payloads["messages"] = new_contexts
483
- context_query = new_contexts
484
484
  except Exception as e:
485
485
  last_exception = e
486
486
  (
@@ -545,12 +545,6 @@ class ProviderOpenAIOfficial(Provider):
545
545
  async for response in self._query_stream(payloads, func_tool):
546
546
  yield response
547
547
  break
548
- except UnprocessableEntityError as e:
549
- logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
550
- # 尝试删除所有 image
551
- new_contexts = await self._remove_image_from_context(context_query)
552
- payloads["messages"] = new_contexts
553
- context_query = new_contexts
554
548
  except Exception as e:
555
549
  last_exception = e
556
550
  (
@@ -646,4 +640,3 @@ class ProviderOpenAIOfficial(Provider):
646
640
  with open(image_url, "rb") as f:
647
641
  image_bs64 = base64.b64encode(f.read()).decode("utf-8")
648
642
  return "data:image/jpeg;base64," + image_bs64
649
- return ""
@@ -10,7 +10,6 @@ from quart import g, make_response, request
10
10
  from astrbot.core import logger
11
11
  from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
12
12
  from astrbot.core.db import BaseDatabase
13
- from astrbot.core.platform.astr_message_event import MessageSession
14
13
  from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
15
14
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
16
15
 
@@ -36,11 +35,14 @@ class ChatRoute(Route):
36
35
  super().__init__(context)
37
36
  self.routes = {
38
37
  "/chat/send": ("POST", self.chat),
39
- "/chat/new_conversation": ("GET", self.new_conversation),
40
- "/chat/conversations": ("GET", self.get_conversations),
41
- "/chat/get_conversation": ("GET", self.get_conversation),
42
- "/chat/delete_conversation": ("GET", self.delete_conversation),
43
- "/chat/rename_conversation": ("POST", self.rename_conversation),
38
+ "/chat/new_session": ("GET", self.new_session),
39
+ "/chat/sessions": ("GET", self.get_sessions),
40
+ "/chat/get_session": ("GET", self.get_session),
41
+ "/chat/delete_session": ("GET", self.delete_webchat_session),
42
+ "/chat/update_session_display_name": (
43
+ "POST",
44
+ self.update_session_display_name,
45
+ ),
44
46
  "/chat/get_file": ("GET", self.get_file),
45
47
  "/chat/post_image": ("POST", self.post_image),
46
48
  "/chat/post_file": ("POST", self.post_file),
@@ -53,6 +55,7 @@ class ChatRoute(Route):
53
55
  self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
54
56
  self.conv_mgr = core_lifecycle.conversation_manager
55
57
  self.platform_history_mgr = core_lifecycle.platform_message_history_manager
58
+ self.db = db
56
59
 
57
60
  self.running_convs: dict[str, bool] = {}
58
61
 
@@ -116,11 +119,14 @@ class ChatRoute(Route):
116
119
  if "message" not in post_data and "image_url" not in post_data:
117
120
  return Response().error("Missing key: message or image_url").__dict__
118
121
 
119
- if "conversation_id" not in post_data:
120
- return Response().error("Missing key: conversation_id").__dict__
122
+ if "session_id" not in post_data and "conversation_id" not in post_data:
123
+ return (
124
+ Response().error("Missing key: session_id or conversation_id").__dict__
125
+ )
121
126
 
122
127
  message = post_data["message"]
123
- conversation_id = post_data["conversation_id"]
128
+ # conversation_id = post_data["conversation_id"]
129
+ session_id = post_data.get("session_id", post_data.get("conversation_id"))
124
130
  image_url = post_data.get("image_url")
125
131
  audio_url = post_data.get("audio_url")
126
132
  selected_provider = post_data.get("selected_provider")
@@ -133,11 +139,11 @@ class ChatRoute(Route):
133
139
  .error("Message and image_url and audio_url are empty")
134
140
  .__dict__
135
141
  )
136
- if not conversation_id:
137
- return Response().error("conversation_id is empty").__dict__
142
+ if not session_id:
143
+ return Response().error("session_id is empty").__dict__
138
144
 
139
145
  # 追加用户消息
140
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
146
+ webchat_conv_id = session_id
141
147
 
142
148
  # 获取会话特定的队列
143
149
  back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
@@ -245,88 +251,110 @@ class ChatRoute(Route):
245
251
  response.timeout = None # fix SSE auto disconnect issue
246
252
  return response
247
253
 
248
- async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
249
- """从对话 ID 中提取 WebChat 会话 ID
250
-
251
- NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
252
- """
253
- conversation = await self.conv_mgr.get_conversation(
254
- unified_msg_origin="webchat",
255
- conversation_id=conversation_id,
256
- )
257
- if not conversation:
258
- raise ValueError(f"Conversation with ID {conversation_id} not found.")
259
- conv_user_id = conversation.user_id
260
- webchat_session_id = MessageSession.from_str(conv_user_id).session_id
261
- if "!" not in webchat_session_id:
262
- raise ValueError(f"Invalid conv user ID: {conv_user_id}")
263
- return webchat_session_id.split("!")[-1]
264
-
265
- async def delete_conversation(self):
266
- conversation_id = request.args.get("conversation_id")
267
- if not conversation_id:
268
- return Response().error("Missing key: conversation_id").__dict__
254
+ async def delete_webchat_session(self):
255
+ """Delete a Platform session and all its related data."""
256
+ session_id = request.args.get("session_id")
257
+ if not session_id:
258
+ return Response().error("Missing key: session_id").__dict__
269
259
  username = g.get("username", "guest")
270
260
 
271
- # Clean up queues when deleting conversation
272
- webchat_queue_mgr.remove_queues(conversation_id)
273
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
274
- await self.conv_mgr.delete_conversation(
275
- unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
276
- conversation_id=conversation_id,
277
- )
261
+ # 验证会话是否存在且属于当前用户
262
+ session = await self.db.get_platform_session_by_id(session_id)
263
+ if not session:
264
+ return Response().error(f"Session {session_id} not found").__dict__
265
+ if session.creator != username:
266
+ return Response().error("Permission denied").__dict__
267
+
268
+ # 删除该会话下的所有对话
269
+ unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
270
+ await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
271
+
272
+ # 删除消息历史
278
273
  await self.platform_history_mgr.delete(
279
- platform_id="webchat",
280
- user_id=webchat_conv_id,
274
+ platform_id=session.platform_id,
275
+ user_id=session_id,
281
276
  offset_sec=99999999,
282
277
  )
278
+
279
+ # 清理队列(仅对 webchat)
280
+ if session.platform_id == "webchat":
281
+ webchat_queue_mgr.remove_queues(session_id)
282
+
283
+ # 删除会话
284
+ await self.db.delete_platform_session(session_id)
285
+
283
286
  return Response().ok().__dict__
284
287
 
285
- async def new_conversation(self):
288
+ async def new_session(self):
289
+ """Create a new Platform session (default: webchat)."""
286
290
  username = g.get("username", "guest")
287
- webchat_conv_id = str(uuid.uuid4())
288
- conv_id = await self.conv_mgr.new_conversation(
289
- unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
290
- platform_id="webchat",
291
- content=[],
291
+
292
+ # 获取可选的 platform_id 参数,默认为 webchat
293
+ platform_id = request.args.get("platform_id", "webchat")
294
+
295
+ # 创建新会话
296
+ session = await self.db.create_platform_session(
297
+ creator=username,
298
+ platform_id=platform_id,
299
+ is_group=0,
292
300
  )
293
- return Response().ok(data={"conversation_id": conv_id}).__dict__
294
301
 
295
- async def rename_conversation(self):
296
- post_data = await request.json
297
- if "conversation_id" not in post_data or "title" not in post_data:
298
- return Response().error("Missing key: conversation_id or title").__dict__
302
+ return (
303
+ Response()
304
+ .ok(
305
+ data={
306
+ "session_id": session.session_id,
307
+ "platform_id": session.platform_id,
308
+ }
309
+ )
310
+ .__dict__
311
+ )
299
312
 
300
- conversation_id = post_data["conversation_id"]
301
- title = post_data["title"]
313
+ async def get_sessions(self):
314
+ """Get all Platform sessions for the current user."""
315
+ username = g.get("username", "guest")
316
+
317
+ # 获取可选的 platform_id 参数
318
+ platform_id = request.args.get("platform_id")
302
319
 
303
- await self.conv_mgr.update_conversation(
304
- unified_msg_origin="webchat", # fake
305
- conversation_id=conversation_id,
306
- title=title,
320
+ sessions = await self.db.get_platform_sessions_by_creator(
321
+ creator=username,
322
+ platform_id=platform_id,
323
+ page=1,
324
+ page_size=100, # 暂时返回前100个
307
325
  )
308
- return Response().ok(message="重命名成功!").__dict__
309
326
 
310
- async def get_conversations(self):
311
- conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
312
- # remove content
313
- conversations_ = []
314
- for conv in conversations:
315
- conv.history = None
316
- conversations_.append(conv)
317
- return Response().ok(data=conversations_).__dict__
327
+ # 转换为字典格式,并添加额外信息
328
+ sessions_data = []
329
+ for session in sessions:
330
+ sessions_data.append(
331
+ {
332
+ "session_id": session.session_id,
333
+ "platform_id": session.platform_id,
334
+ "creator": session.creator,
335
+ "display_name": session.display_name,
336
+ "is_group": session.is_group,
337
+ "created_at": session.created_at.astimezone().isoformat(),
338
+ "updated_at": session.updated_at.astimezone().isoformat(),
339
+ }
340
+ )
341
+
342
+ return Response().ok(data=sessions_data).__dict__
318
343
 
319
- async def get_conversation(self):
320
- conversation_id = request.args.get("conversation_id")
321
- if not conversation_id:
322
- return Response().error("Missing key: conversation_id").__dict__
344
+ async def get_session(self):
345
+ """Get session information and message history by session_id."""
346
+ session_id = request.args.get("session_id")
347
+ if not session_id:
348
+ return Response().error("Missing key: session_id").__dict__
323
349
 
324
- webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
350
+ # 获取会话信息以确定 platform_id
351
+ session = await self.db.get_platform_session_by_id(session_id)
352
+ platform_id = session.platform_id if session else "webchat"
325
353
 
326
- # Get platform message history
354
+ # Get platform message history using session_id
327
355
  history_ls = await self.platform_history_mgr.get(
328
- platform_id="webchat",
329
- user_id=webchat_conv_id,
356
+ platform_id=platform_id,
357
+ user_id=session_id,
330
358
  page=1,
331
359
  page_size=1000,
332
360
  )
@@ -338,8 +366,37 @@ class ChatRoute(Route):
338
366
  .ok(
339
367
  data={
340
368
  "history": history_res,
341
- "is_running": self.running_convs.get(webchat_conv_id, False),
369
+ "is_running": self.running_convs.get(session_id, False),
342
370
  },
343
371
  )
344
372
  .__dict__
345
373
  )
374
+
375
+ async def update_session_display_name(self):
376
+ """Update a Platform session's display name."""
377
+ post_data = await request.json
378
+
379
+ session_id = post_data.get("session_id")
380
+ display_name = post_data.get("display_name")
381
+
382
+ if not session_id:
383
+ return Response().error("Missing key: session_id").__dict__
384
+ if display_name is None:
385
+ return Response().error("Missing key: display_name").__dict__
386
+
387
+ username = g.get("username", "guest")
388
+
389
+ # 验证会话是否存在且属于当前用户
390
+ session = await self.db.get_platform_session_by_id(session_id)
391
+ if not session:
392
+ return Response().error(f"Session {session_id} not found").__dict__
393
+ if session.creator != username:
394
+ return Response().error("Permission denied").__dict__
395
+
396
+ # 更新 display_name
397
+ await self.db.update_platform_session(
398
+ session_id=session_id,
399
+ display_name=display_name,
400
+ )
401
+
402
+ return Response().ok().__dict__