AstrBot 4.1.7__py3-none-any.whl → 4.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/config/default.py +33 -1
- astrbot/core/conversation_mgr.py +12 -4
- astrbot/core/db/__init__.py +5 -0
- astrbot/core/db/sqlite.py +8 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +25 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -5
- astrbot/core/platform/astr_message_event.py +9 -5
- astrbot/core/platform/sources/wecom/wecom_adapter.py +1 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +1 -0
- astrbot/core/provider/manager.py +2 -0
- astrbot/core/provider/sources/coze_api_client.py +314 -0
- astrbot/core/provider/sources/coze_source.py +635 -0
- astrbot/core/star/filter/command.py +23 -11
- astrbot/core/star/filter/command_group.py +15 -5
- astrbot/core/star/session_llm_manager.py +0 -4
- astrbot/core/utils/dify_api_client.py +44 -57
- astrbot/dashboard/routes/chat.py +70 -36
- astrbot/dashboard/routes/session_management.py +235 -78
- {astrbot-4.1.7.dist-info → astrbot-4.2.1.dist-info}/METADATA +1 -1
- {astrbot-4.1.7.dist-info → astrbot-4.2.1.dist-info}/RECORD +24 -22
- {astrbot-4.1.7.dist-info → astrbot-4.2.1.dist-info}/WHEEL +0 -0
- {astrbot-4.1.7.dist-info → astrbot-4.2.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.1.7.dist-info → astrbot-4.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
from typing import AsyncGenerator, Dict
|
|
6
|
+
from astrbot.core.message.message_event_result import MessageChain
|
|
7
|
+
import astrbot.core.message.components as Comp
|
|
8
|
+
from astrbot.api.provider import Provider
|
|
9
|
+
from astrbot import logger
|
|
10
|
+
from astrbot.core.provider.entities import LLMResponse
|
|
11
|
+
from ..register import register_provider_adapter
|
|
12
|
+
from .coze_api_client import CozeAPIClient
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
|
16
|
+
class ProviderCoze(Provider):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
provider_config,
|
|
20
|
+
provider_settings,
|
|
21
|
+
default_persona=None,
|
|
22
|
+
) -> None:
|
|
23
|
+
super().__init__(
|
|
24
|
+
provider_config,
|
|
25
|
+
provider_settings,
|
|
26
|
+
default_persona,
|
|
27
|
+
)
|
|
28
|
+
self.api_key = provider_config.get("coze_api_key", "")
|
|
29
|
+
if not self.api_key:
|
|
30
|
+
raise Exception("Coze API Key 不能为空。")
|
|
31
|
+
self.bot_id = provider_config.get("bot_id", "")
|
|
32
|
+
if not self.bot_id:
|
|
33
|
+
raise Exception("Coze Bot ID 不能为空。")
|
|
34
|
+
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
|
35
|
+
|
|
36
|
+
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
|
37
|
+
("http://", "https://")
|
|
38
|
+
):
|
|
39
|
+
raise Exception(
|
|
40
|
+
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.timeout = provider_config.get("timeout", 120)
|
|
44
|
+
if isinstance(self.timeout, str):
|
|
45
|
+
self.timeout = int(self.timeout)
|
|
46
|
+
self.auto_save_history = provider_config.get("auto_save_history", True)
|
|
47
|
+
self.conversation_ids: Dict[str, str] = {}
|
|
48
|
+
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
|
49
|
+
|
|
50
|
+
# 创建 API 客户端
|
|
51
|
+
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
|
52
|
+
|
|
53
|
+
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
|
54
|
+
"""生成统一的缓存键
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
data: 图片数据或路径
|
|
58
|
+
is_base64: 是否是 base64 数据
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: 缓存键
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
if is_base64 and data.startswith("data:image/"):
|
|
66
|
+
try:
|
|
67
|
+
header, encoded = data.split(",", 1)
|
|
68
|
+
image_bytes = base64.b64decode(encoded)
|
|
69
|
+
cache_key = hashlib.md5(image_bytes).hexdigest()
|
|
70
|
+
return cache_key
|
|
71
|
+
except Exception:
|
|
72
|
+
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
|
73
|
+
return cache_key
|
|
74
|
+
else:
|
|
75
|
+
if data.startswith(("http://", "https://")):
|
|
76
|
+
# URL图片,使用URL作为缓存键
|
|
77
|
+
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
|
78
|
+
return cache_key
|
|
79
|
+
else:
|
|
80
|
+
clean_path = (
|
|
81
|
+
data.split("_")[0]
|
|
82
|
+
if "_" in data and len(data.split("_")) >= 3
|
|
83
|
+
else data
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if os.path.exists(clean_path):
|
|
87
|
+
with open(clean_path, "rb") as f:
|
|
88
|
+
file_content = f.read()
|
|
89
|
+
cache_key = hashlib.md5(file_content).hexdigest()
|
|
90
|
+
return cache_key
|
|
91
|
+
else:
|
|
92
|
+
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
|
93
|
+
return cache_key
|
|
94
|
+
|
|
95
|
+
except Exception as e:
|
|
96
|
+
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
|
97
|
+
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
|
98
|
+
return cache_key
|
|
99
|
+
|
|
100
|
+
async def _upload_file(
|
|
101
|
+
self,
|
|
102
|
+
file_data: bytes,
|
|
103
|
+
session_id: str | None = None,
|
|
104
|
+
cache_key: str | None = None,
|
|
105
|
+
) -> str:
|
|
106
|
+
"""上传文件到 Coze 并返回 file_id"""
|
|
107
|
+
# 使用 API 客户端上传文件
|
|
108
|
+
file_id = await self.api_client.upload_file(file_data)
|
|
109
|
+
|
|
110
|
+
# 缓存 file_id
|
|
111
|
+
if session_id and cache_key:
|
|
112
|
+
if session_id not in self.file_id_cache:
|
|
113
|
+
self.file_id_cache[session_id] = {}
|
|
114
|
+
self.file_id_cache[session_id][cache_key] = file_id
|
|
115
|
+
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
|
116
|
+
|
|
117
|
+
return file_id
|
|
118
|
+
|
|
119
|
+
async def _download_and_upload_image(
|
|
120
|
+
self, image_url: str, session_id: str | None = None
|
|
121
|
+
) -> str:
|
|
122
|
+
"""下载图片并上传到 Coze,返回 file_id"""
|
|
123
|
+
# 计算哈希实现缓存
|
|
124
|
+
cache_key = self._generate_cache_key(image_url) if session_id else None
|
|
125
|
+
|
|
126
|
+
if session_id and cache_key:
|
|
127
|
+
if session_id not in self.file_id_cache:
|
|
128
|
+
self.file_id_cache[session_id] = {}
|
|
129
|
+
|
|
130
|
+
if cache_key in self.file_id_cache[session_id]:
|
|
131
|
+
file_id = self.file_id_cache[session_id][cache_key]
|
|
132
|
+
return file_id
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
image_data = await self.api_client.download_image(image_url)
|
|
136
|
+
|
|
137
|
+
file_id = await self._upload_file(image_data, session_id, cache_key)
|
|
138
|
+
|
|
139
|
+
if session_id and cache_key:
|
|
140
|
+
self.file_id_cache[session_id][cache_key] = file_id
|
|
141
|
+
|
|
142
|
+
return file_id
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
|
146
|
+
raise Exception(f"处理图片失败: {str(e)}")
|
|
147
|
+
|
|
148
|
+
async def _process_context_images(
|
|
149
|
+
self, content: str | list, session_id: str
|
|
150
|
+
) -> str:
|
|
151
|
+
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
if isinstance(content, str):
|
|
155
|
+
return content
|
|
156
|
+
|
|
157
|
+
processed_content = []
|
|
158
|
+
if session_id not in self.file_id_cache:
|
|
159
|
+
self.file_id_cache[session_id] = {}
|
|
160
|
+
|
|
161
|
+
for item in content:
|
|
162
|
+
if not isinstance(item, dict):
|
|
163
|
+
processed_content.append(item)
|
|
164
|
+
continue
|
|
165
|
+
if item.get("type") == "text":
|
|
166
|
+
processed_content.append(item)
|
|
167
|
+
elif item.get("type") == "image_url":
|
|
168
|
+
# 处理图片逻辑
|
|
169
|
+
if "file_id" in item:
|
|
170
|
+
# 已经有 file_id
|
|
171
|
+
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
|
172
|
+
processed_content.append(item)
|
|
173
|
+
else:
|
|
174
|
+
# 获取图片数据
|
|
175
|
+
image_data = ""
|
|
176
|
+
if "image_url" in item and isinstance(item["image_url"], dict):
|
|
177
|
+
image_data = item["image_url"].get("url", "")
|
|
178
|
+
elif "data" in item:
|
|
179
|
+
image_data = item.get("data", "")
|
|
180
|
+
elif "url" in item:
|
|
181
|
+
image_data = item.get("url", "")
|
|
182
|
+
|
|
183
|
+
if not image_data:
|
|
184
|
+
continue
|
|
185
|
+
# 计算哈希用于缓存
|
|
186
|
+
cache_key = self._generate_cache_key(
|
|
187
|
+
image_data, is_base64=image_data.startswith("data:image/")
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# 检查缓存
|
|
191
|
+
if cache_key in self.file_id_cache[session_id]:
|
|
192
|
+
file_id = self.file_id_cache[session_id][cache_key]
|
|
193
|
+
processed_content.append(
|
|
194
|
+
{"type": "image", "file_id": file_id}
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
# 上传图片并缓存
|
|
198
|
+
if image_data.startswith("data:image/"):
|
|
199
|
+
# base64 处理
|
|
200
|
+
_, encoded = image_data.split(",", 1)
|
|
201
|
+
image_bytes = base64.b64decode(encoded)
|
|
202
|
+
file_id = await self._upload_file(
|
|
203
|
+
image_bytes,
|
|
204
|
+
session_id,
|
|
205
|
+
cache_key,
|
|
206
|
+
)
|
|
207
|
+
elif image_data.startswith(("http://", "https://")):
|
|
208
|
+
# URL 图片
|
|
209
|
+
file_id = await self._download_and_upload_image(
|
|
210
|
+
image_data, session_id
|
|
211
|
+
)
|
|
212
|
+
# 为URL图片也添加缓存
|
|
213
|
+
self.file_id_cache[session_id][cache_key] = file_id
|
|
214
|
+
elif os.path.exists(image_data):
|
|
215
|
+
# 本地文件
|
|
216
|
+
with open(image_data, "rb") as f:
|
|
217
|
+
image_bytes = f.read()
|
|
218
|
+
file_id = await self._upload_file(
|
|
219
|
+
image_bytes,
|
|
220
|
+
session_id,
|
|
221
|
+
cache_key,
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
logger.warning(
|
|
225
|
+
f"无法处理的图片格式: {image_data[:50]}..."
|
|
226
|
+
)
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
processed_content.append(
|
|
230
|
+
{"type": "image", "file_id": file_id}
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
result = json.dumps(processed_content, ensure_ascii=False)
|
|
234
|
+
return result
|
|
235
|
+
except Exception as e:
|
|
236
|
+
logger.error(f"处理上下文图片失败: {str(e)}")
|
|
237
|
+
if isinstance(content, str):
|
|
238
|
+
return content
|
|
239
|
+
else:
|
|
240
|
+
return json.dumps(content, ensure_ascii=False)
|
|
241
|
+
|
|
242
|
+
async def text_chat(
|
|
243
|
+
self,
|
|
244
|
+
prompt: str,
|
|
245
|
+
session_id=None,
|
|
246
|
+
image_urls=None,
|
|
247
|
+
func_tool=None,
|
|
248
|
+
contexts=None,
|
|
249
|
+
system_prompt=None,
|
|
250
|
+
tool_calls_result=None,
|
|
251
|
+
model=None,
|
|
252
|
+
**kwargs,
|
|
253
|
+
) -> LLMResponse:
|
|
254
|
+
"""文本对话, 内部使用流式接口实现非流式
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
prompt (str): 用户提示词
|
|
258
|
+
session_id (str): 会话ID
|
|
259
|
+
image_urls (List[str]): 图片URL列表
|
|
260
|
+
func_tool (FuncCall): 函数调用工具(不支持)
|
|
261
|
+
contexts (List): 上下文列表
|
|
262
|
+
system_prompt (str): 系统提示语
|
|
263
|
+
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
|
264
|
+
model (str): 模型名称(不支持)
|
|
265
|
+
Returns:
|
|
266
|
+
LLMResponse: LLM响应对象
|
|
267
|
+
"""
|
|
268
|
+
accumulated_content = ""
|
|
269
|
+
final_response = None
|
|
270
|
+
|
|
271
|
+
async for llm_response in self.text_chat_stream(
|
|
272
|
+
prompt=prompt,
|
|
273
|
+
session_id=session_id,
|
|
274
|
+
image_urls=image_urls,
|
|
275
|
+
func_tool=func_tool,
|
|
276
|
+
contexts=contexts,
|
|
277
|
+
system_prompt=system_prompt,
|
|
278
|
+
tool_calls_result=tool_calls_result,
|
|
279
|
+
model=model,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
if llm_response.is_chunk:
|
|
283
|
+
if llm_response.completion_text:
|
|
284
|
+
accumulated_content += llm_response.completion_text
|
|
285
|
+
else:
|
|
286
|
+
final_response = llm_response
|
|
287
|
+
|
|
288
|
+
if final_response:
|
|
289
|
+
return final_response
|
|
290
|
+
|
|
291
|
+
if accumulated_content:
|
|
292
|
+
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
293
|
+
return LLMResponse(role="assistant", result_chain=chain)
|
|
294
|
+
else:
|
|
295
|
+
return LLMResponse(role="assistant", completion_text="")
|
|
296
|
+
|
|
297
|
+
async def text_chat_stream(
|
|
298
|
+
self,
|
|
299
|
+
prompt: str,
|
|
300
|
+
session_id=None,
|
|
301
|
+
image_urls=None,
|
|
302
|
+
func_tool=None,
|
|
303
|
+
contexts=None,
|
|
304
|
+
system_prompt=None,
|
|
305
|
+
tool_calls_result=None,
|
|
306
|
+
model=None,
|
|
307
|
+
**kwargs,
|
|
308
|
+
) -> AsyncGenerator[LLMResponse, None]:
|
|
309
|
+
"""流式对话接口"""
|
|
310
|
+
# 用户ID参数(参考文档, 可以自定义)
|
|
311
|
+
user_id = session_id or kwargs.get("user", "default_user")
|
|
312
|
+
|
|
313
|
+
# 获取或创建会话ID
|
|
314
|
+
conversation_id = self.conversation_ids.get(user_id)
|
|
315
|
+
|
|
316
|
+
# 构建消息
|
|
317
|
+
additional_messages = []
|
|
318
|
+
|
|
319
|
+
if system_prompt:
|
|
320
|
+
if not self.auto_save_history or not conversation_id:
|
|
321
|
+
additional_messages.append(
|
|
322
|
+
{"role": "system", "content": system_prompt, "content_type": "text"}
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if not self.auto_save_history and contexts:
|
|
326
|
+
# 如果关闭了自动保存历史,传入上下文
|
|
327
|
+
for ctx in contexts:
|
|
328
|
+
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
|
329
|
+
content = ctx["content"]
|
|
330
|
+
content_type = ctx.get("content_type", "text")
|
|
331
|
+
|
|
332
|
+
# 处理可能包含图片的上下文
|
|
333
|
+
if (
|
|
334
|
+
content_type == "object_string"
|
|
335
|
+
or (isinstance(content, str) and content.startswith("["))
|
|
336
|
+
or (
|
|
337
|
+
isinstance(content, list)
|
|
338
|
+
and any(
|
|
339
|
+
isinstance(item, dict)
|
|
340
|
+
and item.get("type") == "image_url"
|
|
341
|
+
for item in content
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
):
|
|
345
|
+
processed_content = await self._process_context_images(
|
|
346
|
+
content, user_id
|
|
347
|
+
)
|
|
348
|
+
additional_messages.append(
|
|
349
|
+
{
|
|
350
|
+
"role": ctx["role"],
|
|
351
|
+
"content": processed_content,
|
|
352
|
+
"content_type": "object_string",
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
# 纯文本
|
|
357
|
+
additional_messages.append(
|
|
358
|
+
{
|
|
359
|
+
"role": ctx["role"],
|
|
360
|
+
"content": (
|
|
361
|
+
content
|
|
362
|
+
if isinstance(content, str)
|
|
363
|
+
else json.dumps(content, ensure_ascii=False)
|
|
364
|
+
),
|
|
365
|
+
"content_type": "text",
|
|
366
|
+
}
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
|
370
|
+
|
|
371
|
+
if prompt or image_urls:
|
|
372
|
+
if image_urls:
|
|
373
|
+
# 多模态
|
|
374
|
+
object_string_content = []
|
|
375
|
+
if prompt:
|
|
376
|
+
object_string_content.append({"type": "text", "text": prompt})
|
|
377
|
+
|
|
378
|
+
for url in image_urls:
|
|
379
|
+
try:
|
|
380
|
+
if url.startswith(("http://", "https://")):
|
|
381
|
+
# 网络图片
|
|
382
|
+
file_id = await self._download_and_upload_image(
|
|
383
|
+
url, user_id
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
# 本地文件或 base64
|
|
387
|
+
if url.startswith("data:image/"):
|
|
388
|
+
# base64
|
|
389
|
+
_, encoded = url.split(",", 1)
|
|
390
|
+
image_data = base64.b64decode(encoded)
|
|
391
|
+
cache_key = self._generate_cache_key(
|
|
392
|
+
url, is_base64=True
|
|
393
|
+
)
|
|
394
|
+
file_id = await self._upload_file(
|
|
395
|
+
image_data, user_id, cache_key
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
# 本地文件
|
|
399
|
+
if os.path.exists(url):
|
|
400
|
+
with open(url, "rb") as f:
|
|
401
|
+
image_data = f.read()
|
|
402
|
+
# 用文件路径和修改时间来缓存
|
|
403
|
+
file_stat = os.stat(url)
|
|
404
|
+
cache_key = self._generate_cache_key(
|
|
405
|
+
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
|
406
|
+
is_base64=False,
|
|
407
|
+
)
|
|
408
|
+
file_id = await self._upload_file(
|
|
409
|
+
image_data, user_id, cache_key
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
logger.warning(f"图片文件不存在: {url}")
|
|
413
|
+
continue
|
|
414
|
+
|
|
415
|
+
object_string_content.append(
|
|
416
|
+
{
|
|
417
|
+
"type": "image",
|
|
418
|
+
"file_id": file_id,
|
|
419
|
+
}
|
|
420
|
+
)
|
|
421
|
+
except Exception as e:
|
|
422
|
+
logger.error(f"处理图片失败 {url}: {str(e)}")
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
if object_string_content:
|
|
426
|
+
content = json.dumps(object_string_content, ensure_ascii=False)
|
|
427
|
+
additional_messages.append(
|
|
428
|
+
{
|
|
429
|
+
"role": "user",
|
|
430
|
+
"content": content,
|
|
431
|
+
"content_type": "object_string",
|
|
432
|
+
}
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
# 纯文本
|
|
436
|
+
if prompt:
|
|
437
|
+
additional_messages.append(
|
|
438
|
+
{
|
|
439
|
+
"role": "user",
|
|
440
|
+
"content": prompt,
|
|
441
|
+
"content_type": "text",
|
|
442
|
+
}
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
accumulated_content = ""
|
|
447
|
+
message_started = False
|
|
448
|
+
|
|
449
|
+
async for chunk in self.api_client.chat_messages(
|
|
450
|
+
bot_id=self.bot_id,
|
|
451
|
+
user_id=user_id,
|
|
452
|
+
additional_messages=additional_messages,
|
|
453
|
+
conversation_id=conversation_id,
|
|
454
|
+
auto_save_history=self.auto_save_history,
|
|
455
|
+
stream=True,
|
|
456
|
+
timeout=self.timeout,
|
|
457
|
+
):
|
|
458
|
+
event_type = chunk.get("event")
|
|
459
|
+
data = chunk.get("data", {})
|
|
460
|
+
|
|
461
|
+
if event_type == "conversation.chat.created":
|
|
462
|
+
if isinstance(data, dict) and "conversation_id" in data:
|
|
463
|
+
self.conversation_ids[user_id] = data["conversation_id"]
|
|
464
|
+
|
|
465
|
+
elif event_type == "conversation.message.delta":
|
|
466
|
+
if isinstance(data, dict):
|
|
467
|
+
content = data.get("content", "")
|
|
468
|
+
if not content and "delta" in data:
|
|
469
|
+
content = data["delta"].get("content", "")
|
|
470
|
+
if not content and "text" in data:
|
|
471
|
+
content = data.get("text", "")
|
|
472
|
+
|
|
473
|
+
if content:
|
|
474
|
+
message_started = True
|
|
475
|
+
accumulated_content += content
|
|
476
|
+
yield LLMResponse(
|
|
477
|
+
role="assistant",
|
|
478
|
+
completion_text=content,
|
|
479
|
+
is_chunk=True,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
elif event_type == "conversation.message.completed":
|
|
483
|
+
if isinstance(data, dict):
|
|
484
|
+
msg_type = data.get("type")
|
|
485
|
+
if msg_type == "answer" and data.get("role") == "assistant":
|
|
486
|
+
final_content = data.get("content", "")
|
|
487
|
+
if not accumulated_content and final_content:
|
|
488
|
+
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
|
489
|
+
yield LLMResponse(
|
|
490
|
+
role="assistant",
|
|
491
|
+
result_chain=chain,
|
|
492
|
+
is_chunk=False,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
elif event_type == "conversation.chat.completed":
|
|
496
|
+
if accumulated_content:
|
|
497
|
+
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
498
|
+
yield LLMResponse(
|
|
499
|
+
role="assistant",
|
|
500
|
+
result_chain=chain,
|
|
501
|
+
is_chunk=False,
|
|
502
|
+
)
|
|
503
|
+
break
|
|
504
|
+
|
|
505
|
+
elif event_type == "done":
|
|
506
|
+
break
|
|
507
|
+
|
|
508
|
+
elif event_type == "error":
|
|
509
|
+
error_msg = (
|
|
510
|
+
data.get("message", "未知错误")
|
|
511
|
+
if isinstance(data, dict)
|
|
512
|
+
else str(data)
|
|
513
|
+
)
|
|
514
|
+
logger.error(f"Coze 流式响应错误: {error_msg}")
|
|
515
|
+
yield LLMResponse(
|
|
516
|
+
role="err",
|
|
517
|
+
completion_text=f"Coze 错误: {error_msg}",
|
|
518
|
+
is_chunk=False,
|
|
519
|
+
)
|
|
520
|
+
break
|
|
521
|
+
|
|
522
|
+
if not message_started and not accumulated_content:
|
|
523
|
+
yield LLMResponse(
|
|
524
|
+
role="assistant",
|
|
525
|
+
completion_text="LLM 未响应任何内容。",
|
|
526
|
+
is_chunk=False,
|
|
527
|
+
)
|
|
528
|
+
elif message_started and accumulated_content:
|
|
529
|
+
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
|
530
|
+
yield LLMResponse(
|
|
531
|
+
role="assistant",
|
|
532
|
+
result_chain=chain,
|
|
533
|
+
is_chunk=False,
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
except Exception as e:
|
|
537
|
+
logger.error(f"Coze 流式请求失败: {str(e)}")
|
|
538
|
+
yield LLMResponse(
|
|
539
|
+
role="err",
|
|
540
|
+
completion_text=f"Coze 流式请求失败: {str(e)}",
|
|
541
|
+
is_chunk=False,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
async def forget(self, session_id: str):
|
|
545
|
+
"""清空指定会话的上下文"""
|
|
546
|
+
user_id = session_id
|
|
547
|
+
conversation_id = self.conversation_ids.get(user_id)
|
|
548
|
+
|
|
549
|
+
if user_id in self.file_id_cache:
|
|
550
|
+
self.file_id_cache.pop(user_id, None)
|
|
551
|
+
|
|
552
|
+
if not conversation_id:
|
|
553
|
+
return True
|
|
554
|
+
|
|
555
|
+
try:
|
|
556
|
+
response = await self.api_client.clear_context(conversation_id)
|
|
557
|
+
|
|
558
|
+
if "code" in response and response["code"] == 0:
|
|
559
|
+
self.conversation_ids.pop(user_id, None)
|
|
560
|
+
return True
|
|
561
|
+
else:
|
|
562
|
+
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
|
563
|
+
return False
|
|
564
|
+
|
|
565
|
+
except Exception as e:
|
|
566
|
+
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
|
567
|
+
return False
|
|
568
|
+
|
|
569
|
+
async def get_current_key(self):
|
|
570
|
+
"""获取当前API Key"""
|
|
571
|
+
return self.api_key
|
|
572
|
+
|
|
573
|
+
async def set_key(self, key: str):
|
|
574
|
+
"""设置新的API Key"""
|
|
575
|
+
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
|
576
|
+
|
|
577
|
+
async def get_models(self):
|
|
578
|
+
"""获取可用模型列表"""
|
|
579
|
+
return [f"bot_{self.bot_id}"]
|
|
580
|
+
|
|
581
|
+
def get_model(self):
|
|
582
|
+
"""获取当前模型"""
|
|
583
|
+
return f"bot_{self.bot_id}"
|
|
584
|
+
|
|
585
|
+
def set_model(self, model: str):
|
|
586
|
+
"""设置模型(在Coze中是Bot ID)"""
|
|
587
|
+
if model.startswith("bot_"):
|
|
588
|
+
self.bot_id = model[4:]
|
|
589
|
+
else:
|
|
590
|
+
self.bot_id = model
|
|
591
|
+
|
|
592
|
+
async def get_human_readable_context(
|
|
593
|
+
self, session_id: str, page: int = 1, page_size: int = 10
|
|
594
|
+
):
|
|
595
|
+
"""获取人类可读的上下文历史"""
|
|
596
|
+
user_id = session_id
|
|
597
|
+
conversation_id = self.conversation_ids.get(user_id)
|
|
598
|
+
|
|
599
|
+
if not conversation_id:
|
|
600
|
+
return []
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
data = await self.api_client.get_message_list(
|
|
604
|
+
conversation_id=conversation_id,
|
|
605
|
+
order="desc",
|
|
606
|
+
limit=page_size,
|
|
607
|
+
offset=(page - 1) * page_size,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
if data.get("code") != 0:
|
|
611
|
+
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
|
612
|
+
return []
|
|
613
|
+
|
|
614
|
+
messages = data.get("data", {}).get("messages", [])
|
|
615
|
+
|
|
616
|
+
readable_history = []
|
|
617
|
+
for msg in messages:
|
|
618
|
+
role = msg.get("role", "unknown")
|
|
619
|
+
content = msg.get("content", "")
|
|
620
|
+
msg_type = msg.get("type", "")
|
|
621
|
+
|
|
622
|
+
if role == "user":
|
|
623
|
+
readable_history.append(f"用户: {content}")
|
|
624
|
+
elif role == "assistant" and msg_type == "answer":
|
|
625
|
+
readable_history.append(f"助手: {content}")
|
|
626
|
+
|
|
627
|
+
return readable_history
|
|
628
|
+
|
|
629
|
+
except Exception as e:
|
|
630
|
+
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
|
631
|
+
return []
|
|
632
|
+
|
|
633
|
+
async def terminate(self):
|
|
634
|
+
"""清理资源"""
|
|
635
|
+
await self.api_client.close()
|
|
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
|
|
|
32
32
|
self.init_handler_md(handler_md)
|
|
33
33
|
self.custom_filter_list: List[CustomFilter] = []
|
|
34
34
|
|
|
35
|
+
# Cache for complete command names list
|
|
36
|
+
self._cmpl_cmd_names: list | None = None
|
|
37
|
+
|
|
35
38
|
def print_types(self):
|
|
36
39
|
result = ""
|
|
37
40
|
for k, v in self.handler_params.items():
|
|
@@ -136,6 +139,22 @@ class CommandFilter(HandlerFilter):
|
|
|
136
139
|
)
|
|
137
140
|
return result
|
|
138
141
|
|
|
142
|
+
def get_complete_command_names(self):
|
|
143
|
+
if self._cmpl_cmd_names is not None:
|
|
144
|
+
return self._cmpl_cmd_names
|
|
145
|
+
self._cmpl_cmd_names = [
|
|
146
|
+
f"{parent} {cmd}" if parent else cmd
|
|
147
|
+
for cmd in [self.command_name] + list(self.alias)
|
|
148
|
+
for parent in self.parent_command_names or [""]
|
|
149
|
+
]
|
|
150
|
+
return self._cmpl_cmd_names
|
|
151
|
+
|
|
152
|
+
def equals(self, message_str: str) -> bool:
|
|
153
|
+
for full_cmd in self.get_complete_command_names():
|
|
154
|
+
if message_str == full_cmd:
|
|
155
|
+
return True
|
|
156
|
+
return False
|
|
157
|
+
|
|
139
158
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
|
140
159
|
if not event.is_at_or_wake_command:
|
|
141
160
|
return False
|
|
@@ -145,18 +164,11 @@ class CommandFilter(HandlerFilter):
|
|
|
145
164
|
|
|
146
165
|
# 检查是否以指令开头
|
|
147
166
|
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
|
148
|
-
candidates = [self.command_name] + list(self.alias)
|
|
149
167
|
ok = False
|
|
150
|
-
for
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
else:
|
|
155
|
-
_full = candidate
|
|
156
|
-
if message_str.startswith(f"{_full} ") or message_str == _full:
|
|
157
|
-
message_str = message_str[len(_full) :].strip()
|
|
158
|
-
ok = True
|
|
159
|
-
break
|
|
168
|
+
for full_cmd in self.get_complete_command_names():
|
|
169
|
+
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
|
170
|
+
ok = True
|
|
171
|
+
message_str = message_str[len(full_cmd) :].strip()
|
|
160
172
|
if not ok:
|
|
161
173
|
return False
|
|
162
174
|
|