AstrBot 4.3.5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  2. astrbot/core/astrbot_config_mgr.py +23 -51
  3. astrbot/core/config/default.py +92 -12
  4. astrbot/core/conversation_mgr.py +36 -1
  5. astrbot/core/core_lifecycle.py +24 -5
  6. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  7. astrbot/core/db/vec_db/base.py +33 -2
  8. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  9. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  10. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  11. astrbot/core/file_token_service.py +6 -1
  12. astrbot/core/initial_loader.py +6 -3
  13. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  14. astrbot/core/knowledge_base/chunking/base.py +24 -0
  15. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  16. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  17. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  18. astrbot/core/knowledge_base/kb_helper.py +348 -0
  19. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  20. astrbot/core/knowledge_base/models.py +114 -0
  21. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  22. astrbot/core/knowledge_base/parsers/base.py +50 -0
  23. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  24. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  25. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  26. astrbot/core/knowledge_base/parsers/util.py +13 -0
  27. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  28. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  29. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  30. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  31. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  32. astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
  33. astrbot/core/pipeline/process_stage/utils.py +80 -0
  34. astrbot/core/platform/astr_message_event.py +8 -7
  35. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  36. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  37. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  38. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  39. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  40. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  41. astrbot/core/platform/sources/satori/satori_event.py +270 -99
  42. astrbot/core/provider/manager.py +14 -9
  43. astrbot/core/provider/provider.py +67 -0
  44. astrbot/core/provider/sources/anthropic_source.py +4 -4
  45. astrbot/core/provider/sources/dashscope_source.py +10 -9
  46. astrbot/core/provider/sources/dify_source.py +6 -8
  47. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  48. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  49. astrbot/core/provider/sources/openai_source.py +18 -15
  50. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  51. astrbot/core/star/context.py +3 -0
  52. astrbot/core/star/star.py +6 -0
  53. astrbot/core/star/star_manager.py +13 -7
  54. astrbot/core/umop_config_router.py +81 -0
  55. astrbot/core/updator.py +1 -1
  56. astrbot/core/utils/io.py +23 -12
  57. astrbot/dashboard/routes/__init__.py +2 -0
  58. astrbot/dashboard/routes/config.py +137 -9
  59. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  60. astrbot/dashboard/routes/plugin.py +24 -5
  61. astrbot/dashboard/routes/update.py +1 -1
  62. astrbot/dashboard/server.py +6 -0
  63. astrbot/dashboard/utils.py +161 -0
  64. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/METADATA +29 -13
  65. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/RECORD +68 -44
  66. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  67. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  68. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
2
2
  from astrbot.api import logger
3
3
  from astrbot.api.event import AstrMessageEvent, MessageChain
4
4
  from astrbot.api.platform import AstrBotMessage, PlatformMetadata
5
- from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
5
+ from astrbot.api.message_components import (
6
+ Plain,
7
+ Image,
8
+ At,
9
+ File,
10
+ Record,
11
+ Video,
12
+ Reply,
13
+ Forward,
14
+ Node,
15
+ Nodes,
16
+ )
6
17
 
7
18
  if TYPE_CHECKING:
8
19
  from .satori_adapter import SatoriPlatformAdapter
@@ -48,55 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
48
59
  content_parts = []
49
60
 
50
61
  for component in message.chain:
