AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. astrbot/core/agent/mcp_client.py +18 -4
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  3. astrbot/core/astr_agent_context.py +1 -0
  4. astrbot/core/astrbot_config_mgr.py +23 -51
  5. astrbot/core/config/default.py +139 -14
  6. astrbot/core/conversation_mgr.py +36 -1
  7. astrbot/core/core_lifecycle.py +24 -5
  8. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  9. astrbot/core/db/vec_db/base.py +33 -2
  10. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  11. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  12. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  13. astrbot/core/file_token_service.py +6 -1
  14. astrbot/core/initial_loader.py +6 -3
  15. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  16. astrbot/core/knowledge_base/chunking/base.py +24 -0
  17. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  18. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  19. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  20. astrbot/core/knowledge_base/kb_helper.py +348 -0
  21. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  22. astrbot/core/knowledge_base/models.py +114 -0
  23. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  24. astrbot/core/knowledge_base/parsers/base.py +50 -0
  25. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  26. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  27. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  28. astrbot/core/knowledge_base/parsers/util.py +13 -0
  29. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  30. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  31. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  32. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  33. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  34. astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
  35. astrbot/core/pipeline/process_stage/utils.py +80 -0
  36. astrbot/core/pipeline/scheduler.py +1 -1
  37. astrbot/core/platform/astr_message_event.py +8 -7
  38. astrbot/core/platform/manager.py +4 -0
  39. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  40. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  41. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  42. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  43. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  44. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  45. astrbot/core/platform/sources/satori/satori_event.py +270 -77
  46. astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
  47. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
  48. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
  49. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
  50. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
  51. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
  52. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
  53. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
  54. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
  55. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
  56. astrbot/core/provider/manager.py +14 -9
  57. astrbot/core/provider/provider.py +67 -0
  58. astrbot/core/provider/sources/anthropic_source.py +4 -4
  59. astrbot/core/provider/sources/dashscope_source.py +10 -9
  60. astrbot/core/provider/sources/dify_source.py +6 -8
  61. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  62. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  63. astrbot/core/provider/sources/openai_source.py +18 -15
  64. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  65. astrbot/core/star/context.py +3 -0
  66. astrbot/core/star/star.py +6 -0
  67. astrbot/core/star/star_manager.py +13 -7
  68. astrbot/core/umop_config_router.py +81 -0
  69. astrbot/core/updator.py +1 -1
  70. astrbot/core/utils/io.py +23 -12
  71. astrbot/dashboard/routes/__init__.py +2 -0
  72. astrbot/dashboard/routes/config.py +137 -9
  73. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  74. astrbot/dashboard/routes/plugin.py +24 -5
  75. astrbot/dashboard/routes/tools.py +14 -0
  76. astrbot/dashboard/routes/update.py +1 -1
  77. astrbot/dashboard/server.py +6 -0
  78. astrbot/dashboard/utils.py +161 -0
  79. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
  80. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
  81. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  82. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  83. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,166 @@
