AstrBot 4.1.7__py3-none-any.whl → 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,635 @@
1
+ import json
2
+ import os
3
+ import base64
4
+ import hashlib
5
+ from typing import AsyncGenerator, Dict
6
+ from astrbot.core.message.message_event_result import MessageChain
7
+ import astrbot.core.message.components as Comp
8
+ from astrbot.api.provider import Provider
9
+ from astrbot import logger
10
+ from astrbot.core.provider.entities import LLMResponse
11
+ from ..register import register_provider_adapter
12
+ from .coze_api_client import CozeAPIClient
13
+
14
+
15
+ @register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
16
+ class ProviderCoze(Provider):
17
+ def __init__(
18
+ self,
19
+ provider_config,
20
+ provider_settings,
21
+ default_persona=None,
22
+ ) -> None:
23
+ super().__init__(
24
+ provider_config,
25
+ provider_settings,
26
+ default_persona,
27
+ )
28
+ self.api_key = provider_config.get("coze_api_key", "")
29
+ if not self.api_key:
30
+ raise Exception("Coze API Key 不能为空。")
31
+ self.bot_id = provider_config.get("bot_id", "")
32
+ if not self.bot_id:
33
+ raise Exception("Coze Bot ID 不能为空。")
34
+ self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
35
+
36
+ if not isinstance(self.api_base, str) or not self.api_base.startswith(
37
+ ("http://", "https://")
38
+ ):
39
+ raise Exception(
40
+ "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
41
+ )
42
+
43
+ self.timeout = provider_config.get("timeout", 120)
44
+ if isinstance(self.timeout, str):
45
+ self.timeout = int(self.timeout)
46
+ self.auto_save_history = provider_config.get("auto_save_history", True)
47
+ self.conversation_ids: Dict[str, str] = {}
48
+ self.file_id_cache: Dict[str, Dict[str, str]] = {}
49
+
50
+ # 创建 API 客户端
51
+ self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
52
+
53
+ def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
54
+ """生成统一的缓存键
55
+
56
+ Args:
57
+ data: 图片数据或路径
58
+ is_base64: 是否是 base64 数据
59
+
60
+ Returns:
61
+ str: 缓存键
62
+ """
63
+
64
+ try:
65
+ if is_base64 and data.startswith("data:image/"):
66
+ try:
67
+ header, encoded = data.split(",", 1)
68
+ image_bytes = base64.b64decode(encoded)
69
+ cache_key = hashlib.md5(image_bytes).hexdigest()
70
+ return cache_key
71
+ except Exception:
72
+ cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
73
+ return cache_key
74
+ else:
75
+ if data.startswith(("http://", "https://")):
76
+ # URL图片,使用URL作为缓存键
77
+ cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
78
+ return cache_key
79
+ else:
80
+ clean_path = (
81
+ data.split("_")[0]
82
+ if "_" in data and len(data.split("_")) >= 3
83
+ else data
84
+ )
85
+
86
+ if os.path.exists(clean_path):
87
+ with open(clean_path, "rb") as f:
88
+ file_content = f.read()
89
+ cache_key = hashlib.md5(file_content).hexdigest()
90
+ return cache_key
91
+ else:
92
+ cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
93
+ return cache_key
94
+
95
+ except Exception as e:
96
+ cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
97
+ logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
98
+ return cache_key
99
+
100
+ async def _upload_file(
101
+ self,
102
+ file_data: bytes,
103
+ session_id: str | None = None,
104
+ cache_key: str | None = None,
105
+ ) -> str:
106
+ """上传文件到 Coze 并返回 file_id"""
107
+ # 使用 API 客户端上传文件
108
+ file_id = await self.api_client.upload_file(file_data)
109
+
110
+ # 缓存 file_id
111
+ if session_id and cache_key:
112
+ if session_id not in self.file_id_cache:
113
+ self.file_id_cache[session_id] = {}
114
+ self.file_id_cache[session_id][cache_key] = file_id
115
+ logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
116
+
117
+ return file_id
118
+
119
+ async def _download_and_upload_image(
120
+ self, image_url: str, session_id: str | None = None
121
+ ) -> str:
122
+ """下载图片并上传到 Coze,返回 file_id"""
123
+ # 计算哈希实现缓存
124
+ cache_key = self._generate_cache_key(image_url) if session_id else None
125
+
126
+ if session_id and cache_key:
127
+ if session_id not in self.file_id_cache:
128
+ self.file_id_cache[session_id] = {}
129
+
130
+ if cache_key in self.file_id_cache[session_id]:
131
+ file_id = self.file_id_cache[session_id][cache_key]
132
+ return file_id
133
+
134
+ try:
135
+ image_data = await self.api_client.download_image(image_url)
136
+
137
+ file_id = await self._upload_file(image_data, session_id, cache_key)
138
+
139
+ if session_id and cache_key:
140
+ self.file_id_cache[session_id][cache_key] = file_id
141
+
142
+ return file_id
143
+
144
+ except Exception as e:
145
+ logger.error(f"处理图片失败 {image_url}: {str(e)}")
146
+ raise Exception(f"处理图片失败: {str(e)}")
147
+
148
+ async def _process_context_images(
149
+ self, content: str | list, session_id: str
150
+ ) -> str:
151
+ """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
152
+
153
+ try:
154
+ if isinstance(content, str):
155
+ return content
156
+
157
+ processed_content = []
158
+ if session_id not in self.file_id_cache:
159
+ self.file_id_cache[session_id] = {}
160
+
161
+ for item in content:
162
+ if not isinstance(item, dict):
163
+ processed_content.append(item)
164
+ continue
165
+ if item.get("type") == "text":
166
+ processed_content.append(item)
167
+ elif item.get("type") == "image_url":
168
+ # 处理图片逻辑
169
+ if "file_id" in item:
170
+ # 已经有 file_id
171
+ logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
172
+ processed_content.append(item)
173
+ else:
174
+ # 获取图片数据
175
+ image_data = ""
176
+ if "image_url" in item and isinstance(item["image_url"], dict):
177
+ image_data = item["image_url"].get("url", "")
178
+ elif "data" in item:
179
+ image_data = item.get("data", "")
180
+ elif "url" in item:
181
+ image_data = item.get("url", "")
182
+
183
+ if not image_data:
184
+ continue
185
+ # 计算哈希用于缓存
186
+ cache_key = self._generate_cache_key(
187
+ image_data, is_base64=image_data.startswith("data:image/")
188
+ )
189
+
190
+ # 检查缓存
191
+ if cache_key in self.file_id_cache[session_id]:
192
+ file_id = self.file_id_cache[session_id][cache_key]
193
+ processed_content.append(
194
+ {"type": "image", "file_id": file_id}
195
+ )
196
+ else:
197
+ # 上传图片并缓存
198
+ if image_data.startswith("data:image/"):
199
+ # base64 处理
200
+ _, encoded = image_data.split(",", 1)
201
+ image_bytes = base64.b64decode(encoded)
202
+ file_id = await self._upload_file(
203
+ image_bytes,
204
+ session_id,
205
+ cache_key,
206
+ )
207
+ elif image_data.startswith(("http://", "https://")):
208
+ # URL 图片
209
+ file_id = await self._download_and_upload_image(
210
+ image_data, session_id
211
+ )
212
+ # 为URL图片也添加缓存
213
+ self.file_id_cache[session_id][cache_key] = file_id
214
+ elif os.path.exists(image_data):
215
+ # 本地文件
216
+ with open(image_data, "rb") as f:
217
+ image_bytes = f.read()
218
+ file_id = await self._upload_file(
219
+ image_bytes,
220
+ session_id,
221
+ cache_key,
222
+ )
223
+ else:
224
+ logger.warning(
225
+ f"无法处理的图片格式: {image_data[:50]}..."
226
+ )
227
+ continue
228
+
229
+ processed_content.append(
230
+ {"type": "image", "file_id": file_id}
231
+ )
232
+
233
+ result = json.dumps(processed_content, ensure_ascii=False)
234
+ return result
235
+ except Exception as e:
236
+ logger.error(f"处理上下文图片失败: {str(e)}")
237
+ if isinstance(content, str):
238
+ return content
239
+ else:
240
+ return json.dumps(content, ensure_ascii=False)
241
+
242
+ async def text_chat(
243
+ self,
244
+ prompt: str,
245
+ session_id=None,
246
+ image_urls=None,
247
+ func_tool=None,
248
+ contexts=None,
249
+ system_prompt=None,
250
+ tool_calls_result=None,
251
+ model=None,
252
+ **kwargs,
253
+ ) -> LLMResponse:
254
+ """文本对话, 内部使用流式接口实现非流式
255
+
256
+ Args:
257
+ prompt (str): 用户提示词
258
+ session_id (str): 会话ID
259
+ image_urls (List[str]): 图片URL列表
260
+ func_tool (FuncCall): 函数调用工具(不支持)
261
+ contexts (List): 上下文列表
262
+ system_prompt (str): 系统提示语
263
+ tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
264
+ model (str): 模型名称(不支持)
265
+ Returns:
266
+ LLMResponse: LLM响应对象
267
+ """
268
+ accumulated_content = ""
269
+ final_response = None
270
+
271
+ async for llm_response in self.text_chat_stream(
272
+ prompt=prompt,
273
+ session_id=session_id,
274
+ image_urls=image_urls,
275
+ func_tool=func_tool,
276
+ contexts=contexts,
277
+ system_prompt=system_prompt,
278
+ tool_calls_result=tool_calls_result,
279
+ model=model,
280
+ **kwargs,
281
+ ):
282
+ if llm_response.is_chunk:
283
+ if llm_response.completion_text:
284
+ accumulated_content += llm_response.completion_text
285
+ else:
286
+ final_response = llm_response
287
+
288
+ if final_response:
289
+ return final_response
290
+
291
+ if accumulated_content:
292
+ chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
293
+ return LLMResponse(role="assistant", result_chain=chain)
294
+ else:
295
+ return LLMResponse(role="assistant", completion_text="")
296
+
297
+ async def text_chat_stream(
298
+ self,
299
+ prompt: str,
300
+ session_id=None,
301
+ image_urls=None,
302
+ func_tool=None,
303
+ contexts=None,
304
+ system_prompt=None,
305
+ tool_calls_result=None,
306
+ model=None,
307
+ **kwargs,
308
+ ) -> AsyncGenerator[LLMResponse, None]:
309
+ """流式对话接口"""
310
+ # 用户ID参数(参考文档, 可以自定义)
311
+ user_id = session_id or kwargs.get("user", "default_user")
312
+
313
+ # 获取或创建会话ID
314
+ conversation_id = self.conversation_ids.get(user_id)
315
+
316
+ # 构建消息
317
+ additional_messages = []
318
+
319
+ if system_prompt:
320
+ if not self.auto_save_history or not conversation_id:
321
+ additional_messages.append(
322
+ {"role": "system", "content": system_prompt, "content_type": "text"}
323
+ )
324
+
325
+ if not self.auto_save_history and contexts:
326
+ # 如果关闭了自动保存历史,传入上下文
327
+ for ctx in contexts:
328
+ if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
329
+ content = ctx["content"]
330
+ content_type = ctx.get("content_type", "text")
331
+
332
+ # 处理可能包含图片的上下文
333
+ if (
334
+ content_type == "object_string"
335
+ or (isinstance(content, str) and content.startswith("["))
336
+ or (
337
+ isinstance(content, list)
338
+ and any(
339
+ isinstance(item, dict)
340
+ and item.get("type") == "image_url"
341
+ for item in content
342
+ )
343
+ )
344
+ ):
345
+ processed_content = await self._process_context_images(
346
+ content, user_id
347
+ )
348
+ additional_messages.append(
349
+ {
350
+ "role": ctx["role"],
351
+ "content": processed_content,
352
+ "content_type": "object_string",
353
+ }
354
+ )
355
+ else:
356
+ # 纯文本
357
+ additional_messages.append(
358
+ {
359
+ "role": ctx["role"],
360
+ "content": (
361
+ content
362
+ if isinstance(content, str)
363
+ else json.dumps(content, ensure_ascii=False)
364
+ ),
365
+ "content_type": "text",
366
+ }
367
+ )
368
+ else:
369
+ logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
370
+
371
+ if prompt or image_urls:
372
+ if image_urls:
373
+ # 多模态
374
+ object_string_content = []
375
+ if prompt:
376
+ object_string_content.append({"type": "text", "text": prompt})
377
+
378
+ for url in image_urls:
379
+ try:
380
+ if url.startswith(("http://", "https://")):
381
+ # 网络图片
382
+ file_id = await self._download_and_upload_image(
383
+ url, user_id
384
+ )
385
+ else:
386
+ # 本地文件或 base64
387
+ if url.startswith("data:image/"):
388
+ # base64
389
+ _, encoded = url.split(",", 1)
390
+ image_data = base64.b64decode(encoded)
391
+ cache_key = self._generate_cache_key(
392
+ url, is_base64=True
393
+ )
394
+ file_id = await self._upload_file(
395
+ image_data, user_id, cache_key
396
+ )
397
+ else:
398
+ # 本地文件
399
+ if os.path.exists(url):
400
+ with open(url, "rb") as f:
401
+ image_data = f.read()
402
+ # 用文件路径和修改时间来缓存
403
+ file_stat = os.stat(url)
404
+ cache_key = self._generate_cache_key(
405
+ f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
406
+ is_base64=False,
407
+ )
408
+ file_id = await self._upload_file(
409
+ image_data, user_id, cache_key
410
+ )
411
+ else:
412
+ logger.warning(f"图片文件不存在: {url}")
413
+ continue
414
+
415
+ object_string_content.append(
416
+ {
417
+ "type": "image",
418
+ "file_id": file_id,
419
+ }
420
+ )
421
+ except Exception as e:
422
+ logger.error(f"处理图片失败 {url}: {str(e)}")
423
+ continue
424
+
425
+ if object_string_content:
426
+ content = json.dumps(object_string_content, ensure_ascii=False)
427
+ additional_messages.append(
428
+ {
429
+ "role": "user",
430
+ "content": content,
431
+ "content_type": "object_string",
432
+ }
433
+ )
434
+ else:
435
+ # 纯文本
436
+ if prompt:
437
+ additional_messages.append(
438
+ {
439
+ "role": "user",
440
+ "content": prompt,
441
+ "content_type": "text",
442
+ }
443
+ )
444
+
445
+ try:
446
+ accumulated_content = ""
447
+ message_started = False
448
+
449
+ async for chunk in self.api_client.chat_messages(
450
+ bot_id=self.bot_id,
451
+ user_id=user_id,
452
+ additional_messages=additional_messages,
453
+ conversation_id=conversation_id,
454
+ auto_save_history=self.auto_save_history,
455
+ stream=True,
456
+ timeout=self.timeout,
457
+ ):
458
+ event_type = chunk.get("event")
459
+ data = chunk.get("data", {})
460
+
461
+ if event_type == "conversation.chat.created":
462
+ if isinstance(data, dict) and "conversation_id" in data:
463
+ self.conversation_ids[user_id] = data["conversation_id"]
464
+
465
+ elif event_type == "conversation.message.delta":
466
+ if isinstance(data, dict):
467
+ content = data.get("content", "")
468
+ if not content and "delta" in data:
469
+ content = data["delta"].get("content", "")
470
+ if not content and "text" in data:
471
+ content = data.get("text", "")
472
+
473
+ if content:
474
+ message_started = True
475
+ accumulated_content += content
476
+ yield LLMResponse(
477
+ role="assistant",
478
+ completion_text=content,
479
+ is_chunk=True,
480
+ )
481
+
482
+ elif event_type == "conversation.message.completed":
483
+ if isinstance(data, dict):
484
+ msg_type = data.get("type")
485
+ if msg_type == "answer" and data.get("role") == "assistant":
486
+ final_content = data.get("content", "")
487
+ if not accumulated_content and final_content:
488
+ chain = MessageChain(chain=[Comp.Plain(final_content)])
489
+ yield LLMResponse(
490
+ role="assistant",
491
+ result_chain=chain,
492
+ is_chunk=False,
493
+ )
494
+
495
+ elif event_type == "conversation.chat.completed":
496
+ if accumulated_content:
497
+ chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
498
+ yield LLMResponse(
499
+ role="assistant",
500
+ result_chain=chain,
501
+ is_chunk=False,
502
+ )
503
+ break
504
+
505
+ elif event_type == "done":
506
+ break
507
+
508
+ elif event_type == "error":
509
+ error_msg = (
510
+ data.get("message", "未知错误")
511
+ if isinstance(data, dict)
512
+ else str(data)
513
+ )
514
+ logger.error(f"Coze 流式响应错误: {error_msg}")
515
+ yield LLMResponse(
516
+ role="err",
517
+ completion_text=f"Coze 错误: {error_msg}",
518
+ is_chunk=False,
519
+ )
520
+ break
521
+
522
+ if not message_started and not accumulated_content:
523
+ yield LLMResponse(
524
+ role="assistant",
525
+ completion_text="LLM 未响应任何内容。",
526
+ is_chunk=False,
527
+ )
528
+ elif message_started and accumulated_content:
529
+ chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
530
+ yield LLMResponse(
531
+ role="assistant",
532
+ result_chain=chain,
533
+ is_chunk=False,
534
+ )
535
+
536
+ except Exception as e:
537
+ logger.error(f"Coze 流式请求失败: {str(e)}")
538
+ yield LLMResponse(
539
+ role="err",
540
+ completion_text=f"Coze 流式请求失败: {str(e)}",
541
+ is_chunk=False,
542
+ )
543
+
544
+ async def forget(self, session_id: str):
545
+ """清空指定会话的上下文"""
546
+ user_id = session_id
547
+ conversation_id = self.conversation_ids.get(user_id)
548
+
549
+ if user_id in self.file_id_cache:
550
+ self.file_id_cache.pop(user_id, None)
551
+
552
+ if not conversation_id:
553
+ return True
554
+
555
+ try:
556
+ response = await self.api_client.clear_context(conversation_id)
557
+
558
+ if "code" in response and response["code"] == 0:
559
+ self.conversation_ids.pop(user_id, None)
560
+ return True
561
+ else:
562
+ logger.warning(f"清空 Coze 会话上下文失败: {response}")
563
+ return False
564
+
565
+ except Exception as e:
566
+ logger.error(f"清空 Coze 会话失败: {str(e)}")
567
+ return False
568
+
569
+ async def get_current_key(self):
570
+ """获取当前API Key"""
571
+ return self.api_key
572
+
573
+ async def set_key(self, key: str):
574
+ """设置新的API Key"""
575
+ raise NotImplementedError("Coze 适配器不支持设置 API Key。")
576
+
577
+ async def get_models(self):
578
+ """获取可用模型列表"""
579
+ return [f"bot_{self.bot_id}"]
580
+
581
+ def get_model(self):
582
+ """获取当前模型"""
583
+ return f"bot_{self.bot_id}"
584
+
585
+ def set_model(self, model: str):
586
+ """设置模型(在Coze中是Bot ID)"""
587
+ if model.startswith("bot_"):
588
+ self.bot_id = model[4:]
589
+ else:
590
+ self.bot_id = model
591
+
592
+ async def get_human_readable_context(
593
+ self, session_id: str, page: int = 1, page_size: int = 10
594
+ ):
595
+ """获取人类可读的上下文历史"""
596
+ user_id = session_id
597
+ conversation_id = self.conversation_ids.get(user_id)
598
+
599
+ if not conversation_id:
600
+ return []
601
+
602
+ try:
603
+ data = await self.api_client.get_message_list(
604
+ conversation_id=conversation_id,
605
+ order="desc",
606
+ limit=page_size,
607
+ offset=(page - 1) * page_size,
608
+ )
609
+
610
+ if data.get("code") != 0:
611
+ logger.warning(f"获取 Coze 消息历史失败: {data}")
612
+ return []
613
+
614
+ messages = data.get("data", {}).get("messages", [])
615
+
616
+ readable_history = []
617
+ for msg in messages:
618
+ role = msg.get("role", "unknown")
619
+ content = msg.get("content", "")
620
+ msg_type = msg.get("type", "")
621
+
622
+ if role == "user":
623
+ readable_history.append(f"用户: {content}")
624
+ elif role == "assistant" and msg_type == "answer":
625
+ readable_history.append(f"助手: {content}")
626
+
627
+ return readable_history
628
+
629
+ except Exception as e:
630
+ logger.error(f"获取 Coze 消息历史失败: {str(e)}")
631
+ return []
632
+
633
+ async def terminate(self):
634
+ """清理资源"""
635
+ await self.api_client.close()
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
32
32
  self.init_handler_md(handler_md)
33
33
  self.custom_filter_list: List[CustomFilter] = []
34
34
 
35
+ # Cache for complete command names list
36
+ self._cmpl_cmd_names: list | None = None
37
+
35
38
  def print_types(self):
36
39
  result = ""
37
40
  for k, v in self.handler_params.items():
@@ -136,6 +139,22 @@ class CommandFilter(HandlerFilter):
136
139
  )
137
140
  return result
138
141
 
142
+ def get_complete_command_names(self):
143
+ if self._cmpl_cmd_names is not None:
144
+ return self._cmpl_cmd_names
145
+ self._cmpl_cmd_names = [
146
+ f"{parent} {cmd}" if parent else cmd
147
+ for cmd in [self.command_name] + list(self.alias)
148
+ for parent in self.parent_command_names or [""]
149
+ ]
150
+ return self._cmpl_cmd_names
151
+
152
+ def equals(self, message_str: str) -> bool:
153
+ for full_cmd in self.get_complete_command_names():
154
+ if message_str == full_cmd:
155
+ return True
156
+ return False
157
+
139
158
  def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
140
159
  if not event.is_at_or_wake_command:
141
160
  return False
@@ -145,18 +164,11 @@ class CommandFilter(HandlerFilter):
145
164
 
146
165
  # 检查是否以指令开头
147
166
  message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
148
- candidates = [self.command_name] + list(self.alias)
149
167
  ok = False
150
- for candidate in candidates:
151
- for parent_command_name in self.parent_command_names:
152
- if parent_command_name:
153
- _full = f"{parent_command_name} {candidate}"
154
- else:
155
- _full = candidate
156
- if message_str.startswith(f"{_full} ") or message_str == _full:
157
- message_str = message_str[len(_full) :].strip()
158
- ok = True
159
- break
168
+ for full_cmd in self.get_complete_command_names():
169
+ if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
170
+ ok = True
171
+ message_str = message_str[len(full_cmd) :].strip()
160
172
  if not ok:
161
173
  return False
162
174