51
- if isinstance(component, Plain):
52
- text = (
53
- component.text.replace("&", "&")
54
- .replace("<", "&lt;")
55
- .replace(">", "&gt;")
56
- )
57
- content_parts.append(text)
58
-
59
- elif isinstance(component, At):
60
- if component.qq:
61
- content_parts.append(f'<at id="{component.qq}"/>')
62
- elif component.name:
63
- content_parts.append(f'<at name="{component.name}"/>')
64
-
65
- elif isinstance(component, Image):
66
- try:
67
- image_base64 = await component.convert_to_base64()
68
- if image_base64:
69
- content_parts.append(
70
- f'<img src="data:image/jpeg;base64,{image_base64}"/>'
71
- )
72
- except Exception as e:
73
- logger.error(f"图片转换为base64失败: {e}")
74
-
75
- elif isinstance(component, File):
76
- content_parts.append(
77
- f'<file src="{component.file}" name="{component.name or "文件"}"/>'
78
- )
79
-
80
- elif isinstance(component, Record):
81
- try:
82
- record_base64 = await component.convert_to_base64()
83
- if record_base64:
84
- content_parts.append(
85
- f'<audio src="data:audio/wav;base64,{record_base64}"/>'
86
- )
87
- except Exception as e:
88
- logger.error(f"语音转换为base64失败: {e}")
89
-
90
- elif isinstance(component, Reply):
91
- content_parts.append(f'<reply id="{component.id}"/>')
92
-
93
- elif isinstance(component, Video):
94
- try:
95
- video_path_url = await component.convert_to_file_path()
96
- if video_path_url:
97
- content_parts.append(f'<video src="{video_path_url}"/>')
98
- except Exception as e:
99
- logger.error(f"视频文件转换失败: {e}")
62
+ component_content = await cls._convert_component_to_satori_static(
63
+ component
64
+ )
65
+ if component_content:
66
+ content_parts.append(component_content)
67
+
68
+ # 特殊处理 Node 和 Nodes 组件
69
+ if isinstance(component, Node):
70
+ # 单个转发节点
71
+ node_content = await cls._convert_node_to_satori_static(component)
72
+ if node_content:
73
+ content_parts.append(node_content)
74
+
75
+ elif isinstance(component, Nodes):
76
+ # 合并转发消息
77
+ node_content = await cls._convert_nodes_to_satori_static(component)
78
+ if node_content:
79
+ content_parts.append(node_content)
100
80
 
101
81
  content = "".join(content_parts)
102
82
  channel_id = session_id
@@ -138,55 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
138
118
  content_parts = []
139
119
 
140
120
  for component in message.chain:
141
- if isinstance(component, Plain):
142
- text = (
143
- component.text.replace("&", "&amp;")
144
- .replace("<", "&lt;")
145
- .replace(">", "&gt;")
146
- )
147
- content_parts.append(text)
148
-
149
- elif isinstance(component, At):
150
- if component.qq:
151
- content_parts.append(f'<at id="{component.qq}"/>')
152
- elif component.name:
153
- content_parts.append(f'<at name="{component.name}"/>')
154
-
155
- elif isinstance(component, Image):
156
- try:
157
- image_base64 = await component.convert_to_base64()
158
- if image_base64:
159
- content_parts.append(
160
- f'<img src="data:image/jpeg;base64,{image_base64}"/>'
161
- )
162
- except Exception as e:
163
- logger.error(f"图片转换为base64失败: {e}")
164
-
165
- elif isinstance(component, File):
166
- content_parts.append(
167
- f'<file src="{component.file}" name="{component.name or "文件"}"/>'
168
- )
169
-
170
- elif isinstance(component, Record):
171
- try:
172
- record_base64 = await component.convert_to_base64()
173
- if record_base64:
174
- content_parts.append(
175
- f'<audio src="data:audio/wav;base64,{record_base64}"/>'
176
- )
177
- except Exception as e:
178
- logger.error(f"语音转换为base64失败: {e}")
179
-
180
- elif isinstance(component, Reply):
181
- content_parts.append(f'<reply id="{component.id}"/>')
182
-
183
- elif isinstance(component, Video):
184
- try:
185
- video_path_url = await component.convert_to_file_path()
186
- if video_path_url:
187
- content_parts.append(f'<video src="{video_path_url}"/>')
188
- except Exception as e:
189
- logger.error(f"视频文件转换失败: {e}")
121
+ component_content = await self._convert_component_to_satori(component)
122
+ if component_content:
123
+ content_parts.append(component_content)
124
+
125
+ # 特殊处理 Node 和 Nodes 组件
126
+ if isinstance(component, Node):
127
+ # 单个转发节点
128
+ node_content = await self._convert_node_to_satori(component)
129
+ if node_content:
130
+ content_parts.append(node_content)
131
+
132
+ elif isinstance(component, Nodes):
133
+ # 合并转发消息
134
+ node_content = await self._convert_nodes_to_satori(component)
135
+ if node_content:
136
+ content_parts.append(node_content)
190
137
 