1
+ """
2
+ 企业微信智能机器人 HTTP 服务器
3
+ 处理企业微信智能机器人的 HTTP 回调请求
4
+ """
5
+
6
+ import asyncio
7
+ from typing import Dict, Any, Optional, Callable
8
+
9
+ import quart
10
+ from astrbot.api import logger
11
+
12
+ from .wecomai_api import WecomAIBotAPIClient
13
+ from .wecomai_utils import WecomAIBotConstants
14
+
15
+
16
+ class WecomAIBotServer:
17
+ """企业微信智能机器人 HTTP 服务器"""
18
+
19
+ def __init__(
20
+ self,
21
+ host: str,
22
+ port: int,
23
+ api_client: WecomAIBotAPIClient,
24
+ message_handler: Optional[
25
+ Callable[[Dict[str, Any], Dict[str, str]], Any]
26
+ ] = None,
27
+ ):
28
+ """初始化服务器
29
+
30
+ Args:
31
+ host: 监听地址
32
+ port: 监听端口
33
+ api_client: API客户端实例
34
+ message_handler: 消息处理回调函数
35
+ """
36
+ self.host = host
37
+ self.port = port
38
+ self.api_client = api_client
39
+ self.message_handler = message_handler
40
+
41
+ self.app = quart.Quart(__name__)
42
+ self._setup_routes()
43
+
44
+ self.shutdown_event = asyncio.Event()
45
+
46
+ def _setup_routes(self):
47
+ """设置 Quart 路由"""
48
+
49
+ # 使用 Quart 的 add_url_rule 方法添加路由
50
+ self.app.add_url_rule(
51
+ "/webhook/wecom-ai-bot",
52
+ view_func=self.verify_url,
53
+ methods=["GET"],
54
+ )
55
+
56
+ self.app.add_url_rule(
57
+ "/webhook/wecom-ai-bot",
58
+ view_func=self.handle_message,
59
+ methods=["POST"],
60
+ )
61
+
62
+ async def verify_url(self):
63
+ """验证回调 URL"""
64
+ args = quart.request.args
65
+ msg_signature = args.get("msg_signature")
66
+ timestamp = args.get("timestamp")
67
+ nonce = args.get("nonce")
68
+ echostr = args.get("echostr")
69
+
70
+ if not all([msg_signature, timestamp, nonce, echostr]):
71
+ logger.error("URL 验证参数缺失")
72
+ return "verify fail", 400
73
+
74
+ # 类型检查确保不为 None
75
+ assert msg_signature is not None
76
+ assert timestamp is not None
77
+ assert nonce is not None
78
+ assert echostr is not None
79
+
80
+ logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
81
+ result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
82
+ return result, 200, {"Content-Type": "text/plain"}
83
+
84
+ async def handle_message(self):
85
+ """处理消息回调"""
86
+ args = quart.request.args
87
+ msg_signature = args.get("msg_signature")
88
+ timestamp = args.get("timestamp")
89
+ nonce = args.get("nonce")
90
+
91
+ if not all([msg_signature, timestamp, nonce]):
92
+ logger.error("消息回调参数缺失")
93
+ return "缺少必要参数", 400
94
+
95
+ # 类型检查确保不为 None
96
+ assert msg_signature is not None
97
+ assert timestamp is not None
98
+ assert nonce is not None
99
+
100
+ logger.debug(
101
+ f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
102
+ )
103
+
104
+ try:
105
+ # 获取请求体
106
+ post_data = await quart.request.get_data()
107
+
108
+ # 确保 post_data 是 bytes 类型
109
+ if isinstance(post_data, str):
110
+ post_data = post_data.encode("utf-8")
111
+
112
+ # 解密消息
113
+ ret_code, message_data = await self.api_client.decrypt_message(
114
+ post_data, msg_signature, timestamp, nonce
115
+ )
116
+
117
+ if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
118
+ logger.error("消息解密失败,错误码: %d", ret_code)
119
+ return "消息解密失败", 400
120
+
121
+ # 调用消息处理器
122
+ response = None
123
+ if self.message_handler:
124
+ try:
125
+ response = await self.message_handler(
126
+ message_data, {"nonce": nonce, "timestamp": timestamp}
127
+ )
128
+ except Exception as e:
129
+ logger.error("消息处理器执行异常: %s", e)
130
+ return "消息处理异常", 500
131
+
132
+ if response:
133
+ return response, 200, {"Content-Type": "text/plain"}
134
+ else:
135
+ return "success", 200, {"Content-Type": "text/plain"}
136
+
137
+ except Exception as e:
138
+ logger.error("处理消息时发生异常: %s", e)
139
+ return "内部服务器错误", 500
140
+
141
+ async def start_server(self):
142
+ """启动服务器"""
143
+ logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
144
+
145
+ try:
146
+ await self.app.run_task(
147
+ host=self.host,
148
+ port=self.port,
149
+ shutdown_trigger=self.shutdown_trigger,
150
+ )
151
+ except Exception as e:
152
+ logger.error("服务器运行异常: %s", e)
153
+ raise
154
+
155
+ async def shutdown_trigger(self):
156
+ """关闭触发器"""
157
+ await self.shutdown_event.wait()
158
+
159
+ async def shutdown(self):
160
+ """关闭服务器"""
161
+ logger.info("企业微信智能机器人服务器正在关闭...")
162
+ self.shutdown_event.set()
163
+
164
+ def get_app(self):
165
+ """获取 Quart 应用实例"""
166
+ return self.app
@@ -0,0 +1,199 @@
1
+ """
2
+ 企业微信智能机器人工具模块
3
+ 提供常量定义、工具函数和辅助方法
4
+ """
5
+
6
+ import string
7
+ import random
8
+ import hashlib
9
+ import base64
10
+ import aiohttp
11
+ import asyncio
12
+ from Crypto.Cipher import AES
13
+ from typing import Any, Tuple
14
+ from astrbot.api import logger
15
+
16
+
17
+ # 常量定义
18
+ class WecomAIBotConstants:
19
+ """企业微信智能机器人常量"""
20
+
21
+ # 消息类型
22
+ MSG_TYPE_TEXT = "text"
23
+ MSG_TYPE_IMAGE = "image"
24
+ MSG_TYPE_MIXED = "mixed"
25
+ MSG_TYPE_STREAM = "stream"
26
+ MSG_TYPE_EVENT = "event"
27
+
28
+ # 流消息状态
29
+ STREAM_CONTINUE = False
30
+ STREAM_FINISH = True
31
+
32
+ # 错误码
33
+ SUCCESS = 0
34
+ DECRYPT_ERROR = -40001
35
+ VALIDATE_SIGNATURE_ERROR = -40002
36
+ PARSE_XML_ERROR = -40003
37
+ COMPUTE_SIGNATURE_ERROR = -40004
38
+ ILLEGAL_AES_KEY = -40005
39
+ VALIDATE_APPID_ERROR = -40006
40
+ ENCRYPT_AES_ERROR = -40007
41
+ ILLEGAL_BUFFER = -40008
42
+
43
+
44
+ def generate_random_string(length: int = 10) -> str:
45
+ """生成随机字符串
46
+
47
+ Args:
48
+ length: 字符串长度,默认为 10
49
+
50
+ Returns:
51
+ 随机字符串
52
+ """
53
+ letters = string.ascii_letters + string.digits
54
+ return "".join(random.choice(letters) for _ in range(length))
55
+
56
+
57
+ def calculate_image_md5(image_data: bytes) -> str:
58
+ """计算图片数据的 MD5 值
59
+
60
+ Args:
61
+ image_data: 图片二进制数据
62
+
63
+ Returns:
64
+ MD5 哈希值(十六进制字符串)
65
+ """
66
+ return hashlib.md5(image_data).hexdigest()
67
+
68
+
69
+ def encode_image_base64(image_data: bytes) -> str:
70
+ """将图片数据编码为 Base64
71
+
72
+ Args:
73
+ image_data: 图片二进制数据
74
+
75
+ Returns:
76
+ Base64 编码的字符串
77
+ """
78
+ return base64.b64encode(image_data).decode("utf-8")
79
+
80
+
81
+ def format_session_id(session_type: str, session_id: str) -> str:
82
+ """格式化会话 ID
83
+
84
+ Args:
85
+ session_type: 会话类型 ("user", "group")
86
+ session_id: 原始会话 ID
87
+
88
+ Returns:
89
+ 格式化后的会话 ID
90
+ """
91
+ return f"wecom_ai_bot_{session_type}_{session_id}"
92
+
93
+
94
+ def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
95
+ """解析格式化的会话 ID
96
+
97
+ Args:
98
+ formatted_session_id: 格式化的会话 ID
99
+
100
+ Returns:
101
+ (会话类型, 原始会话ID)
102
+ """
103
+ parts = formatted_session_id.split("_", 3)
104
+ if (
105
+ len(parts) >= 4
106
+ and parts[0] == "wecom"
107
+ and parts[1] == "ai"
108
+ and parts[2] == "bot"
109
+ ):
110
+ return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
111
+ return "user", formatted_session_id
112
+
113
+
114
+ def safe_json_loads(json_str: str, default: Any = None) -> Any:
115
+ """安全地解析 JSON 字符串
116
+
117
+ Args:
118
+ json_str: JSON 字符串
119
+ default: 解析失败时的默认值
120
+
121
+ Returns:
122
+ 解析结果或默认值
123
+ """
124
+ import json
125
+
126
+ try:
127
+ return json.loads(json_str)
128
+ except (json.JSONDecodeError, TypeError) as e:
129
+ logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
130
+ return default
131
+
132
+
133
+ def format_error_response(error_code: int, error_msg: str) -> str:
134
+ """格式化错误响应
135
+
136
+ Args:
137
+ error_code: 错误码
138
+ error_msg: 错误信息
139
+
140
+ Returns:
141
+ 格式化的错误响应字符串
142
+ """
143
+ return f"Error {error_code}: {error_msg}"
144
+
145
+
146
+ async def process_encrypted_image(
147
+ image_url: str, aes_key_base64: str
148
+ ) -> Tuple[bool, str]:
149
+ """下载并解密加密图片
150
+
151
+ Args:
152
+ image_url: 加密图片的URL
153
+ aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
154
+
155
+ Returns:
156
+ Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
157
+ status 为 False 时 data 是错误信息
158
+ """
159
+ # 1. 下载加密图片
160
+ logger.info("开始下载加密图片: %s", image_url)
161
+ try:
162
+ async with aiohttp.ClientSession() as session:
163
+ async with session.get(image_url, timeout=15) as response:
164
+ response.raise_for_status()
165
+ encrypted_data = await response.read()
166
+ logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
167
+ except (aiohttp.ClientError, asyncio.TimeoutError) as e:
168
+ error_msg = f"下载图片失败: {str(e)}"
169
+ logger.error(error_msg)
170
+ return False, error_msg
171
+
172
+ # 2. 准备AES密钥和IV
173
+ if not aes_key_base64:
174
+ raise ValueError("AES密钥不能为空")
175
+
176
+ # Base64解码密钥 (自动处理填充)
177
+ aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
178
+ if len(aes_key) != 32:
179
+ raise ValueError("无效的AES密钥长度: 应为32字节")
180
+
181
+ iv = aes_key[:16] # 初始向量为密钥前16字节
182
+
183
+ # 3. 解密图片数据
184
+ cipher = AES.new(aes_key, AES.MODE_CBC, iv)
185
+ decrypted_data = cipher.decrypt(encrypted_data)
186
+
187
+ # 4. 去除PKCS#7填充 (Python 3兼容写法)
188
+ pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
189
+ if pad_len > 32: # AES-256块大小为32字节
190
+ raise ValueError("无效的填充长度 (大于32字节)")
191
+
192
+ decrypted_data = decrypted_data[:-pad_len]
193
+ logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
194
+
195
+ # 5. 转换为base64编码
196
+ base64_data = base64.b64encode(decrypted_data).decode("utf-8")
197
+ logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data))
198
+
199
+ return True, base64_data
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import traceback
3
- from typing import List
4
3
 
