AstrBot 4.5.8__py3-none-any.whl → 4.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +152 -26
- astrbot/core/agent/message.py +7 -0
- astrbot/core/config/default.py +31 -1
- astrbot/core/core_lifecycle.py +8 -0
- astrbot/core/db/__init__.py +50 -1
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/po.py +49 -13
- astrbot/core/db/sqlite.py +102 -3
- astrbot/core/knowledge_base/kb_helper.py +314 -33
- astrbot/core/knowledge_base/kb_mgr.py +45 -1
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +28 -14
- astrbot/core/pipeline/process_stage/utils.py +60 -16
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +13 -10
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -4
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +0 -4
- astrbot/core/provider/entities.py +22 -9
- astrbot/core/provider/func_tool_manager.py +12 -9
- astrbot/core/provider/manager.py +4 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/gemini_source.py +25 -8
- astrbot/core/provider/sources/openai_source.py +9 -16
- astrbot/dashboard/routes/chat.py +134 -77
- astrbot/dashboard/routes/knowledge_base.py +172 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/METADATA +5 -4
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/RECORD +30 -26
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/WHEEL +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.8.dist-info → astrbot-4.6.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import aiohttp
|
|
4
|
+
|
|
5
|
+
from astrbot import logger
|
|
6
|
+
|
|
7
|
+
from ..entities import ProviderType, RerankResult
|
|
8
|
+
from ..provider import RerankProvider
|
|
9
|
+
from ..register import register_provider_adapter
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BailianRerankError(Exception):
|
|
13
|
+
"""百炼重排序服务异常基类"""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BailianAPIError(BailianRerankError):
|
|
19
|
+
"""百炼API返回错误"""
|
|
20
|
+
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BailianNetworkError(BailianRerankError):
|
|
25
|
+
"""百炼网络请求错误"""
|
|
26
|
+
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_provider_adapter(
|
|
31
|
+
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
|
32
|
+
)
|
|
33
|
+
class BailianRerankProvider(RerankProvider):
|
|
34
|
+
"""阿里云百炼文本重排序适配器."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
37
|
+
super().__init__(provider_config, provider_settings)
|
|
38
|
+
self.provider_config = provider_config
|
|
39
|
+
self.provider_settings = provider_settings
|
|
40
|
+
|
|
41
|
+
# API配置
|
|
42
|
+
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
|
|
43
|
+
"DASHSCOPE_API_KEY", ""
|
|
44
|
+
)
|
|
45
|
+
if not self.api_key:
|
|
46
|
+
raise ValueError("阿里云百炼 API Key 不能为空。")
|
|
47
|
+
|
|
48
|
+
self.model = provider_config.get("rerank_model", "qwen3-rerank")
|
|
49
|
+
self.timeout = provider_config.get("timeout", 30)
|
|
50
|
+
self.return_documents = provider_config.get("return_documents", False)
|
|
51
|
+
self.instruct = provider_config.get("instruct", "")
|
|
52
|
+
|
|
53
|
+
self.base_url = provider_config.get(
|
|
54
|
+
"rerank_api_base",
|
|
55
|
+
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# 设置HTTP客户端
|
|
59
|
+
headers = {
|
|
60
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
61
|
+
"Content-Type": "application/json",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
self.client = aiohttp.ClientSession(
|
|
65
|
+
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# 设置模型名称
|
|
69
|
+
self.set_model(self.model)
|
|
70
|
+
|
|
71
|
+
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
|
|
72
|
+
|
|
73
|
+
def _build_payload(
|
|
74
|
+
self, query: str, documents: list[str], top_n: int | None
|
|
75
|
+
) -> dict:
|
|
76
|
+
"""构建请求载荷
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
query: 查询文本
|
|
80
|
+
documents: 文档列表
|
|
81
|
+
top_n: 返回前N个结果,如果为None则返回所有结果
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
请求载荷字典
|
|
85
|
+
"""
|
|
86
|
+
base = {"model": self.model, "input": {"query": query, "documents": documents}}
|
|
87
|
+
|
|
88
|
+
params = {
|
|
89
|
+
k: v
|
|
90
|
+
for k, v in [
|
|
91
|
+
("top_n", top_n if top_n is not None and top_n > 0 else None),
|
|
92
|
+
("return_documents", True if self.return_documents else None),
|
|
93
|
+
(
|
|
94
|
+
"instruct",
|
|
95
|
+
self.instruct
|
|
96
|
+
if self.instruct and self.model == "qwen3-rerank"
|
|
97
|
+
else None,
|
|
98
|
+
),
|
|
99
|
+
]
|
|
100
|
+
if v is not None
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if params:
|
|
104
|
+
base["parameters"] = params
|
|
105
|
+
|
|
106
|
+
return base
|
|
107
|
+
|
|
108
|
+
def _parse_results(self, data: dict) -> list[RerankResult]:
|
|
109
|
+
"""解析API响应结果
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
data: API响应数据
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
重排序结果列表
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
BailianAPIError: API返回错误
|
|
119
|
+
KeyError: 结果缺少必要字段
|
|
120
|
+
"""
|
|
121
|
+
# 检查响应状态
|
|
122
|
+
if data.get("code", "200") != "200":
|
|
123
|
+
raise BailianAPIError(
|
|
124
|
+
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
results = data.get("output", {}).get("results", [])
|
|
128
|
+
if not results:
|
|
129
|
+
logger.warning(f"百炼 Rerank 返回空结果: {data}")
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
# 转换为RerankResult对象,使用.get()避免KeyError
|
|
133
|
+
rerank_results = []
|
|
134
|
+
for idx, result in enumerate(results):
|
|
135
|
+
try:
|
|
136
|
+
index = result.get("index", idx)
|
|
137
|
+
relevance_score = result.get("relevance_score", 0.0)
|
|
138
|
+
|
|
139
|
+
if relevance_score is None:
|
|
140
|
+
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
|
|
141
|
+
relevance_score = 0.0
|
|
142
|
+
|
|
143
|
+
rerank_result = RerankResult(
|
|
144
|
+
index=index, relevance_score=relevance_score
|
|
145
|
+
)
|
|
146
|
+
rerank_results.append(rerank_result)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
return rerank_results
|
|
152
|
+
|
|
153
|
+
def _log_usage(self, data: dict) -> None:
|
|
154
|
+
"""记录使用量信息
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
data: API响应数据
|
|
158
|
+
"""
|
|
159
|
+
tokens = data.get("usage", {}).get("total_tokens", 0)
|
|
160
|
+
if tokens > 0:
|
|
161
|
+
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
|
|
162
|
+
|
|
163
|
+
async def rerank(
|
|
164
|
+
self,
|
|
165
|
+
query: str,
|
|
166
|
+
documents: list[str],
|
|
167
|
+
top_n: int | None = None,
|
|
168
|
+
) -> list[RerankResult]:
|
|
169
|
+
"""
|
|
170
|
+
对文档进行重排序
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
query: 查询文本
|
|
174
|
+
documents: 待排序的文档列表
|
|
175
|
+
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
重排序结果列表
|
|
179
|
+
"""
|
|
180
|
+
if not documents:
|
|
181
|
+
logger.warning("文档列表为空,返回空结果")
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
if not query.strip():
|
|
185
|
+
logger.warning("查询文本为空,返回空结果")
|
|
186
|
+
return []
|
|
187
|
+
|
|
188
|
+
# 检查限制
|
|
189
|
+
if len(documents) > 500:
|
|
190
|
+
logger.warning(
|
|
191
|
+
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
|
|
192
|
+
)
|
|
193
|
+
documents = documents[:500]
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
# 构建请求载荷,如果top_n为None则返回所有重排序结果
|
|
197
|
+
payload = self._build_payload(query, documents, top_n)
|
|
198
|
+
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# 发送请求
|
|
204
|
+
async with self.client.post(self.base_url, json=payload) as response:
|
|
205
|
+
response.raise_for_status()
|
|
206
|
+
response_data = await response.json()
|
|
207
|
+
|
|
208
|
+
# 解析结果并记录使用量
|
|
209
|
+
results = self._parse_results(response_data)
|
|
210
|
+
self._log_usage(response_data)
|
|
211
|
+
|
|
212
|
+
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
|
|
213
|
+
|
|
214
|
+
return results
|
|
215
|
+
|
|
216
|
+
except aiohttp.ClientError as e:
|
|
217
|
+
error_msg = f"网络请求失败: {e}"
|
|
218
|
+
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
|
219
|
+
raise BailianNetworkError(error_msg) from e
|
|
220
|
+
except BailianRerankError:
|
|
221
|
+
raise
|
|
222
|
+
except Exception as e:
|
|
223
|
+
error_msg = f"重排序失败: {e}"
|
|
224
|
+
logger.error(f"百炼 Rerank 处理失败: {e}")
|
|
225
|
+
raise BailianRerankError(error_msg) from e
|
|
226
|
+
|
|
227
|
+
async def terminate(self) -> None:
|
|
228
|
+
"""关闭HTTP客户端会话."""
|
|
229
|
+
if self.client:
|
|
230
|
+
logger.info("关闭 百炼 Rerank 客户端会话")
|
|
231
|
+
try:
|
|
232
|
+
await self.client.close()
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
|
|
235
|
+
finally:
|
|
236
|
+
self.client = None
|
|
@@ -290,13 +290,24 @@ class ProviderGoogleGenAI(Provider):
|
|
|
290
290
|
parts = [types.Part.from_text(text=content)]
|
|
291
291
|
append_or_extend(gemini_contents, parts, types.ModelContent)
|
|
292
292
|
elif not native_tool_enabled and "tool_calls" in message:
|
|
293
|
-
parts = [
|
|
294
|
-
|
|
293
|
+
parts = []
|
|
294
|
+
for tool in message["tool_calls"]:
|
|
295
|
+
part = types.Part.from_function_call(
|
|
295
296
|
name=tool["function"]["name"],
|
|
296
297
|
args=json.loads(tool["function"]["arguments"]),
|
|
297
298
|
)
|
|
298
|
-
|
|
299
|
-
|
|
299
|
+
# we should set thought_signature back to part if exists
|
|
300
|
+
# for more info about thought_signature, see:
|
|
301
|
+
# https://ai.google.dev/gemini-api/docs/thought-signatures
|
|
302
|
+
if "extra_content" in tool and tool["extra_content"]:
|
|
303
|
+
ts_bs64 = (
|
|
304
|
+
tool["extra_content"]
|
|
305
|
+
.get("google", {})
|
|
306
|
+
.get("thought_signature")
|
|
307
|
+
)
|
|
308
|
+
if ts_bs64:
|
|
309
|
+
part.thought_signature = base64.b64decode(ts_bs64)
|
|
310
|
+
parts.append(part)
|
|
300
311
|
append_or_extend(gemini_contents, parts, types.ModelContent)
|
|
301
312
|
else:
|
|
302
313
|
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
|
@@ -393,10 +404,15 @@ class ProviderGoogleGenAI(Provider):
|
|
|
393
404
|
llm_response.role = "tool"
|
|
394
405
|
llm_response.tools_call_name.append(part.function_call.name)
|
|
395
406
|
llm_response.tools_call_args.append(part.function_call.args)
|
|
396
|
-
#
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
407
|
+
# function_call.id might be None, use name as fallback
|
|
408
|
+
tool_call_id = part.function_call.id or part.function_call.name
|
|
409
|
+
llm_response.tools_call_ids.append(tool_call_id)
|
|
410
|
+
# extra_content
|
|
411
|
+
if part.thought_signature:
|
|
412
|
+
ts_bs64 = base64.b64encode(part.thought_signature).decode("utf-8")
|
|
413
|
+
llm_response.tools_call_extra_content[tool_call_id] = {
|
|
414
|
+
"google": {"thought_signature": ts_bs64}
|
|
415
|
+
}
|
|
400
416
|
elif (
|
|
401
417
|
part.inline_data
|
|
402
418
|
and part.inline_data.mime_type
|
|
@@ -435,6 +451,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
435
451
|
contents=conversation,
|
|
436
452
|
config=config,
|
|
437
453
|
)
|
|
454
|
+
logger.debug(f"genai result: {result}")
|
|
438
455
|
|
|
439
456
|
if not result.candidates:
|
|
440
457
|
logger.error(f"请求失败, 返回的 candidates 为空: {result}")
|
|
@@ -8,7 +8,7 @@ import re
|
|
|
8
8
|
from collections.abc import AsyncGenerator
|
|
9
9
|
|
|
10
10
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
11
|
-
from openai._exceptions import NotFoundError
|
|
11
|
+
from openai._exceptions import NotFoundError
|
|
12
12
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
|
13
13
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
14
14
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
@@ -279,6 +279,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
279
279
|
args_ls = []
|
|
280
280
|
func_name_ls = []
|
|
281
281
|
tool_call_ids = []
|
|
282
|
+
tool_call_extra_content_dict = {}
|
|
282
283
|
for tool_call in choice.message.tool_calls:
|
|
283
284
|
if isinstance(tool_call, str):
|
|
284
285
|
# workaround for #1359
|
|
@@ -296,11 +297,16 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
296
297
|
args_ls.append(args)
|
|
297
298
|
func_name_ls.append(tool_call.function.name)
|
|
298
299
|
tool_call_ids.append(tool_call.id)
|
|
300
|
+
|
|
301
|
+
# gemini-2.5 / gemini-3 series extra_content handling
|
|
302
|
+
extra_content = getattr(tool_call, "extra_content", None)
|
|
303
|
+
if extra_content is not None:
|
|
304
|
+
tool_call_extra_content_dict[tool_call.id] = extra_content
|
|
299
305
|
llm_response.role = "tool"
|
|
300
306
|
llm_response.tools_call_args = args_ls
|
|
301
307
|
llm_response.tools_call_name = func_name_ls
|
|
302
308
|
llm_response.tools_call_ids = tool_call_ids
|
|
303
|
-
|
|
309
|
+
llm_response.tools_call_extra_content = tool_call_extra_content_dict
|
|
304
310
|
# specially handle finish reason
|
|
305
311
|
if choice.finish_reason == "content_filter":
|
|
306
312
|
raise Exception(
|
|
@@ -353,7 +359,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
353
359
|
|
|
354
360
|
payloads = {"messages": context_query, **model_config}
|
|
355
361
|
|
|
356
|
-
# xAI
|
|
362
|
+
# xAI origin search tool inject
|
|
357
363
|
self._maybe_inject_xai_search(payloads, **kwargs)
|
|
358
364
|
|
|
359
365
|
return payloads, context_query
|
|
@@ -475,12 +481,6 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
475
481
|
self.client.api_key = chosen_key
|
|
476
482
|
llm_response = await self._query(payloads, func_tool)
|
|
477
483
|
break
|
|
478
|
-
except UnprocessableEntityError as e:
|
|
479
|
-
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
|
480
|
-
# 尝试删除所有 image
|
|
481
|
-
new_contexts = await self._remove_image_from_context(context_query)
|
|
482
|
-
payloads["messages"] = new_contexts
|
|
483
|
-
context_query = new_contexts
|
|
484
484
|
except Exception as e:
|
|
485
485
|
last_exception = e
|
|
486
486
|
(
|
|
@@ -545,12 +545,6 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
545
545
|
async for response in self._query_stream(payloads, func_tool):
|
|
546
546
|
yield response
|
|
547
547
|
break
|
|
548
|
-
except UnprocessableEntityError as e:
|
|
549
|
-
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
|
550
|
-
# 尝试删除所有 image
|
|
551
|
-
new_contexts = await self._remove_image_from_context(context_query)
|
|
552
|
-
payloads["messages"] = new_contexts
|
|
553
|
-
context_query = new_contexts
|
|
554
548
|
except Exception as e:
|
|
555
549
|
last_exception = e
|
|
556
550
|
(
|
|
@@ -646,4 +640,3 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
646
640
|
with open(image_url, "rb") as f:
|
|
647
641
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
|
648
642
|
return "data:image/jpeg;base64," + image_bs64
|
|
649
|
-
return ""
|
astrbot/dashboard/routes/chat.py
CHANGED
|
@@ -10,7 +10,6 @@ from quart import g, make_response, request
|
|
|
10
10
|
from astrbot.core import logger
|
|
11
11
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
12
12
|
from astrbot.core.db import BaseDatabase
|
|
13
|
-
from astrbot.core.platform.astr_message_event import MessageSession
|
|
14
13
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
|
15
14
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
16
15
|
|
|
@@ -36,11 +35,14 @@ class ChatRoute(Route):
|
|
|
36
35
|
super().__init__(context)
|
|
37
36
|
self.routes = {
|
|
38
37
|
"/chat/send": ("POST", self.chat),
|
|
39
|
-
"/chat/
|
|
40
|
-
"/chat/
|
|
41
|
-
"/chat/
|
|
42
|
-
"/chat/
|
|
43
|
-
"/chat/
|
|
38
|
+
"/chat/new_session": ("GET", self.new_session),
|
|
39
|
+
"/chat/sessions": ("GET", self.get_sessions),
|
|
40
|
+
"/chat/get_session": ("GET", self.get_session),
|
|
41
|
+
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
|
42
|
+
"/chat/update_session_display_name": (
|
|
43
|
+
"POST",
|
|
44
|
+
self.update_session_display_name,
|
|
45
|
+
),
|
|
44
46
|
"/chat/get_file": ("GET", self.get_file),
|
|
45
47
|
"/chat/post_image": ("POST", self.post_image),
|
|
46
48
|
"/chat/post_file": ("POST", self.post_file),
|
|
@@ -53,6 +55,7 @@ class ChatRoute(Route):
|
|
|
53
55
|
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
|
54
56
|
self.conv_mgr = core_lifecycle.conversation_manager
|
|
55
57
|
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
|
58
|
+
self.db = db
|
|
56
59
|
|
|
57
60
|
self.running_convs: dict[str, bool] = {}
|
|
58
61
|
|
|
@@ -116,11 +119,14 @@ class ChatRoute(Route):
|
|
|
116
119
|
if "message" not in post_data and "image_url" not in post_data:
|
|
117
120
|
return Response().error("Missing key: message or image_url").__dict__
|
|
118
121
|
|
|
119
|
-
if "conversation_id" not in post_data:
|
|
120
|
-
return
|
|
122
|
+
if "session_id" not in post_data and "conversation_id" not in post_data:
|
|
123
|
+
return (
|
|
124
|
+
Response().error("Missing key: session_id or conversation_id").__dict__
|
|
125
|
+
)
|
|
121
126
|
|
|
122
127
|
message = post_data["message"]
|
|
123
|
-
conversation_id = post_data["conversation_id"]
|
|
128
|
+
# conversation_id = post_data["conversation_id"]
|
|
129
|
+
session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
|
124
130
|
image_url = post_data.get("image_url")
|
|
125
131
|
audio_url = post_data.get("audio_url")
|
|
126
132
|
selected_provider = post_data.get("selected_provider")
|
|
@@ -133,11 +139,11 @@ class ChatRoute(Route):
|
|
|
133
139
|
.error("Message and image_url and audio_url are empty")
|
|
134
140
|
.__dict__
|
|
135
141
|
)
|
|
136
|
-
if not
|
|
137
|
-
return Response().error("
|
|
142
|
+
if not session_id:
|
|
143
|
+
return Response().error("session_id is empty").__dict__
|
|
138
144
|
|
|
139
145
|
# 追加用户消息
|
|
140
|
-
webchat_conv_id =
|
|
146
|
+
webchat_conv_id = session_id
|
|
141
147
|
|
|
142
148
|
# 获取会话特定的队列
|
|
143
149
|
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
|
@@ -245,88 +251,110 @@ class ChatRoute(Route):
|
|
|
245
251
|
response.timeout = None # fix SSE auto disconnect issue
|
|
246
252
|
return response
|
|
247
253
|
|
|
248
|
-
async def
|
|
249
|
-
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
conversation = await self.conv_mgr.get_conversation(
|
|
254
|
-
unified_msg_origin="webchat",
|
|
255
|
-
conversation_id=conversation_id,
|
|
256
|
-
)
|
|
257
|
-
if not conversation:
|
|
258
|
-
raise ValueError(f"Conversation with ID {conversation_id} not found.")
|
|
259
|
-
conv_user_id = conversation.user_id
|
|
260
|
-
webchat_session_id = MessageSession.from_str(conv_user_id).session_id
|
|
261
|
-
if "!" not in webchat_session_id:
|
|
262
|
-
raise ValueError(f"Invalid conv user ID: {conv_user_id}")
|
|
263
|
-
return webchat_session_id.split("!")[-1]
|
|
264
|
-
|
|
265
|
-
async def delete_conversation(self):
|
|
266
|
-
conversation_id = request.args.get("conversation_id")
|
|
267
|
-
if not conversation_id:
|
|
268
|
-
return Response().error("Missing key: conversation_id").__dict__
|
|
254
|
+
async def delete_webchat_session(self):
|
|
255
|
+
"""Delete a Platform session and all its related data."""
|
|
256
|
+
session_id = request.args.get("session_id")
|
|
257
|
+
if not session_id:
|
|
258
|
+
return Response().error("Missing key: session_id").__dict__
|
|
269
259
|
username = g.get("username", "guest")
|
|
270
260
|
|
|
271
|
-
#
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
261
|
+
# 验证会话是否存在且属于当前用户
|
|
262
|
+
session = await self.db.get_platform_session_by_id(session_id)
|
|
263
|
+
if not session:
|
|
264
|
+
return Response().error(f"Session {session_id} not found").__dict__
|
|
265
|
+
if session.creator != username:
|
|
266
|
+
return Response().error("Permission denied").__dict__
|
|
267
|
+
|
|
268
|
+
# 删除该会话下的所有对话
|
|
269
|
+
unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}"
|
|
270
|
+
await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin)
|
|
271
|
+
|
|
272
|
+
# 删除消息历史
|
|
278
273
|
await self.platform_history_mgr.delete(
|
|
279
|
-
platform_id=
|
|
280
|
-
user_id=
|
|
274
|
+
platform_id=session.platform_id,
|
|
275
|
+
user_id=session_id,
|
|
281
276
|
offset_sec=99999999,
|
|
282
277
|
)
|
|
278
|
+
|
|
279
|
+
# 清理队列(仅对 webchat)
|
|
280
|
+
if session.platform_id == "webchat":
|
|
281
|
+
webchat_queue_mgr.remove_queues(session_id)
|
|
282
|
+
|
|
283
|
+
# 删除会话
|
|
284
|
+
await self.db.delete_platform_session(session_id)
|
|
285
|
+
|
|
283
286
|
return Response().ok().__dict__
|
|
284
287
|
|
|
285
|
-
async def
|
|
288
|
+
async def new_session(self):
|
|
289
|
+
"""Create a new Platform session (default: webchat)."""
|
|
286
290
|
username = g.get("username", "guest")
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
291
|
+
|
|
292
|
+
# 获取可选的 platform_id 参数,默认为 webchat
|
|
293
|
+
platform_id = request.args.get("platform_id", "webchat")
|
|
294
|
+
|
|
295
|
+
# 创建新会话
|
|
296
|
+
session = await self.db.create_platform_session(
|
|
297
|
+
creator=username,
|
|
298
|
+
platform_id=platform_id,
|
|
299
|
+
is_group=0,
|
|
292
300
|
)
|
|
293
|
-
return Response().ok(data={"conversation_id": conv_id}).__dict__
|
|
294
301
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
302
|
+
return (
|
|
303
|
+
Response()
|
|
304
|
+
.ok(
|
|
305
|
+
data={
|
|
306
|
+
"session_id": session.session_id,
|
|
307
|
+
"platform_id": session.platform_id,
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
.__dict__
|
|
311
|
+
)
|
|
299
312
|
|
|
300
|
-
|
|
301
|
-
|
|
313
|
+
async def get_sessions(self):
|
|
314
|
+
"""Get all Platform sessions for the current user."""
|
|
315
|
+
username = g.get("username", "guest")
|
|
316
|
+
|
|
317
|
+
# 获取可选的 platform_id 参数
|
|
318
|
+
platform_id = request.args.get("platform_id")
|
|
302
319
|
|
|
303
|
-
await self.
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
320
|
+
sessions = await self.db.get_platform_sessions_by_creator(
|
|
321
|
+
creator=username,
|
|
322
|
+
platform_id=platform_id,
|
|
323
|
+
page=1,
|
|
324
|
+
page_size=100, # 暂时返回前100个
|
|
307
325
|
)
|
|
308
|
-
return Response().ok(message="重命名成功!").__dict__
|
|
309
326
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
327
|
+
# 转换为字典格式,并添加额外信息
|
|
328
|
+
sessions_data = []
|
|
329
|
+
for session in sessions:
|
|
330
|
+
sessions_data.append(
|
|
331
|
+
{
|
|
332
|
+
"session_id": session.session_id,
|
|
333
|
+
"platform_id": session.platform_id,
|
|
334
|
+
"creator": session.creator,
|
|
335
|
+
"display_name": session.display_name,
|
|
336
|
+
"is_group": session.is_group,
|
|
337
|
+
"created_at": session.created_at.astimezone().isoformat(),
|
|
338
|
+
"updated_at": session.updated_at.astimezone().isoformat(),
|
|
339
|
+
}
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return Response().ok(data=sessions_data).__dict__
|
|
318
343
|
|
|
319
|
-
async def
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
344
|
+
async def get_session(self):
|
|
345
|
+
"""Get session information and message history by session_id."""
|
|
346
|
+
session_id = request.args.get("session_id")
|
|
347
|
+
if not session_id:
|
|
348
|
+
return Response().error("Missing key: session_id").__dict__
|
|
323
349
|
|
|
324
|
-
|
|
350
|
+
# 获取会话信息以确定 platform_id
|
|
351
|
+
session = await self.db.get_platform_session_by_id(session_id)
|
|
352
|
+
platform_id = session.platform_id if session else "webchat"
|
|
325
353
|
|
|
326
|
-
# Get platform message history
|
|
354
|
+
# Get platform message history using session_id
|
|
327
355
|
history_ls = await self.platform_history_mgr.get(
|
|
328
|
-
platform_id=
|
|
329
|
-
user_id=
|
|
356
|
+
platform_id=platform_id,
|
|
357
|
+
user_id=session_id,
|
|
330
358
|
page=1,
|
|
331
359
|
page_size=1000,
|
|
332
360
|
)
|
|
@@ -338,8 +366,37 @@ class ChatRoute(Route):
|
|
|
338
366
|
.ok(
|
|
339
367
|
data={
|
|
340
368
|
"history": history_res,
|
|
341
|
-
"is_running": self.running_convs.get(
|
|
369
|
+
"is_running": self.running_convs.get(session_id, False),
|
|
342
370
|
},
|
|
343
371
|
)
|
|
344
372
|
.__dict__
|
|
345
373
|
)
|
|
374
|
+
|
|
375
|
+
async def update_session_display_name(self):
|
|
376
|
+
"""Update a Platform session's display name."""
|
|
377
|
+
post_data = await request.json
|
|
378
|
+
|
|
379
|
+
session_id = post_data.get("session_id")
|
|
380
|
+
display_name = post_data.get("display_name")
|
|
381
|
+
|
|
382
|
+
if not session_id:
|
|
383
|
+
return Response().error("Missing key: session_id").__dict__
|
|
384
|
+
if display_name is None:
|
|
385
|
+
return Response().error("Missing key: display_name").__dict__
|
|
386
|
+
|
|
387
|
+
username = g.get("username", "guest")
|
|
388
|
+
|
|
389
|
+
# 验证会话是否存在且属于当前用户
|
|
390
|
+
session = await self.db.get_platform_session_by_id(session_id)
|
|
391
|
+
if not session:
|
|
392
|
+
return Response().error(f"Session {session_id} not found").__dict__
|
|
393
|
+
if session.creator != username:
|
|
394
|
+
return Response().error("Permission denied").__dict__
|
|
395
|
+
|
|
396
|
+
# 更新 display_name
|
|
397
|
+
await self.db.update_platform_session(
|
|
398
|
+
session_id=session_id,
|
|
399
|
+
display_name=display_name,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return Response().ok().__dict__
|