191
138
  content = "".join(content_parts)
192
139
  channel_id = self.session_id
@@ -250,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
250
197
  logger.error(f"Satori 流式消息发送异常: {e}")
251
198
 
252
199
  return await super().send_streaming(generator, use_fallback)
200
+
201
+ async def _convert_component_to_satori(self, component) -> str:
202
+ """将单个消息组件转换为 Satori 格式"""
203
+ try:
204
+ if isinstance(component, Plain):
205
+ text = (
206
+ component.text.replace("&", "&amp;")
207
+ .replace("<", "&lt;")
208
+ .replace(">", "&gt;")
209
+ )
210
+ return text
211
+
212
+ elif isinstance(component, At):
213
+ if component.qq:
214
+ return f'<at id="{component.qq}"/>'
215
+ elif component.name:
216
+ return f'<at name="{component.name}"/>'
217
+
218
+ elif isinstance(component, Image):
219
+ try:
220
+ image_base64 = await component.convert_to_base64()
221
+ if image_base64:
222
+ return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
223
+ except Exception as e:
224
+ logger.error(f"图片转换为base64失败: {e}")
225
+
226
+ elif isinstance(component, File):
227
+ return (
228
+ f'<file src="{component.file}" name="{component.name or "文件"}"/>'
229
+ )
230
+
231
+ elif isinstance(component, Record):
232
+ try:
233
+ record_base64 = await component.convert_to_base64()
234
+ if record_base64:
235
+ return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
236
+ except Exception as e:
237
+ logger.error(f"语音转换为base64失败: {e}")
238
+
239
+ elif isinstance(component, Reply):
240
+ return f'<reply id="{component.id}"/>'
241
+
242
+ elif isinstance(component, Video):
243
+ try:
244
+ video_path_url = await component.convert_to_file_path()
245
+ if video_path_url:
246
+ return f'<video src="{video_path_url}"/>'
247
+ except Exception as e:
248
+ logger.error(f"视频文件转换失败: {e}")
249
+
250
+ elif isinstance(component, Forward):
251
+ return f'<message id="{component.id}" forward/>'
252
+
253
+ # 对于其他未处理的组件类型,返回空字符串
254
+ return ""
255
+
256
+ except Exception as e:
257
+ logger.error(f"转换消息组件失败: {e}")
258
+ return ""
259
+
260
+ async def _convert_node_to_satori(self, node: Node) -> str:
261
+ """将单个转发节点转换为 Satori 格式"""
262
+ try:
263
+ content_parts = []
264
+ if node.content:
265
+ for content_component in node.content:
266
+ component_content = await self._convert_component_to_satori(
267
+ content_component
268
+ )
269
+ if component_content:
270
+ content_parts.append(component_content)
271
+
272
+ content = "".join(content_parts)
273
+
274
+ # 如果内容为空,添加默认内容
275
+ if not content.strip():
276
+ content = "[转发消息]"
277
+
278
+ # 构建 Satori 格式的转发节点
279
+ author_attrs = []
280
+ if node.uin:
281
+ author_attrs.append(f'id="{node.uin}"')
282
+ if node.name:
283
+ author_attrs.append(f'name="{node.name}"')
284
+
285
+ author_attr_str = " ".join(author_attrs)
286
+
287
+ return f"<message><author {author_attr_str}/>{content}</message>"
288
+
289
+ except Exception as e:
290
+ logger.error(f"转换转发节点失败: {e}")
291
+ return ""
292
+
293
+ @classmethod
294
+ async def _convert_component_to_satori_static(cls, component) -> str:
295
+ """将单个消息组件转换为 Satori 格式"""
296
+ try:
297
+ if isinstance(component, Plain):
298
+ text = (
299
+ component.text.replace("&", "&amp;")
300
+ .replace("<", "&lt;")
301
+ .replace(">", "&gt;")
302
+ )
303
+ return text
304
+
305
+ elif isinstance(component, At):
306
+ if component.qq:
307
+ return f'<at id="{component.qq}"/>'
308
+ elif component.name:
309
+ return f'<at name="{component.name}"/>'
310
+
311
+ elif isinstance(component, Image):
312
+ try:
313
+ image_base64 = await component.convert_to_base64()
314
+ if image_base64:
315
+ return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
316
+ except Exception as e:
317
+ logger.error(f"图片转换为base64失败: {e}")
318
+
319
+ elif isinstance(component, File):
320
+ return (
321
+ f'<file src="{component.file}" name="{component.name or "文件"}"/>'
322
+ )
323
+
324
+ elif isinstance(component, Record):
325
+ try:
326
+ record_base64 = await component.convert_to_base64()
327
+ if record_base64:
328
+ return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
329
+ except Exception as e:
330
+ logger.error(f"语音转换为base64失败: {e}")
331
+
332
+ elif isinstance(component, Reply):
333
+ return f'<reply id="{component.id}"/>'
334
+
335
+ elif isinstance(component, Video):
336
+ try:
337
+ video_path_url = await component.convert_to_file_path()
338
+ if video_path_url:
339
+ return f'<video src="{video_path_url}"/>'
340
+ except Exception as e:
341
+ logger.error(f"视频文件转换失败: {e}")
342
+
343
+ elif isinstance(component, Forward):
344
+ return f'<message id="{component.id}" forward/>'
345
+
346
+ # 对于其他未处理的组件类型,返回空字符串
347
+ return ""
348
+
349
+ except Exception as e:
350
+ logger.error(f"转换消息组件失败: {e}")
351
+ return ""
352
+
353
+ @classmethod
354
+ async def _convert_node_to_satori_static(cls, node: Node) -> str:
355
+ """将单个转发节点转换为 Satori 格式"""
356
+ try:
357
+ content_parts = []
358
+ if node.content:
359
+ for content_component in node.content:
360
+ component_content = await cls._convert_component_to_satori_static(
361
+ content_component
362
+ )
363
+ if component_content:
364
+ content_parts.append(component_content)
365
+
366
+ content = "".join(content_parts)
367
+
368
+ # 如果内容为空,添加默认内容
369
+ if not content.strip():
370
+ content = "[转发消息]"
371
+
372
+ author_attrs = []
373
+ if node.uin:
374
+ author_attrs.append(f'id="{node.uin}"')
375
+ if node.name:
376
+ author_attrs.append(f'name="{node.name}"')
377
+
378
+ author_attr_str = " ".join(author_attrs)
379
+
380
+ return f"<message><author {author_attr_str}/>{content}</message>"
381
+
382
+ except Exception as e:
383
+ logger.error(f"转换转发节点失败: {e}")
384
+ return ""
385
+
386
+ async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
387
+ """将多个转发节点转换为 Satori 格式的合并转发"""
388
+ try:
389
+ node_parts = []
390
+
391
+ for node in nodes.nodes:
392
+ node_content = await self._convert_node_to_satori(node)
393
+ if node_content:
394
+ node_parts.append(node_content)
395
+
396
+ if node_parts:
397
+ return f"<message forward>{''.join(node_parts)}</message>"
398
+ else:
399
+ return ""
400
+
401
+ except Exception as e:
402
+ logger.error(f"转换合并转发消息失败: {e}")
403
+ return ""
404
+
405
+ @classmethod
406
+ async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
407
+ """将多个转发节点转换为 Satori 格式的合并转发"""
408
+ try:
409
+ node_parts = []
410
+
411
+ for node in nodes.nodes:
412
+ node_content = await cls._convert_node_to_satori_static(node)
413
+ if node_content:
414
+ node_parts.append(node_content)
415
+
416
+ if node_parts:
417
+ return f"<message forward>{''.join(node_parts)}</message>"
418
+ else:
419
+ return ""
420
+
421
+ except Exception as e:
422
+ logger.error(f"转换合并转发消息失败: {e}")
423
+ return ""
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import traceback
3
- from typing import List
4
3
 