5
4
  from astrbot.core import logger, sp
6
5
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -28,7 +27,7 @@ class ProviderManager:
28
27
  self.persona_mgr = persona_mgr
29
28
  self.acm = acm
30
29
  config = acm.confs["default"]
31
- self.providers_config: List = config["provider"]
30
+ self.providers_config: list = config["provider"]
32
31
  self.provider_settings: dict = config["provider_settings"]
33
32
  self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
34
33
  self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -36,15 +35,15 @@ class ProviderManager:
36
35
  # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
37
36
  self.default_persona_name = persona_mgr.default_persona
38
37
 
39
- self.provider_insts: List[Provider] = []
38
+ self.provider_insts: list[Provider] = []
40
39
  """加载的 Provider 的实例"""
41
- self.stt_provider_insts: List[STTProvider] = []
40
+ self.stt_provider_insts: list[STTProvider] = []
42
41
  """加载的 Speech To Text Provider 的实例"""
43
- self.tts_provider_insts: List[TTSProvider] = []
42
+ self.tts_provider_insts: list[TTSProvider] = []
44
43
  """加载的 Text To Speech Provider 的实例"""
45
- self.embedding_provider_insts: List[EmbeddingProvider] = []
44
+ self.embedding_provider_insts: list[EmbeddingProvider] = []
46
45
  """加载的 Embedding Provider 的实例"""
