AstrBot 4.3.5__py3-none-any.whl → 4.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +132 -12
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/helper.py +6 -3
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +5 -2
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -99
- astrbot/core/provider/manager.py +22 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +43 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/provider/sources/xinference_rerank_source.py +108 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +187 -0
- astrbot/core/star/context.py +19 -13
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/METADATA +30 -13
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/RECORD +72 -46
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/WHEEL +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.5.dist-info → astrbot-4.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
|
|
|
2
2
|
from astrbot.api import logger
|
|
3
3
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
|
4
4
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
5
|
-
from astrbot.api.message_components import
|
|
5
|
+
from astrbot.api.message_components import (
|
|
6
|
+
Plain,
|
|
7
|
+
Image,
|
|
8
|
+
At,
|
|
9
|
+
File,
|
|
10
|
+
Record,
|
|
11
|
+
Video,
|
|
12
|
+
Reply,
|
|
13
|
+
Forward,
|
|
14
|
+
Node,
|
|
15
|
+
Nodes,
|
|
16
|
+
)
|
|
6
17
|
|
|
7
18
|
if TYPE_CHECKING:
|
|
8
19
|
from .satori_adapter import SatoriPlatformAdapter
|
|
@@ -48,55 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|
|
48
59
|
content_parts = []
|
|
49
60
|
|
|
50
61
|
for component in message.chain:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
content_parts.append(
|
|
70
|
-
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
71
|
-
)
|
|
72
|
-
except Exception as e:
|
|
73
|
-
logger.error(f"图片转换为base64失败: {e}")
|
|
74
|
-
|
|
75
|
-
elif isinstance(component, File):
|
|
76
|
-
content_parts.append(
|
|
77
|
-
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
elif isinstance(component, Record):
|
|
81
|
-
try:
|
|
82
|
-
record_base64 = await component.convert_to_base64()
|
|
83
|
-
if record_base64:
|
|
84
|
-
content_parts.append(
|
|
85
|
-
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
86
|
-
)
|
|
87
|
-
except Exception as e:
|
|
88
|
-
logger.error(f"语音转换为base64失败: {e}")
|
|
89
|
-
|
|
90
|
-
elif isinstance(component, Reply):
|
|
91
|
-
content_parts.append(f'<reply id="{component.id}"/>')
|
|
92
|
-
|
|
93
|
-
elif isinstance(component, Video):
|
|
94
|
-
try:
|
|
95
|
-
video_path_url = await component.convert_to_file_path()
|
|
96
|
-
if video_path_url:
|
|
97
|
-
content_parts.append(f'<video src="{video_path_url}"/>')
|
|
98
|
-
except Exception as e:
|
|
99
|
-
logger.error(f"视频文件转换失败: {e}")
|
|
62
|
+
component_content = await cls._convert_component_to_satori_static(
|
|
63
|
+
component
|
|
64
|
+
)
|
|
65
|
+
if component_content:
|
|
66
|
+
content_parts.append(component_content)
|
|
67
|
+
|
|
68
|
+
# 特殊处理 Node 和 Nodes 组件
|
|
69
|
+
if isinstance(component, Node):
|
|
70
|
+
# 单个转发节点
|
|
71
|
+
node_content = await cls._convert_node_to_satori_static(component)
|
|
72
|
+
if node_content:
|
|
73
|
+
content_parts.append(node_content)
|
|
74
|
+
|
|
75
|
+
elif isinstance(component, Nodes):
|
|
76
|
+
# 合并转发消息
|
|
77
|
+
node_content = await cls._convert_nodes_to_satori_static(component)
|
|
78
|
+
if node_content:
|
|
79
|
+
content_parts.append(node_content)
|
|
100
80
|
|
|
101
81
|
content = "".join(content_parts)
|
|
102
82
|
channel_id = session_id
|
|
@@ -138,55 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|
|
138
118
|
content_parts = []
|
|
139
119
|
|
|
140
120
|
for component in message.chain:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
image_base64 = await component.convert_to_base64()
|
|
158
|
-
if image_base64:
|
|
159
|
-
content_parts.append(
|
|
160
|
-
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
161
|
-
)
|
|
162
|
-
except Exception as e:
|
|
163
|
-
logger.error(f"图片转换为base64失败: {e}")
|
|
164
|
-
|
|
165
|
-
elif isinstance(component, File):
|
|
166
|
-
content_parts.append(
|
|
167
|
-
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
elif isinstance(component, Record):
|
|
171
|
-
try:
|
|
172
|
-
record_base64 = await component.convert_to_base64()
|
|
173
|
-
if record_base64:
|
|
174
|
-
content_parts.append(
|
|
175
|
-
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
176
|
-
)
|
|
177
|
-
except Exception as e:
|
|
178
|
-
logger.error(f"语音转换为base64失败: {e}")
|
|
179
|
-
|
|
180
|
-
elif isinstance(component, Reply):
|
|
181
|
-
content_parts.append(f'<reply id="{component.id}"/>')
|
|
182
|
-
|
|
183
|
-
elif isinstance(component, Video):
|
|
184
|
-
try:
|
|
185
|
-
video_path_url = await component.convert_to_file_path()
|
|
186
|
-
if video_path_url:
|
|
187
|
-
content_parts.append(f'<video src="{video_path_url}"/>')
|
|
188
|
-
except Exception as e:
|
|
189
|
-
logger.error(f"视频文件转换失败: {e}")
|
|
121
|
+
component_content = await self._convert_component_to_satori(component)
|
|
122
|
+
if component_content:
|
|
123
|
+
content_parts.append(component_content)
|
|
124
|
+
|
|
125
|
+
# 特殊处理 Node 和 Nodes 组件
|
|
126
|
+
if isinstance(component, Node):
|
|
127
|
+
# 单个转发节点
|
|
128
|
+
node_content = await self._convert_node_to_satori(component)
|
|
129
|
+
if node_content:
|
|
130
|
+
content_parts.append(node_content)
|
|
131
|
+
|
|
132
|
+
elif isinstance(component, Nodes):
|
|
133
|
+
# 合并转发消息
|
|
134
|
+
node_content = await self._convert_nodes_to_satori(component)
|
|
135
|
+
if node_content:
|
|
136
|
+
content_parts.append(node_content)
|
|
190
137
|
|
|
191
138
|
content = "".join(content_parts)
|
|
192
139
|
channel_id = self.session_id
|
|
@@ -250,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|
|
250
197
|
logger.error(f"Satori 流式消息发送异常: {e}")
|
|
251
198
|
|
|
252
199
|
return await super().send_streaming(generator, use_fallback)
|
|
200
|
+
|
|
201
|
+
async def _convert_component_to_satori(self, component) -> str:
|
|
202
|
+
"""将单个消息组件转换为 Satori 格式"""
|
|
203
|
+
try:
|
|
204
|
+
if isinstance(component, Plain):
|
|
205
|
+
text = (
|
|
206
|
+
component.text.replace("&", "&")
|
|
207
|
+
.replace("<", "<")
|
|
208
|
+
.replace(">", ">")
|
|
209
|
+
)
|
|
210
|
+
return text
|
|
211
|
+
|
|
212
|
+
elif isinstance(component, At):
|
|
213
|
+
if component.qq:
|
|
214
|
+
return f'<at id="{component.qq}"/>'
|
|
215
|
+
elif component.name:
|
|
216
|
+
return f'<at name="{component.name}"/>'
|
|
217
|
+
|
|
218
|
+
elif isinstance(component, Image):
|
|
219
|
+
try:
|
|
220
|
+
image_base64 = await component.convert_to_base64()
|
|
221
|
+
if image_base64:
|
|
222
|
+
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"图片转换为base64失败: {e}")
|
|
225
|
+
|
|
226
|
+
elif isinstance(component, File):
|
|
227
|
+
return (
|
|
228
|
+
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
elif isinstance(component, Record):
|
|
232
|
+
try:
|
|
233
|
+
record_base64 = await component.convert_to_base64()
|
|
234
|
+
if record_base64:
|
|
235
|
+
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(f"语音转换为base64失败: {e}")
|
|
238
|
+
|
|
239
|
+
elif isinstance(component, Reply):
|
|
240
|
+
return f'<reply id="{component.id}"/>'
|
|
241
|
+
|
|
242
|
+
elif isinstance(component, Video):
|
|
243
|
+
try:
|
|
244
|
+
video_path_url = await component.convert_to_file_path()
|
|
245
|
+
if video_path_url:
|
|
246
|
+
return f'<video src="{video_path_url}"/>'
|
|
247
|
+
except Exception as e:
|
|
248
|
+
logger.error(f"视频文件转换失败: {e}")
|
|
249
|
+
|
|
250
|
+
elif isinstance(component, Forward):
|
|
251
|
+
return f'<message id="{component.id}" forward/>'
|
|
252
|
+
|
|
253
|
+
# 对于其他未处理的组件类型,返回空字符串
|
|
254
|
+
return ""
|
|
255
|
+
|
|
256
|
+
except Exception as e:
|
|
257
|
+
logger.error(f"转换消息组件失败: {e}")
|
|
258
|
+
return ""
|
|
259
|
+
|
|
260
|
+
async def _convert_node_to_satori(self, node: Node) -> str:
|
|
261
|
+
"""将单个转发节点转换为 Satori 格式"""
|
|
262
|
+
try:
|
|
263
|
+
content_parts = []
|
|
264
|
+
if node.content:
|
|
265
|
+
for content_component in node.content:
|
|
266
|
+
component_content = await self._convert_component_to_satori(
|
|
267
|
+
content_component
|
|
268
|
+
)
|
|
269
|
+
if component_content:
|
|
270
|
+
content_parts.append(component_content)
|
|
271
|
+
|
|
272
|
+
content = "".join(content_parts)
|
|
273
|
+
|
|
274
|
+
# 如果内容为空,添加默认内容
|
|
275
|
+
if not content.strip():
|
|
276
|
+
content = "[转发消息]"
|
|
277
|
+
|
|
278
|
+
# 构建 Satori 格式的转发节点
|
|
279
|
+
author_attrs = []
|
|
280
|
+
if node.uin:
|
|
281
|
+
author_attrs.append(f'id="{node.uin}"')
|
|
282
|
+
if node.name:
|
|
283
|
+
author_attrs.append(f'name="{node.name}"')
|
|
284
|
+
|
|
285
|
+
author_attr_str = " ".join(author_attrs)
|
|
286
|
+
|
|
287
|
+
return f"<message><author {author_attr_str}/>{content}</message>"
|
|
288
|
+
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.error(f"转换转发节点失败: {e}")
|
|
291
|
+
return ""
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
async def _convert_component_to_satori_static(cls, component) -> str:
|
|
295
|
+
"""将单个消息组件转换为 Satori 格式"""
|
|
296
|
+
try:
|
|
297
|
+
if isinstance(component, Plain):
|
|
298
|
+
text = (
|
|
299
|
+
component.text.replace("&", "&")
|
|
300
|
+
.replace("<", "<")
|
|
301
|
+
.replace(">", ">")
|
|
302
|
+
)
|
|
303
|
+
return text
|
|
304
|
+
|
|
305
|
+
elif isinstance(component, At):
|
|
306
|
+
if component.qq:
|
|
307
|
+
return f'<at id="{component.qq}"/>'
|
|
308
|
+
elif component.name:
|
|
309
|
+
return f'<at name="{component.name}"/>'
|
|
310
|
+
|
|
311
|
+
elif isinstance(component, Image):
|
|
312
|
+
try:
|
|
313
|
+
image_base64 = await component.convert_to_base64()
|
|
314
|
+
if image_base64:
|
|
315
|
+
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
|
316
|
+
except Exception as e:
|
|
317
|
+
logger.error(f"图片转换为base64失败: {e}")
|
|
318
|
+
|
|
319
|
+
elif isinstance(component, File):
|
|
320
|
+
return (
|
|
321
|
+
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
elif isinstance(component, Record):
|
|
325
|
+
try:
|
|
326
|
+
record_base64 = await component.convert_to_base64()
|
|
327
|
+
if record_base64:
|
|
328
|
+
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"语音转换为base64失败: {e}")
|
|
331
|
+
|
|
332
|
+
elif isinstance(component, Reply):
|
|
333
|
+
return f'<reply id="{component.id}"/>'
|
|
334
|
+
|
|
335
|
+
elif isinstance(component, Video):
|
|
336
|
+
try:
|
|
337
|
+
video_path_url = await component.convert_to_file_path()
|
|
338
|
+
if video_path_url:
|
|
339
|
+
return f'<video src="{video_path_url}"/>'
|
|
340
|
+
except Exception as e:
|
|
341
|
+
logger.error(f"视频文件转换失败: {e}")
|
|
342
|
+
|
|
343
|
+
elif isinstance(component, Forward):
|
|
344
|
+
return f'<message id="{component.id}" forward/>'
|
|
345
|
+
|
|
346
|
+
# 对于其他未处理的组件类型,返回空字符串
|
|
347
|
+
return ""
|
|
348
|
+
|
|
349
|
+
except Exception as e:
|
|
350
|
+
logger.error(f"转换消息组件失败: {e}")
|
|
351
|
+
return ""
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
async def _convert_node_to_satori_static(cls, node: Node) -> str:
|
|
355
|
+
"""将单个转发节点转换为 Satori 格式"""
|
|
356
|
+
try:
|
|
357
|
+
content_parts = []
|
|
358
|
+
if node.content:
|
|
359
|
+
for content_component in node.content:
|
|
360
|
+
component_content = await cls._convert_component_to_satori_static(
|
|
361
|
+
content_component
|
|
362
|
+
)
|
|
363
|
+
if component_content:
|
|
364
|
+
content_parts.append(component_content)
|
|
365
|
+
|
|
366
|
+
content = "".join(content_parts)
|
|
367
|
+
|
|
368
|
+
# 如果内容为空,添加默认内容
|
|
369
|
+
if not content.strip():
|
|
370
|
+
content = "[转发消息]"
|
|
371
|
+
|
|
372
|
+
author_attrs = []
|
|
373
|
+
if node.uin:
|
|
374
|
+
author_attrs.append(f'id="{node.uin}"')
|
|
375
|
+
if node.name:
|
|
376
|
+
author_attrs.append(f'name="{node.name}"')
|
|
377
|
+
|
|
378
|
+
author_attr_str = " ".join(author_attrs)
|
|
379
|
+
|
|
380
|
+
return f"<message><author {author_attr_str}/>{content}</message>"
|
|
381
|
+
|
|
382
|
+
except Exception as e:
|
|
383
|
+
logger.error(f"转换转发节点失败: {e}")
|
|
384
|
+
return ""
|
|
385
|
+
|
|
386
|
+
async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
|
|
387
|
+
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
|
388
|
+
try:
|
|
389
|
+
node_parts = []
|
|
390
|
+
|
|
391
|
+
for node in nodes.nodes:
|
|
392
|
+
node_content = await self._convert_node_to_satori(node)
|
|
393
|
+
if node_content:
|
|
394
|
+
node_parts.append(node_content)
|
|
395
|
+
|
|
396
|
+
if node_parts:
|
|
397
|
+
return f"<message forward>{''.join(node_parts)}</message>"
|
|
398
|
+
else:
|
|
399
|
+
return ""
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
logger.error(f"转换合并转发消息失败: {e}")
|
|
403
|
+
return ""
|
|
404
|
+
|
|
405
|
+
@classmethod
|
|
406
|
+
async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
|
|
407
|
+
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
|
408
|
+
try:
|
|
409
|
+
node_parts = []
|
|
410
|
+
|
|
411
|
+
for node in nodes.nodes:
|
|
412
|
+
node_content = await cls._convert_node_to_satori_static(node)
|
|
413
|
+
if node_content:
|
|
414
|
+
node_parts.append(node_content)
|
|
415
|
+
|
|
416
|
+
if node_parts:
|
|
417
|
+
return f"<message forward>{''.join(node_parts)}</message>"
|
|
418
|
+
else:
|
|
419
|
+
return ""
|
|
420
|
+
|
|
421
|
+
except Exception as e:
|
|
422
|
+
logger.error(f"转换合并转发消息失败: {e}")
|
|
423
|
+
return ""
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import traceback
|
|
3
|
-
from typing import List
|
|
4
3
|
|
|
5
4
|
from astrbot.core import logger, sp
|
|
6
5
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
@@ -28,7 +27,7 @@ class ProviderManager:
|
|
|
28
27
|
self.persona_mgr = persona_mgr
|
|
29
28
|
self.acm = acm
|
|
30
29
|
config = acm.confs["default"]
|
|
31
|
-
self.providers_config:
|
|
30
|
+
self.providers_config: list = config["provider"]
|
|
32
31
|
self.provider_settings: dict = config["provider_settings"]
|
|
33
32
|
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
|
34
33
|
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
|
@@ -36,15 +35,15 @@ class ProviderManager:
|
|
|
36
35
|
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
|
37
36
|
self.default_persona_name = persona_mgr.default_persona
|
|
38
37
|
|
|
39
|
-
self.provider_insts:
|
|
38
|
+
self.provider_insts: list[Provider] = []
|
|
40
39
|
"""加载的 Provider 的实例"""
|
|
41
|
-
self.stt_provider_insts:
|
|
40
|
+
self.stt_provider_insts: list[STTProvider] = []
|
|
42
41
|
"""加载的 Speech To Text Provider 的实例"""
|
|
43
|
-
self.tts_provider_insts:
|
|
42
|
+
self.tts_provider_insts: list[TTSProvider] = []
|
|
44
43
|
"""加载的 Text To Speech Provider 的实例"""
|
|
45
|
-
self.embedding_provider_insts:
|
|
44
|
+
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
|
46
45
|
"""加载的 Embedding Provider 的实例"""
|
|
47
|
-
self.rerank_provider_insts:
|
|
46
|
+
self.rerank_provider_insts: list[RerankProvider] = []
|
|
48
47
|
"""加载的 Rerank Provider 的实例"""
|
|
49
48
|
self.inst_map: dict[
|
|
50
49
|
str,
|
|
@@ -175,7 +174,11 @@ class ProviderManager:
|
|
|
175
174
|
async def initialize(self):
|
|
176
175
|
# 逐个初始化提供商
|
|
177
176
|
for provider_config in self.providers_config:
|
|
178
|
-
|
|
177
|
+
try:
|
|
178
|
+
await self.load_provider(provider_config)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.error(traceback.format_exc())
|
|
181
|
+
logger.error(e)
|
|
179
182
|
|
|
180
183
|
# 设置默认提供商
|
|
181
184
|
selected_provider_id = sp.get(
|
|
@@ -256,6 +259,10 @@ class ProviderManager:
|
|
|
256
259
|
from .sources.whisper_selfhosted_source import (
|
|
257
260
|
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
|
258
261
|
)
|
|
262
|
+
case "xinference_stt":
|
|
263
|
+
from .sources.xinference_stt_provider import (
|
|
264
|
+
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
|
265
|
+
)
|
|
259
266
|
case "openai_tts_api":
|
|
260
267
|
from .sources.openai_tts_api_source import (
|
|
261
268
|
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
@@ -308,6 +315,10 @@ class ProviderManager:
|
|
|
308
315
|
from .sources.vllm_rerank_source import (
|
|
309
316
|
VLLMRerankProvider as VLLMRerankProvider,
|
|
310
317
|
)
|
|
318
|
+
case "xinference_rerank":
|
|
319
|
+
from .sources.xinference_rerank_source import (
|
|
320
|
+
XinferenceRerankProvider as XinferenceRerankProvider,
|
|
321
|
+
)
|
|
311
322
|
except (ImportError, ModuleNotFoundError) as e:
|
|
312
323
|
logger.critical(
|
|
313
324
|
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
|
@@ -404,10 +415,12 @@ class ProviderManager:
|
|
|
404
415
|
|
|
405
416
|
self.inst_map[provider_config["id"]] = inst
|
|
406
417
|
except Exception as e:
|
|
407
|
-
logger.error(traceback.format_exc())
|
|
408
418
|
logger.error(
|
|
409
419
|
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
410
420
|
)
|
|
421
|
+
raise Exception(
|
|
422
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
423
|
+
)
|
|
411
424
|
|
|
412
425
|
async def reload(self, provider_config: dict):
|
|
413
426
|
await self.terminate_provider(provider_config["id"])
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import asyncio
|
|
2
3
|
from typing import List
|
|
3
4
|
from typing import AsyncGenerator
|
|
4
5
|
from astrbot.core.agent.tool import ToolSet
|
|
@@ -203,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
|
|
|
203
204
|
"""获取向量的维度"""
|
|
204
205
|
...
|
|
205
206
|
|
|
207
|
+
async def get_embeddings_batch(
|
|
208
|
+
self,
|
|
209
|
+
texts: list[str],
|
|
210
|
+
batch_size: int = 16,
|
|
211
|
+
tasks_limit: int = 3,
|
|
212
|
+
max_retries: int = 3,
|
|
213
|
+
progress_callback=None,
|
|
214
|
+
) -> list[list[float]]:
|
|
215
|
+
"""批量获取文本的向量,分批处理以节省内存
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
texts: 文本列表
|
|
219
|
+
batch_size: 每批处理的文本数量
|
|
220
|
+
tasks_limit: 并发任务数量限制
|
|
221
|
+
max_retries: 失败时的最大重试次数
|
|
222
|
+
progress_callback: 进度回调函数,接收参数 (current, total)
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
向量列表
|
|
226
|
+
"""
|
|
227
|
+
semaphore = asyncio.Semaphore(tasks_limit)
|
|
228
|
+
all_embeddings: list[list[float]] = []
|
|
229
|
+
failed_batches: list[tuple[int, list[str]]] = []
|
|
230
|
+
completed_count = 0
|
|
231
|
+
total_count = len(texts)
|
|
232
|
+
|
|
233
|
+
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
|
234
|
+
nonlocal completed_count
|
|
235
|
+
async with semaphore:
|
|
236
|
+
for attempt in range(max_retries):
|
|
237
|
+
try:
|
|
238
|
+
batch_embeddings = await self.get_embeddings(batch_texts)
|
|
239
|
+
all_embeddings.extend(batch_embeddings)
|
|
240
|
+
completed_count += len(batch_texts)
|
|
241
|
+
if progress_callback:
|
|
242
|
+
await progress_callback(completed_count, total_count)
|
|
243
|
+
return
|
|
244
|
+
except Exception as e:
|
|
245
|
+
if attempt == max_retries - 1:
|
|
246
|
+
# 最后一次重试失败,记录失败的批次
|
|
247
|
+
failed_batches.append((batch_idx, batch_texts))
|
|
248
|
+
raise Exception(
|
|
249
|
+
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
|
|
250
|
+
)
|
|
251
|
+
# 等待一段时间后重试,使用指数退避
|
|
252
|
+
await asyncio.sleep(2**attempt)
|
|
253
|
+
|
|
254
|
+
tasks = []
|
|
255
|
+
for i in range(0, len(texts), batch_size):
|
|
256
|
+
batch_texts = texts[i : i + batch_size]
|
|
257
|
+
batch_idx = i // batch_size
|
|
258
|
+
tasks.append(process_batch(batch_idx, batch_texts))
|
|
259
|
+
|
|
260
|
+
# 收集所有任务的结果,包括失败的任务
|
|
261
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
262
|
+
|
|
263
|
+
# 检查是否有失败的任务
|
|
264
|
+
errors = [r for r in results if isinstance(r, Exception)]
|
|
265
|
+
if errors:
|
|
266
|
+
error_msg = (
|
|
267
|
+
f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
|
|
268
|
+
)
|
|
269
|
+
raise Exception(error_msg)
|
|
270
|
+
|
|
271
|
+
return all_embeddings
|
|
272
|
+
|
|
206
273
|
|
|
207
274
|
class RerankProvider(AbstractProvider):
|
|
208
275
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
@@ -10,7 +10,7 @@ from anthropic.types import Message
|
|
|
10
10
|
from astrbot.core.utils.io import download_image_by_url
|
|
11
11
|
from astrbot.api.provider import Provider
|
|
12
12
|
from astrbot import logger
|
|
13
|
-
from astrbot.core.provider.func_tool_manager import
|
|
13
|
+
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
14
14
|
from ..register import register_provider_adapter
|
|
15
15
|
from astrbot.core.provider.entities import LLMResponse
|
|
16
16
|
from typing import AsyncGenerator
|
|
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
|
|
|
104
104
|
|
|
105
105
|
return system_prompt, new_messages
|
|
106
106
|
|
|
107
|
-
async def _query(self, payloads: dict, tools:
|
|
107
|
+
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
|
108
108
|
if tools:
|
|
109
109
|
if tool_list := tools.get_func_desc_anthropic_style():
|
|
110
110
|
payloads["tools"] = tool_list
|
|
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
|
|
|
135
135
|
return llm_response
|
|
136
136
|
|
|
137
137
|
async def _query_stream(
|
|
138
|
-
self, payloads: dict, tools:
|
|
138
|
+
self, payloads: dict, tools: ToolSet | None
|
|
139
139
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
140
140
|
if tools:
|
|
141
141
|
if tool_list := tools.get_func_desc_anthropic_style():
|
|
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
|
|
|
326
326
|
async for llm_response in self._query_stream(payloads, func_tool):
|
|
327
327
|
yield llm_response
|
|
328
328
|
|
|
329
|
-
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
|
329
|
+
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
|
|
330
330
|
"""组装上下文,支持文本和图片"""
|
|
331
331
|
if not image_urls:
|
|
332
332
|
return {"role": "user", "content": text}
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import asyncio
|
|
3
3
|
import functools
|
|
4
|
-
from typing import List
|
|
5
4
|
from .. import Provider, Personality
|
|
6
5
|
from ..entities import LLMResponse
|
|
7
|
-
from ..func_tool_manager import FuncCall
|
|
8
6
|
from ..register import register_provider_adapter
|
|
9
7
|
from astrbot.core.message.message_event_result import MessageChain
|
|
10
8
|
from .openai_source import ProviderOpenAIOfficial
|
|
11
9
|
from astrbot.core import logger, sp
|
|
12
10
|
from dashscope import Application
|
|
11
|
+
from dashscope.app.application_response import ApplicationResponse
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
|
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|
|
62
61
|
async def text_chat(
|
|
63
62
|
self,
|
|
64
63
|
prompt: str,
|
|
65
|
-
session_id
|
|
66
|
-
image_urls
|
|
67
|
-
func_tool
|
|
68
|
-
contexts
|
|
69
|
-
system_prompt
|
|
64
|
+
session_id=None,
|
|
65
|
+
image_urls=[],
|
|
66
|
+
func_tool=None,
|
|
67
|
+
contexts=None,
|
|
68
|
+
system_prompt=None,
|
|
70
69
|
model=None,
|
|
71
70
|
**kwargs,
|
|
72
71
|
) -> LLMResponse:
|
|
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|
|
122
121
|
)
|
|
123
122
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
124
123
|
|
|
124
|
+
assert isinstance(response, ApplicationResponse)
|
|
125
|
+
|
|
125
126
|
logger.debug(f"dashscope resp: {response}")
|
|
126
127
|
|
|
127
128
|
if response.status_code != 200:
|
|
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
|
|
135
136
|
),
|
|
136
137
|
)
|
|
137
138
|
|
|
138
|
-
output_text = response.output.get("text", "")
|
|
139
|
+
output_text = response.output.get("text", "") or ""
|
|
139
140
|
# RAG 引用脚标格式化
|
|
140
141
|
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
|
141
142
|
if self.output_reference and response.output.get("doc_references", None):
|
|
142
143
|
ref_str = ""
|
|
143
|
-
for ref in response.output.get("doc_references", []):
|
|
144
|
+
for ref in response.output.get("doc_references", []) or []:
|
|
144
145
|
ref_title = (
|
|
145
146
|
ref.get("title", "")
|
|
146
147
|
if ref.get("title")
|