5
4
  from astrbot.core import logger, sp
6
5
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -28,7 +27,7 @@ class ProviderManager:
28
27
  self.persona_mgr = persona_mgr
29
28
  self.acm = acm
30
29
  config = acm.confs["default"]
31
- self.providers_config: List = config["provider"]
30
+ self.providers_config: list = config["provider"]
32
31
  self.provider_settings: dict = config["provider_settings"]
33
32
  self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
34
33
  self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -36,15 +35,15 @@ class ProviderManager:
36
35
  # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
37
36
  self.default_persona_name = persona_mgr.default_persona
38
37
 
39
- self.provider_insts: List[Provider] = []
38
+ self.provider_insts: list[Provider] = []
40
39
  """加载的 Provider 的实例"""
41
- self.stt_provider_insts: List[STTProvider] = []
40
+ self.stt_provider_insts: list[STTProvider] = []
42
41
  """加载的 Speech To Text Provider 的实例"""
43
- self.tts_provider_insts: List[TTSProvider] = []
42
+ self.tts_provider_insts: list[TTSProvider] = []
44
43
  """加载的 Text To Speech Provider 的实例"""
45
- self.embedding_provider_insts: List[EmbeddingProvider] = []
44
+ self.embedding_provider_insts: list[EmbeddingProvider] = []
46
45
  """加载的 Embedding Provider 的实例"""