47
- self.rerank_provider_insts: List[RerankProvider] = []
46
+ self.rerank_provider_insts: list[RerankProvider] = []
48
47
  """加载的 Rerank Provider 的实例"""
49
48
  self.inst_map: dict[
50
49
  str,
@@ -175,7 +174,11 @@ class ProviderManager:
175
174
  async def initialize(self):
176
175
  # 逐个初始化提供商
177
176
  for provider_config in self.providers_config:
178
- await self.load_provider(provider_config)
177
+ try:
178
+ await self.load_provider(provider_config)
179
+ except Exception as e:
180
+ logger.error(traceback.format_exc())
181
+ logger.error(e)
179
182
 
180
183
  # 设置默认提供商
181
184
  selected_provider_id = sp.get(
@@ -404,10 +407,12 @@ class ProviderManager:
404
407
 
405
408
  self.inst_map[provider_config["id"]] = inst
406
409
  except Exception as e:
407
- logger.error(traceback.format_exc())
408
410
  logger.error(
409
411
  f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
410
412
  )
413
+ raise Exception(
414
+ f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
415
+ )
411
416
 
412
417
  async def reload(self, provider_config: dict):
413
418
  await self.terminate_provider(provider_config["id"])
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ import asyncio
2
3
  from typing import List
3
4
  from typing import AsyncGenerator
4
5
  from astrbot.core.agent.tool import ToolSet
@@ -203,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
203
204
  """获取向量的维度"""
204
205
  ...
205
206
 
207
+ async def get_embeddings_batch(
208
+ self,
209
+ texts: list[str],
210
+ batch_size: int = 16,
211
+ tasks_limit: int = 3,
212
+ max_retries: int = 3,
213
+ progress_callback=None,
214
+ ) -> list[list[float]]:
215
+ """批量获取文本的向量,分批处理以节省内存
216
+
217
+ Args:
218
+ texts: 文本列表
219
+ batch_size: 每批处理的文本数量
220
+ tasks_limit: 并发任务数量限制
221
+ max_retries: 失败时的最大重试次数
222
+ progress_callback: 进度回调函数,接收参数 (current, total)
223
+
224
+ Returns:
225
+ 向量列表
226
+ """
227
+ semaphore = asyncio.Semaphore(tasks_limit)
228
+ all_embeddings: list[list[float]] = []
229
+ failed_batches: list[tuple[int, list[str]]] = []
230
+ completed_count = 0
231
+ total_count = len(texts)
232
+
233
+ async def process_batch(batch_idx: int, batch_texts: list[str]):
234
+ nonlocal completed_count
235
+ async with semaphore:
236
+ for attempt in range(max_retries):
237
+ try:
238
+ batch_embeddings = await self.get_embeddings(batch_texts)
239
+ all_embeddings.extend(batch_embeddings)
240
+ completed_count += len(batch_texts)
241
+ if progress_callback:
242
+ await progress_callback(completed_count, total_count)
243
+ return
244
+ except Exception as e:
245
+ if attempt == max_retries - 1:
246
+ # 最后一次重试失败,记录失败的批次
247
+ failed_batches.append((batch_idx, batch_texts))
248
+ raise Exception(
249
+ f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
250
+ )
251
+ # 等待一段时间后重试,使用指数退避
252
+ await asyncio.sleep(2**attempt)
253
+
254
+ tasks = []
255
+ for i in range(0, len(texts), batch_size):
256
+ batch_texts = texts[i : i + batch_size]
257
+ batch_idx = i // batch_size
258
+ tasks.append(process_batch(batch_idx, batch_texts))
259
+
260
+ # 收集所有任务的结果,包括失败的任务
261
+ results = await asyncio.gather(*tasks, return_exceptions=True)
262
+
263
+ # 检查是否有失败的任务
264
+ errors = [r for r in results if isinstance(r, Exception)]
265
+ if errors:
266
+ error_msg = (
267
+ f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
268
+ )
269
+ raise Exception(error_msg)
270
+
271
+ return all_embeddings
272
+
206
273
 
207
274
  class RerankProvider(AbstractProvider):
208
275
  def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -10,7 +10,7 @@ from anthropic.types import Message
10
10
  from astrbot.core.utils.io import download_image_by_url
11
11
  from astrbot.api.provider import Provider
12
12
  from astrbot import logger
13
- from astrbot.core.provider.func_tool_manager import FuncCall
13
+ from astrbot.core.provider.func_tool_manager import ToolSet
14
14
  from ..register import register_provider_adapter
15
15
  from astrbot.core.provider.entities import LLMResponse
16
16
  from typing import AsyncGenerator
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
104
104
 
105
105
  return system_prompt, new_messages
106
106
 
107
- async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
107
+ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
108
108
  if tools:
109
109
  if tool_list := tools.get_func_desc_anthropic_style():
110
110
  payloads["tools"] = tool_list
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
135
135
  return llm_response
136
136
 
137
137
  async def _query_stream(
138
- self, payloads: dict, tools: FuncCall
138
+ self, payloads: dict, tools: ToolSet | None
139
139
  ) -> AsyncGenerator[LLMResponse, None]:
140
140
  if tools:
141
141
  if tool_list := tools.get_func_desc_anthropic_style():
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
326
326
  async for llm_response in self._query_stream(payloads, func_tool):
327
327
  yield llm_response
328
328
 
329
- async def assemble_context(self, text: str, image_urls: List[str] = None):
329
+ async def assemble_context(self, text: str, image_urls: List[str] | None = None):
330
330
  """组装上下文,支持文本和图片"""
331
331
  if not image_urls:
332
332
  return {"role": "user", "content": text}
@@ -1,15 +1,14 @@
1
1
  import re
2
2
  import asyncio
3
3
  import functools
4
- from typing import List
5
4
  from .. import Provider, Personality
6
5
  from ..entities import LLMResponse
7
- from ..func_tool_manager import FuncCall
8
6
  from ..register import register_provider_adapter
9
7
  from astrbot.core.message.message_event_result import MessageChain
10
8
  from .openai_source import ProviderOpenAIOfficial
11
9
  from astrbot.core import logger, sp
12
10
  from dashscope import Application
11
+ from dashscope.app.application_response import ApplicationResponse
13
12
 
14
13
 
15
14
  @register_provider_adapter("dashscope", "Dashscope APP 适配器。")
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
62
61
  async def text_chat(
63
62
  self,
64
63
  prompt: str,
65
- session_id: str = None,
66
- image_urls: List[str] = [],
67
- func_tool: FuncCall = None,
68
- contexts: List = None,
69
- system_prompt: str = None,
64
+ session_id=None,
65
+ image_urls=[],
66
+ func_tool=None,
67
+ contexts=None,
68
+ system_prompt=None,
70
69
  model=None,
71
70
  **kwargs,
72
71
  ) -> LLMResponse:
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
122
121
  )
123
122
  response = await asyncio.get_event_loop().run_in_executor(None, partial)
124
123
 
124
+ assert isinstance(response, ApplicationResponse)
125
+
125
126
  logger.debug(f"dashscope resp: {response}")
126
127
 
127
128
  if response.status_code != 200:
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
135
136
  ),
