AstrBot 4.8.0__py3-none-any.whl → 4.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
- astrbot/core/agent/tool.py +7 -2
- astrbot/core/astr_agent_tool_exec.py +5 -1
- astrbot/core/config/astrbot_config.py +4 -0
- astrbot/core/config/default.py +72 -1
- astrbot/core/config/i18n_utils.py +1 -0
- astrbot/core/core_lifecycle.py +1 -1
- astrbot/core/db/__init__.py +2 -3
- astrbot/core/db/migration/migra_3_to_4.py +2 -0
- astrbot/core/db/migration/sqlite_v3.py +6 -4
- astrbot/core/db/po.py +16 -15
- astrbot/core/db/sqlite.py +4 -3
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
- astrbot/core/event_bus.py +6 -1
- astrbot/core/knowledge_base/retrieval/manager.py +5 -1
- astrbot/core/log.py +2 -1
- astrbot/core/message/components.py +9 -3
- astrbot/core/persona_mgr.py +2 -2
- astrbot/core/pipeline/content_safety_check/stage.py +1 -1
- astrbot/core/pipeline/context_utils.py +2 -1
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
- astrbot/core/pipeline/process_stage/stage.py +1 -1
- astrbot/core/pipeline/respond/stage.py +8 -2
- astrbot/core/pipeline/result_decorate/stage.py +89 -22
- astrbot/core/pipeline/scheduler.py +5 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -0
- astrbot/core/platform/astr_message_event.py +5 -3
- astrbot/core/platform/astrbot_message.py +2 -2
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/platform.py +11 -3
- astrbot/core/platform/platform_metadata.py +1 -1
- astrbot/core/platform/register.py +1 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +9 -5
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +24 -16
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
- astrbot/core/platform/sources/discord/client.py +16 -4
- astrbot/core/platform/sources/discord/components.py +2 -2
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +52 -24
- astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
- astrbot/core/platform/sources/lark/lark_adapter.py +183 -20
- astrbot/core/platform/sources/lark/lark_event.py +39 -4
- astrbot/core/platform/sources/lark/server.py +206 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +2 -3
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +62 -18
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +13 -7
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +5 -3
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/slack/client.py +9 -2
- astrbot/core/platform/sources/slack/slack_adapter.py +15 -9
- astrbot/core/platform/sources/slack/slack_event.py +8 -7
- astrbot/core/platform/sources/telegram/tg_adapter.py +1 -1
- astrbot/core/platform/sources/telegram/tg_event.py +23 -27
- astrbot/core/platform/sources/webchat/webchat_adapter.py +2 -2
- astrbot/core/platform/sources/webchat/webchat_event.py +2 -2
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +26 -9
- astrbot/core/platform/sources/wecom/wecom_adapter.py +25 -28
- astrbot/core/platform/sources/wecom/wecom_event.py +2 -2
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +30 -25
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +10 -7
- astrbot/core/provider/func_tool_manager.py +3 -3
- astrbot/core/provider/manager.py +130 -74
- astrbot/core/provider/provider.py +12 -1
- astrbot/core/provider/sources/azure_tts_source.py +31 -9
- astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
- astrbot/core/provider/sources/dashscope_tts.py +3 -2
- astrbot/core/provider/sources/edge_tts_source.py +1 -1
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
- astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
- astrbot/core/provider/sources/gemini_source.py +12 -10
- astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
- astrbot/core/provider/sources/openai_embedding_source.py +2 -2
- astrbot/core/provider/sources/openai_source.py +4 -0
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
- astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
- astrbot/core/provider/sources/whisper_api_source.py +1 -1
- astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
- astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
- astrbot/core/star/context.py +2 -2
- astrbot/core/star/register/star_handler.py +22 -5
- astrbot/core/star/star_handler.py +85 -4
- astrbot/core/updator.py +3 -3
- astrbot/core/utils/io.py +1 -1
- astrbot/core/utils/session_waiter.py +17 -10
- astrbot/core/utils/shared_preferences.py +32 -0
- astrbot/core/utils/t2i/__init__.py +2 -2
- astrbot/core/utils/t2i/local_strategy.py +25 -31
- astrbot/core/utils/tencent_record_helper.py +1 -1
- astrbot/core/utils/version_comparator.py +6 -3
- astrbot/core/utils/webhook_utils.py +19 -0
- astrbot/dashboard/routes/chat.py +14 -9
- astrbot/dashboard/routes/config.py +10 -20
- astrbot/dashboard/routes/conversation.py +91 -1
- astrbot/dashboard/routes/knowledge_base.py +253 -78
- astrbot/dashboard/routes/log.py +13 -8
- astrbot/dashboard/routes/platform.py +1 -1
- astrbot/dashboard/routes/plugin.py +113 -52
- astrbot/dashboard/routes/route.py +2 -0
- astrbot/dashboard/server.py +6 -3
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/METADATA +9 -1
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/RECORD +106 -105
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/WHEEL +0 -0
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import uuid
|
|
3
|
+
from typing import cast
|
|
3
4
|
|
|
4
5
|
from wechatpy import WeChatClient
|
|
5
6
|
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
|
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|
|
85
86
|
|
|
86
87
|
async def send(self, message: MessageChain):
|
|
87
88
|
message_obj = self.message_obj
|
|
88
|
-
active_send_mode = message_obj.raw_message.get(
|
|
89
|
+
active_send_mode = cast(dict, message_obj.raw_message).get(
|
|
90
|
+
"active_send_mode", False
|
|
91
|
+
)
|
|
89
92
|
for comp in message.chain:
|
|
90
93
|
if isinstance(comp, Plain):
|
|
91
94
|
# Split long text messages if needed
|
|
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|
|
96
99
|
else:
|
|
97
100
|
reply = TextReply(
|
|
98
101
|
content=chunk,
|
|
99
|
-
message=self.message_obj.raw_message["message"],
|
|
102
|
+
message=cast(dict, self.message_obj.raw_message)["message"],
|
|
100
103
|
)
|
|
101
104
|
xml = reply.render()
|
|
102
|
-
future = self.message_obj.raw_message["future"]
|
|
105
|
+
future = cast(dict, self.message_obj.raw_message)["future"]
|
|
103
106
|
assert isinstance(future, asyncio.Future)
|
|
104
107
|
future.set_result(xml)
|
|
105
108
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
|
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|
|
125
128
|
else:
|
|
126
129
|
reply = ImageReply(
|
|
127
130
|
media_id=response["media_id"],
|
|
128
|
-
message=self.message_obj.raw_message["message"],
|
|
131
|
+
message=cast(dict, self.message_obj.raw_message)["message"],
|
|
129
132
|
)
|
|
130
133
|
xml = reply.render()
|
|
131
|
-
future = self.message_obj.raw_message["future"]
|
|
134
|
+
future = cast(dict, self.message_obj.raw_message)["future"]
|
|
132
135
|
assert isinstance(future, asyncio.Future)
|
|
133
136
|
future.set_result(xml)
|
|
134
137
|
|
|
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|
|
160
163
|
else:
|
|
161
164
|
reply = VoiceReply(
|
|
162
165
|
media_id=response["media_id"],
|
|
163
|
-
message=self.message_obj.raw_message["message"],
|
|
166
|
+
message=cast(dict, self.message_obj.raw_message)["message"],
|
|
164
167
|
)
|
|
165
168
|
xml = reply.render()
|
|
166
|
-
future = self.message_obj.raw_message["future"]
|
|
169
|
+
future = cast(dict, self.message_obj.raw_message)["future"]
|
|
167
170
|
assert isinstance(future, asyncio.Future)
|
|
168
171
|
future.set_result(xml)
|
|
169
172
|
|
|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
4
4
|
import copy
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
from collections.abc import Awaitable, Callable
|
|
7
|
+
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
10
|
import aiohttp
|
|
@@ -118,7 +118,7 @@ class FunctionToolManager:
|
|
|
118
118
|
name: str,
|
|
119
119
|
func_args: list[dict],
|
|
120
120
|
desc: str,
|
|
121
|
-
handler: Callable[..., Awaitable[Any]],
|
|
121
|
+
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
|
122
122
|
) -> FuncTool:
|
|
123
123
|
params = {
|
|
124
124
|
"type": "object", # hard-coded here
|
|
@@ -140,7 +140,7 @@ class FunctionToolManager:
|
|
|
140
140
|
name: str,
|
|
141
141
|
func_args: list,
|
|
142
142
|
desc: str,
|
|
143
|
-
handler: Callable[..., Awaitable[Any]],
|
|
143
|
+
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
|
144
144
|
) -> None:
|
|
145
145
|
"""添加函数调用工具
|
|
146
146
|
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import traceback
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
3
4
|
|
|
4
5
|
from astrbot.core import astrbot_config, logger, sp
|
|
5
6
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
@@ -10,6 +11,7 @@ from .entities import ProviderType
|
|
|
10
11
|
from .provider import (
|
|
11
12
|
EmbeddingProvider,
|
|
12
13
|
Provider,
|
|
14
|
+
Providers,
|
|
13
15
|
RerankProvider,
|
|
14
16
|
STTProvider,
|
|
15
17
|
TTSProvider,
|
|
@@ -17,6 +19,11 @@ from .provider import (
|
|
|
17
19
|
from .register import llm_tools, provider_cls_map
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
@runtime_checkable
|
|
23
|
+
class HasInitialize(Protocol):
|
|
24
|
+
async def initialize(self) -> None: ...
|
|
25
|
+
|
|
26
|
+
|
|
20
27
|
class ProviderManager:
|
|
21
28
|
def __init__(
|
|
22
29
|
self,
|
|
@@ -48,7 +55,7 @@ class ProviderManager:
|
|
|
48
55
|
"""加载的 Rerank Provider 的实例"""
|
|
49
56
|
self.inst_map: dict[
|
|
50
57
|
str,
|
|
51
|
-
|
|
58
|
+
Providers,
|
|
52
59
|
] = {}
|
|
53
60
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
|
54
61
|
self.llm_tools = llm_tools
|
|
@@ -123,15 +130,13 @@ class ProviderManager:
|
|
|
123
130
|
self.curr_provider_inst = prov
|
|
124
131
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
125
132
|
|
|
126
|
-
async def get_provider_by_id(self, provider_id: str) ->
|
|
133
|
+
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
|
127
134
|
"""根据提供商 ID 获取提供商实例"""
|
|
128
135
|
return self.inst_map.get(provider_id)
|
|
129
136
|
|
|
130
137
|
def get_using_provider(
|
|
131
|
-
self,
|
|
132
|
-
|
|
133
|
-
umo=None,
|
|
134
|
-
) -> Provider | STTProvider | TTSProvider | None:
|
|
138
|
+
self, provider_type: ProviderType, umo=None
|
|
139
|
+
) -> Providers | None:
|
|
135
140
|
"""获取正在使用的提供商实例。
|
|
136
141
|
|
|
137
142
|
Args:
|
|
@@ -191,7 +196,6 @@ class ProviderManager:
|
|
|
191
196
|
logger.error(traceback.format_exc())
|
|
192
197
|
logger.error(e)
|
|
193
198
|
|
|
194
|
-
# 设置默认提供商
|
|
195
199
|
selected_provider_id = sp.get(
|
|
196
200
|
"curr_provider",
|
|
197
201
|
self.provider_settings.get("default_provider_id"),
|
|
@@ -210,15 +214,37 @@ class ProviderManager:
|
|
|
210
214
|
scope="global",
|
|
211
215
|
scope_id="global",
|
|
212
216
|
)
|
|
213
|
-
|
|
217
|
+
|
|
218
|
+
temp_provider = (
|
|
219
|
+
self.inst_map.get(selected_provider_id)
|
|
220
|
+
if isinstance(selected_provider_id, str)
|
|
221
|
+
else None
|
|
222
|
+
)
|
|
223
|
+
self.curr_provider_inst = (
|
|
224
|
+
temp_provider if isinstance(temp_provider, Provider) else None
|
|
225
|
+
)
|
|
214
226
|
if not self.curr_provider_inst and self.provider_insts:
|
|
215
227
|
self.curr_provider_inst = self.provider_insts[0]
|
|
216
228
|
|
|
217
|
-
|
|
229
|
+
temp_stt = (
|
|
230
|
+
self.inst_map.get(selected_stt_provider_id)
|
|
231
|
+
if isinstance(selected_stt_provider_id, str)
|
|
232
|
+
else None
|
|
233
|
+
)
|
|
234
|
+
self.curr_stt_provider_inst = (
|
|
235
|
+
temp_stt if isinstance(temp_stt, STTProvider) else None
|
|
236
|
+
)
|
|
218
237
|
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
|
219
238
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
220
239
|
|
|
221
|
-
|
|
240
|
+
temp_tts = (
|
|
241
|
+
self.inst_map.get(selected_tts_provider_id)
|
|
242
|
+
if isinstance(selected_tts_provider_id, str)
|
|
243
|
+
else None
|
|
244
|
+
)
|
|
245
|
+
self.curr_tts_provider_inst = (
|
|
246
|
+
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
|
247
|
+
)
|
|
222
248
|
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
|
223
249
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
224
250
|
|
|
@@ -358,73 +384,103 @@ class ProviderManager:
|
|
|
358
384
|
|
|
359
385
|
provider_metadata.id = provider_config["id"]
|
|
360
386
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
self.
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
387
|
+
match provider_metadata.provider_type:
|
|
388
|
+
case ProviderType.SPEECH_TO_TEXT:
|
|
389
|
+
# STT 任务
|
|
390
|
+
if not issubclass(cls_type, STTProvider):
|
|
391
|
+
raise TypeError(
|
|
392
|
+
f"Provider class {cls_type} is not a subclass of STTProvider"
|
|
393
|
+
)
|
|
394
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
395
|
+
|
|
396
|
+
if isinstance(inst, HasInitialize):
|
|
397
|
+
await inst.initialize()
|
|
398
|
+
|
|
399
|
+
self.stt_provider_insts.append(inst)
|
|
400
|
+
if (
|
|
401
|
+
self.provider_stt_settings.get("provider_id")
|
|
402
|
+
== provider_config["id"]
|
|
403
|
+
):
|
|
404
|
+
self.curr_stt_provider_inst = inst
|
|
405
|
+
logger.info(
|
|
406
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
|
407
|
+
)
|
|
408
|
+
if not self.curr_stt_provider_inst:
|
|
409
|
+
self.curr_stt_provider_inst = inst
|
|
410
|
+
|
|
411
|
+
case ProviderType.TEXT_TO_SPEECH:
|
|
412
|
+
# TTS 任务
|
|
413
|
+
if not issubclass(cls_type, TTSProvider):
|
|
414
|
+
raise TypeError(
|
|
415
|
+
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
|
416
|
+
)
|
|
417
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
418
|
+
|
|
419
|
+
if isinstance(inst, HasInitialize):
|
|
420
|
+
await inst.initialize()
|
|
421
|
+
|
|
422
|
+
self.tts_provider_insts.append(inst)
|
|
423
|
+
if (
|
|
424
|
+
self.provider_settings.get("provider_id")
|
|
425
|
+
== provider_config["id"]
|
|
426
|
+
):
|
|
427
|
+
self.curr_tts_provider_inst = inst
|
|
428
|
+
logger.info(
|
|
429
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
|
430
|
+
)
|
|
431
|
+
if not self.curr_tts_provider_inst:
|
|
432
|
+
self.curr_tts_provider_inst = inst
|
|
433
|
+
|
|
434
|
+
case ProviderType.CHAT_COMPLETION:
|
|
435
|
+
# 文本生成任务
|
|
436
|
+
if not issubclass(cls_type, Provider):
|
|
437
|
+
raise TypeError(
|
|
438
|
+
f"Provider class {cls_type} is not a subclass of Provider"
|
|
439
|
+
)
|
|
440
|
+
inst = cls_type(
|
|
441
|
+
provider_config,
|
|
442
|
+
self.provider_settings,
|
|
392
443
|
)
|
|
393
|
-
if not self.curr_tts_provider_inst:
|
|
394
|
-
self.curr_tts_provider_inst = inst
|
|
395
|
-
|
|
396
|
-
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
|
397
|
-
# 文本生成任务
|
|
398
|
-
inst = cls_type(
|
|
399
|
-
provider_config,
|
|
400
|
-
self.provider_settings,
|
|
401
|
-
)
|
|
402
444
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
445
|
+
if isinstance(inst, HasInitialize):
|
|
446
|
+
await inst.initialize()
|
|
447
|
+
|
|
448
|
+
self.provider_insts.append(inst)
|
|
449
|
+
if (
|
|
450
|
+
self.provider_settings.get("default_provider_id")
|
|
451
|
+
== provider_config["id"]
|
|
452
|
+
):
|
|
453
|
+
self.curr_provider_inst = inst
|
|
454
|
+
logger.info(
|
|
455
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
456
|
+
)
|
|
457
|
+
if not self.curr_provider_inst:
|
|
458
|
+
self.curr_provider_inst = inst
|
|
459
|
+
|
|
460
|
+
case ProviderType.EMBEDDING:
|
|
461
|
+
if not issubclass(cls_type, EmbeddingProvider):
|
|
462
|
+
raise TypeError(
|
|
463
|
+
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
|
464
|
+
)
|
|
465
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
466
|
+
if isinstance(inst, HasInitialize):
|
|
467
|
+
await inst.initialize()
|
|
468
|
+
self.embedding_provider_insts.append(inst)
|
|
469
|
+
case ProviderType.RERANK:
|
|
470
|
+
if not issubclass(cls_type, RerankProvider):
|
|
471
|
+
raise TypeError(
|
|
472
|
+
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
|
473
|
+
)
|
|
474
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
475
|
+
if isinstance(inst, HasInitialize):
|
|
476
|
+
await inst.initialize()
|
|
477
|
+
self.rerank_provider_insts.append(inst)
|
|
478
|
+
case _:
|
|
479
|
+
# 未知供应商抛出异常,确保inst初始化
|
|
480
|
+
# Should be unreachable
|
|
481
|
+
raise Exception(
|
|
482
|
+
f"未知的提供商类型:{provider_metadata.provider_type}"
|
|
414
483
|
)
|
|
415
|
-
if not self.curr_provider_inst:
|
|
416
|
-
self.curr_provider_inst = inst
|
|
417
|
-
|
|
418
|
-
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
|
419
|
-
inst = cls_type(provider_config, self.provider_settings)
|
|
420
|
-
if getattr(inst, "initialize", None):
|
|
421
|
-
await inst.initialize()
|
|
422
|
-
self.embedding_provider_insts.append(inst)
|
|
423
|
-
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
424
|
-
inst = cls_type(provider_config, self.provider_settings)
|
|
425
|
-
if getattr(inst, "initialize", None):
|
|
426
|
-
await inst.initialize()
|
|
427
|
-
self.rerank_provider_insts.append(inst)
|
|
428
484
|
|
|
429
485
|
self.inst_map[provider_config["id"]] = inst
|
|
430
486
|
except Exception as e:
|
|
@@ -2,6 +2,7 @@ import abc
|
|
|
2
2
|
import asyncio
|
|
3
3
|
import os
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import TypeAlias, Union
|
|
5
6
|
|
|
6
7
|
from astrbot.core.agent.message import Message
|
|
7
8
|
from astrbot.core.agent.tool import ToolSet
|
|
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
|
|
|
14
15
|
from astrbot.core.provider.register import provider_cls_map
|
|
15
16
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
|
16
17
|
|
|
18
|
+
Providers: TypeAlias = Union[
|
|
19
|
+
"Provider",
|
|
20
|
+
"STTProvider",
|
|
21
|
+
"TTSProvider",
|
|
22
|
+
"EmbeddingProvider",
|
|
23
|
+
"RerankProvider",
|
|
24
|
+
]
|
|
25
|
+
|
|
17
26
|
|
|
18
27
|
class AbstractProvider(abc.ABC):
|
|
19
28
|
"""Provider Abstract Class"""
|
|
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
|
|
|
142
151
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
143
152
|
|
|
144
153
|
"""
|
|
145
|
-
|
|
154
|
+
if False: # pragma: no cover - make this an async generator for typing
|
|
155
|
+
yield None # type: ignore
|
|
156
|
+
raise NotImplementedError()
|
|
146
157
|
|
|
147
158
|
async def pop_record(self, context: list):
|
|
148
159
|
"""弹出 context 第一条非系统提示词对话记录"""
|
|
@@ -29,15 +29,24 @@ class OTTSProvider:
|
|
|
29
29
|
self.last_sync_time = 0
|
|
30
30
|
self.timeout = Timeout(10.0)
|
|
31
31
|
self.retry_count = 3
|
|
32
|
-
self.
|
|
32
|
+
self._client: AsyncClient | None = None
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def client(self) -> AsyncClient:
|
|
36
|
+
if self._client is None:
|
|
37
|
+
raise RuntimeError(
|
|
38
|
+
"Client not initialized. Please use 'async with' context."
|
|
39
|
+
)
|
|
40
|
+
return self._client
|
|
33
41
|
|
|
34
42
|
async def __aenter__(self):
|
|
35
|
-
self.
|
|
43
|
+
self._client = AsyncClient(timeout=self.timeout)
|
|
36
44
|
return self
|
|
37
45
|
|
|
38
46
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
39
|
-
if self.
|
|
40
|
-
await self.
|
|
47
|
+
if self._client:
|
|
48
|
+
await self._client.aclose()
|
|
49
|
+
self._client = None
|
|
41
50
|
|
|
42
51
|
async def _sync_time(self):
|
|
43
52
|
try:
|
|
@@ -90,6 +99,7 @@ class OTTSProvider:
|
|
|
90
99
|
if attempt == self.retry_count - 1:
|
|
91
100
|
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
|
92
101
|
await asyncio.sleep(0.5 * (attempt + 1))
|
|
102
|
+
raise RuntimeError("OTTS未返回音频文件")
|
|
93
103
|
|
|
94
104
|
|
|
95
105
|
class AzureNativeProvider(TTSProvider):
|
|
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
|
|
105
115
|
self.endpoint = (
|
|
106
116
|
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
|
107
117
|
)
|
|
108
|
-
self.
|
|
118
|
+
self._client: AsyncClient | None = None
|
|
109
119
|
self.token = None
|
|
110
120
|
self.token_expire = 0
|
|
111
121
|
self.voice_params = {
|
|
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
|
|
116
126
|
"volume": provider_config.get("azure_tts_volume", "100"),
|
|
117
127
|
}
|
|
118
128
|
|
|
129
|
+
@property
|
|
130
|
+
def client(self) -> AsyncClient:
|
|
131
|
+
if self._client is None:
|
|
132
|
+
raise RuntimeError(
|
|
133
|
+
"Client not initialized. Please use 'async with' context."
|
|
134
|
+
)
|
|
135
|
+
return self._client
|
|
136
|
+
|
|
119
137
|
async def __aenter__(self):
|
|
120
|
-
self.
|
|
138
|
+
self._client = AsyncClient(
|
|
121
139
|
headers={
|
|
122
140
|
"User-Agent": f"AstrBot/{VERSION}",
|
|
123
141
|
"Content-Type": "application/ssml+xml",
|
|
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
|
|
127
145
|
return self
|
|
128
146
|
|
|
129
147
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
130
|
-
if self.
|
|
131
|
-
await self.
|
|
148
|
+
if self._client:
|
|
149
|
+
await self._client.aclose()
|
|
150
|
+
self._client = None
|
|
132
151
|
|
|
133
152
|
async def _refresh_token(self):
|
|
134
153
|
token_url = (
|
|
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
|
|
181
200
|
key_value = provider_config.get("azure_tts_subscription_key", "")
|
|
182
201
|
self.provider = self._parse_provider(key_value, provider_config)
|
|
183
202
|
|
|
184
|
-
def _parse_provider(
|
|
203
|
+
def _parse_provider(
|
|
204
|
+
self, key_value: str, config: dict
|
|
205
|
+
) -> OTTSProvider | AzureNativeProvider:
|
|
185
206
|
if key_value.lower().startswith("other["):
|
|
207
|
+
json_str = ""
|
|
186
208
|
try:
|
|
187
209
|
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
|
188
210
|
if not match:
|
|
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|
|
36
36
|
super().__init__(provider_config, provider_settings)
|
|
37
37
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
|
38
38
|
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
|
39
|
-
self.set_model(provider_config
|
|
39
|
+
self.set_model(provider_config["model"])
|
|
40
40
|
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
|
41
41
|
dashscope.api_key = self.chosen_api_key
|
|
42
42
|
|
|
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|
|
71
71
|
|
|
72
72
|
kwargs = {
|
|
73
73
|
"model": model,
|
|
74
|
-
"
|
|
74
|
+
"messages": None,
|
|
75
75
|
"api_key": self.chosen_api_key,
|
|
76
76
|
"voice": self.voice or "Cherry",
|
|
77
|
+
"text": text,
|
|
77
78
|
}
|
|
78
79
|
if not self.voice:
|
|
79
80
|
logging.warning(
|
|
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
|
|
67
67
|
from pyffmpeg import FFmpeg
|
|
68
68
|
|
|
69
69
|
ff = FFmpeg()
|
|
70
|
-
ff.convert(
|
|
70
|
+
ff.convert(input_file=mp3_path, output_file=wav_path)
|
|
71
71
|
except Exception as e:
|
|
72
72
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
|
73
73
|
# use ffmpeg command line
|
|
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
59
59
|
self.headers = {
|
|
60
60
|
"Authorization": f"Bearer {self.chosen_api_key}",
|
|
61
61
|
}
|
|
62
|
-
self.set_model(provider_config
|
|
62
|
+
self.set_model(provider_config["model"])
|
|
63
63
|
|
|
64
|
-
async def _get_reference_id_by_character(self, character: str) -> str:
|
|
64
|
+
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
|
65
65
|
"""获取角色的reference_id
|
|
66
66
|
|
|
67
67
|
Args:
|
|
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
109
109
|
pattern = r"^[a-fA-F0-9]{32}$"
|
|
110
110
|
return bool(re.match(pattern, reference_id.strip()))
|
|
111
111
|
|
|
112
|
-
async def _generate_request(self, text: str) ->
|
|
112
|
+
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
|
113
113
|
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
|
114
114
|
if self.reference_id and self.reference_id.strip():
|
|
115
115
|
# 验证reference_id格式
|
|
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
146
146
|
async for chunk in response.aiter_bytes():
|
|
147
147
|
f.write(chunk)
|
|
148
148
|
return path
|
|
149
|
-
|
|
149
|
+
body = await response.aread()
|
|
150
|
+
text = body.decode("utf-8", errors="replace")
|
|
150
151
|
raise Exception(f"Fish Audio API请求失败: {text}")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
1
3
|
from google import genai
|
|
2
4
|
from google.genai import types
|
|
3
5
|
from google.genai.errors import APIError
|
|
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
18
20
|
self.provider_config = provider_config
|
|
19
21
|
self.provider_settings = provider_settings
|
|
20
22
|
|
|
21
|
-
api_key: str = provider_config
|
|
22
|
-
api_base: str = provider_config
|
|
23
|
+
api_key: str = provider_config["embedding_api_key"]
|
|
24
|
+
api_base: str = provider_config["embedding_api_base"]
|
|
23
25
|
timeout: int = int(provider_config.get("timeout", 20))
|
|
24
26
|
|
|
25
27
|
http_options = types.HttpOptions(timeout=timeout * 1000)
|
|
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
41
43
|
model=self.model,
|
|
42
44
|
contents=text,
|
|
43
45
|
)
|
|
46
|
+
assert result.embeddings is not None
|
|
47
|
+
assert result.embeddings[0].values is not None
|
|
44
48
|
return result.embeddings[0].values
|
|
45
49
|
except APIError as e:
|
|
46
50
|
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
|
47
51
|
|
|
48
|
-
async def get_embeddings(self,
|
|
52
|
+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
49
53
|
"""批量获取文本的嵌入"""
|
|
50
54
|
try:
|
|
51
55
|
result = await self.client.models.embed_content(
|
|
52
56
|
model=self.model,
|
|
53
|
-
contents=
|
|
57
|
+
contents=cast(types.ContentListUnion, text),
|
|
54
58
|
)
|
|
55
|
-
|
|
59
|
+
assert result.embeddings is not None
|
|
60
|
+
|
|
61
|
+
embeddings: list[list[float]] = []
|
|
62
|
+
for embedding in result.embeddings:
|
|
63
|
+
assert embedding.values is not None
|
|
64
|
+
embeddings.append(embedding.values)
|
|
65
|
+
return embeddings
|
|
56
66
|
except APIError as e:
|
|
57
67
|
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
|
58
68
|
|