47
- self.rerank_provider_insts: List[RerankProvider] = []
46
+ self.rerank_provider_insts: list[RerankProvider] = []
48
47
  """加载的 Rerank Provider 的实例"""
49
48
  self.inst_map: dict[
50
49
  str,
@@ -175,7 +174,11 @@ class ProviderManager:
175
174
  async def initialize(self):
176
175
  # 逐个初始化提供商
177
176
  for provider_config in self.providers_config:
178
- await self.load_provider(provider_config)
177
+ try:
178
+ await self.load_provider(provider_config)
179
+ except Exception as e:
180
+ logger.error(traceback.format_exc())
181
+ logger.error(e)
179
182
 
180
183
  # 设置默认提供商
181
184
  selected_provider_id = sp.get(
@@ -404,10 +407,12 @@ class ProviderManager:
404
407
 
405
408
  self.inst_map[provider_config["id"]] = inst
406
409
  except Exception as e:
407
- logger.error(traceback.format_exc())
408
410
  logger.error(
409
411
  f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
410
412
  )
413
+ raise Exception(
414
+ f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
415
+ )
411
416
 
412
417
  async def reload(self, provider_config: dict):
413
418
  await self.terminate_provider(provider_config["id"])
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ import asyncio
2
3
  from typing import List
3
4
  from typing import AsyncGenerator
4
5
  from astrbot.core.agent.tool import ToolSet
@@ -203,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
203
204
  """获取向量的维度"""
204
205
  ...
205
206
 
207
+ async def get_embeddings_batch(
208
+ self,
209
+ texts: list[str],
210
+ batch_size: int = 16,
211
+ tasks_limit: int = 3,
212
+ max_retries: int = 3,
213
+ progress_callback=None,
214
+ ) -> list[list[float]]:
215
+ """批量获取文本的向量,分批处理以节省内存
216
+
217
+ Args:
218
+ texts: 文本列表
219
+ batch_size: 每批处理的文本数量
220
+ tasks_limit: 并发任务数量限制
221
+ max_retries: 失败时的最大重试次数
222
+ progress_callback: 进度回调函数,接收参数 (current, total)
223
+
224
+ Returns:
225
+ 向量列表
226
+ """
227
+ semaphore = asyncio.Semaphore(tasks_limit)
228
+ all_embeddings: list[list[float]] = []
229
+ failed_batches: list[tuple[int, list[str]]] = []
230
+ completed_count = 0
231
+ total_count = len(texts)
232
+
233
+ async def process_batch(batch_idx: int, batch_texts: list[str]):
234
+ nonlocal completed_count
235
+ async with semaphore:
236
+ for attempt in range(max_retries):
237
+ try:
238
+ batch_embeddings = await self.get_embeddings(batch_texts)
239
+ all_embeddings.extend(batch_embeddings)
240
+ completed_count += len(batch_texts)
241
+ if progress_callback:
242
+ await progress_callback(completed_count, total_count)
243
+ return
244
+ except Exception as e:
245
+ if attempt == max_retries - 1:
246
+ # 最后一次重试失败,记录失败的批次
247
+ failed_batches.append((batch_idx, batch_texts))
248
+ raise Exception(
249
+ f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
250
+ )
251
+ # 等待一段时间后重试,使用指数退避
252
+ await asyncio.sleep(2**attempt)
253
+
254
+ tasks = []
255
+ for i in range(0, len(texts), batch_size):
256
+ batch_texts = texts[i : i + batch_size]
257
+ batch_idx = i // batch_size
258
+ tasks.append(process_batch(batch_idx, batch_texts))
259
+
260
+ # 收集所有任务的结果,包括失败的任务
261
+ results = await asyncio.gather(*tasks, return_exceptions=True)
262
+
263
+ # 检查是否有失败的任务
264
+ errors = [r for r in results if isinstance(r, Exception)]
265
+ if errors:
266
+ error_msg = (
267
+ f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
268
+ )
269
+ raise Exception(error_msg)
270
+
271
+ return all_embeddings
272
+
206
273
 
207
274
  class RerankProvider(AbstractProvider):
208
275
  def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -10,7 +10,7 @@ from anthropic.types import Message
10
10
  from astrbot.core.utils.io import download_image_by_url
11
11
  from astrbot.api.provider import Provider
12
12
  from astrbot import logger
13
- from astrbot.core.provider.func_tool_manager import FuncCall
13
+ from astrbot.core.provider.func_tool_manager import ToolSet
14
14
  from ..register import register_provider_adapter
15
15
  from astrbot.core.provider.entities import LLMResponse
16
16
  from typing import AsyncGenerator
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
104
104
 
105
105
  return system_prompt, new_messages
106
106
 
107
- async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
107
+ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
108
108
  if tools:
109
109
  if tool_list := tools.get_func_desc_anthropic_style():
110
110
  payloads["tools"] = tool_list
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
135
135
  return llm_response
136
136
 
137
137
  async def _query_stream(
138
- self, payloads: dict, tools: FuncCall
138
+ self, payloads: dict, tools: ToolSet | None
139
139
  ) -> AsyncGenerator[LLMResponse, None]:
140
140
  if tools:
141
141
  if tool_list := tools.get_func_desc_anthropic_style():
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
326
326
  async for llm_response in self._query_stream(payloads, func_tool):
327
327
  yield llm_response
328
328
 
329
- async def assemble_context(self, text: str, image_urls: List[str] = None):
329
+ async def assemble_context(self, text: str, image_urls: List[str] | None = None):
330
330
  """组装上下文,支持文本和图片"""
331
331
  if not image_urls:
332
332
  return {"role": "user", "content": text}
@@ -1,15 +1,14 @@
1
1
  import re
2
2
  import asyncio
3
3
  import functools
4
- from typing import List
5
4
  from .. import Provider, Personality
6
5
  from ..entities import LLMResponse
7
- from ..func_tool_manager import FuncCall
8
6
  from ..register import register_provider_adapter
9
7
  from astrbot.core.message.message_event_result import MessageChain
10
8
  from .openai_source import ProviderOpenAIOfficial
11
9
  from astrbot.core import logger, sp
12
10
  from dashscope import Application
11
+ from dashscope.app.application_response import ApplicationResponse
13
12
 
14
13
 
15
14
  @register_provider_adapter("dashscope", "Dashscope APP 适配器。")
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
62
61
  async def text_chat(
63
62
  self,
64
63
  prompt: str,
65
- session_id: str = None,
66
- image_urls: List[str] = [],
67
- func_tool: FuncCall = None,
68
- contexts: List = None,
69
- system_prompt: str = None,
64
+ session_id=None,
65
+ image_urls=[],
66
+ func_tool=None,
67
+ contexts=None,
68
+ system_prompt=None,
70
69
  model=None,
71
70
  **kwargs,
72
71
  ) -> LLMResponse:
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
122
121
  )
123
122
  response = await asyncio.get_event_loop().run_in_executor(None, partial)
124
123
 
124
+ assert isinstance(response, ApplicationResponse)
125
+
125
126
  logger.debug(f"dashscope resp: {response}")
126
127
 
127
128
  if response.status_code != 200:
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
135
136
  ),