136
137
  )
137
138
 
138
- output_text = response.output.get("text", "")
139
+ output_text = response.output.get("text", "") or ""
139
140
  # RAG 引用脚标格式化
140
141
  output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
141
142
  if self.output_reference and response.output.get("doc_references", None):
142
143
  ref_str = ""
143
- for ref in response.output.get("doc_references", []):
144
+ for ref in response.output.get("doc_references", []) or []:
144
145
  ref_title = (
145
146
  ref.get("title", "")
146
147
  if ref.get("title")
@@ -1,9 +1,7 @@
1
1
  import astrbot.core.message.components as Comp
2
2
  import os
3
- from typing import List
4
3
  from .. import Provider
5
4
  from ..entities import LLMResponse
6
- from ..func_tool_manager import FuncCall
7
5
  from ..register import register_provider_adapter
8
6
  from astrbot.core.utils.dify_api_client import DifyAPIClient
9
7
  from astrbot.core.utils.io import download_image_by_url, download_file
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
55
53
  async def text_chat(
56
54
  self,
57
55
  prompt: str,
58
- session_id: str = None,
59
- image_urls: List[str] = None,
60
- func_tool: FuncCall = None,
61
- contexts: List = None,
62
- system_prompt: str = None,
56
+ session_id=None,
57
+ image_urls=None,
58
+ func_tool=None,
59
+ contexts=None,
60
+ system_prompt=None,
63
61
  tool_calls_result=None,
64
62
  model=None,
65
63
  **kwargs,
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
223
221
  # Chat
224
222
  return MessageChain(chain=[Comp.Plain(chunk)])
225
223
 
226
- async def parse_file(item: dict) -> Comp:
224
+ async def parse_file(item: dict):
227
225
  match item["type"]:
228
226
  case "image":
229
227
  return Comp.Image(file=item["url"], url=item["url"])
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
32
32
  self.model = provider_config.get(
33
33
  "embedding_model", "gemini-embedding-exp-03-07"
34
34
  )
35
- self.dimension = provider_config.get("embedding_dimensions", 768)
36
35
 
37
36
  async def get_embedding(self, text: str) -> list[float]:
38
37
  """
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
60
59
 
61
60
  def get_dim(self) -> int:
62
61
  """获取向量的维度"""
63
- return self.dimension
62
+ return self.provider_config.get("embedding_dimensions", 768)
@@ -22,7 +22,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
22
22
  timeout=int(provider_config.get("timeout", 20)),
23
23
  )
24
24
  self.model = provider_config.get("embedding_model", "text-embedding-3-small")
25
- self.dimension = provider_config.get("embedding_dimensions", 1024)
26
25
 
27
26
  async def get_embedding(self, text: str) -> list[float]:
28
27
  """
@@ -40,4 +39,4 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
40
39
 
41
40
  def get_dim(self) -> int:
42
41
  """获取向量的维度"""
43
- return self.dimension
42
+ return self.provider_config.get("embedding_dimensions", 1024)