136
137
  )
137
138
 
138
- output_text = response.output.get("text", "")
139
+ output_text = response.output.get("text", "") or ""
139
140
  # RAG 引用脚标格式化
140
141
  output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
141
142
  if self.output_reference and response.output.get("doc_references", None):
142
143
  ref_str = ""
143
- for ref in response.output.get("doc_references", []):
144
+ for ref in response.output.get("doc_references", []) or []:
144
145
  ref_title = (
145
146
  ref.get("title", "")
146
147
  if ref.get("title")
@@ -1,9 +1,7 @@
1
1
  import astrbot.core.message.components as Comp
2
2
  import os
3
- from typing import List
4
3
  from .. import Provider
5
4
  from ..entities import LLMResponse
6
- from ..func_tool_manager import FuncCall
7
5
  from ..register import register_provider_adapter
8
6
  from astrbot.core.utils.dify_api_client import DifyAPIClient
9
7
  from astrbot.core.utils.io import download_image_by_url, download_file
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
55
53
  async def text_chat(
56
54
  self,
57
55
  prompt: str,
58
- session_id: str = None,
59
- image_urls: List[str] = None,
60
- func_tool: FuncCall = None,
61
- contexts: List = None,
62
- system_prompt: str = None,
56
+ session_id=None,
57
+ image_urls=None,
58
+ func_tool=None,
59
+ contexts=None,
60
+ system_prompt=None,
63
61
  tool_calls_result=None,
64
62
  model=None,
65
63
  **kwargs,
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
223
221
  # Chat
224
222
  return MessageChain(chain=[Comp.Plain(chunk)])
225
223
 
226
- async def parse_file(item: dict) -> Comp:
224
+ async def parse_file(item: dict):
227
225
  match item["type"]:
228
226
  case "image":
229
227
  return Comp.Image(file=item["url"], url=item["url"])
@@ -32,7 +32,6 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
32
32
  self.model = provider_config.get(
33
33
  "embedding_model", "gemini-embedding-exp-03-07"
34
34
  )
35
- self.dimension = provider_config.get("embedding_dimensions", 768)
36
35
 
37
36
  async def get_embedding(self, text: str) -> list[float]:
38
37
  """
@@ -60,4 +59,4 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
60
59
 
61
60
  def get_dim(self) -> int:
62
61
  """获取向量的维度"""
63
- return self.dimension
62
+ return self.provider_config.get("embedding_dimensions", 768)