AstrBot 4.7.4__py3-none-any.whl → 4.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
- astrbot/core/agent/tool.py +7 -2
- astrbot/core/astr_agent_run_util.py +15 -1
- astrbot/core/astr_agent_tool_exec.py +5 -1
- astrbot/core/config/astrbot_config.py +4 -0
- astrbot/core/config/default.py +116 -1
- astrbot/core/core_lifecycle.py +1 -1
- astrbot/core/db/__init__.py +32 -4
- astrbot/core/db/migration/migra_3_to_4.py +2 -0
- astrbot/core/db/migration/sqlite_v3.py +6 -4
- astrbot/core/db/po.py +16 -15
- astrbot/core/db/sqlite.py +56 -1
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
- astrbot/core/event_bus.py +6 -1
- astrbot/core/knowledge_base/retrieval/manager.py +5 -1
- astrbot/core/log.py +2 -1
- astrbot/core/message/components.py +9 -3
- astrbot/core/persona_mgr.py +2 -2
- astrbot/core/pipeline/content_safety_check/stage.py +1 -1
- astrbot/core/pipeline/context_utils.py +2 -1
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +1 -1
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
- astrbot/core/pipeline/process_stage/stage.py +1 -1
- astrbot/core/pipeline/respond/stage.py +4 -2
- astrbot/core/pipeline/result_decorate/stage.py +68 -21
- astrbot/core/pipeline/scheduler.py +5 -1
- astrbot/core/pipeline/waking_check/stage.py +10 -0
- astrbot/core/platform/astr_message_event.py +5 -3
- astrbot/core/platform/astrbot_message.py +2 -2
- astrbot/core/platform/manager.py +71 -9
- astrbot/core/platform/platform.py +109 -4
- astrbot/core/platform/platform_metadata.py +1 -1
- astrbot/core/platform/register.py +1 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +13 -8
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +28 -22
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
- astrbot/core/platform/sources/discord/client.py +16 -4
- astrbot/core/platform/sources/discord/components.py +2 -2
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +53 -26
- astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
- astrbot/core/platform/sources/lark/lark_adapter.py +178 -22
- astrbot/core/platform/sources/lark/lark_event.py +39 -4
- astrbot/core/platform/sources/lark/server.py +206 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +3 -5
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +64 -18
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +14 -10
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -11
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +15 -2
- astrbot/core/platform/sources/satori/satori_adapter.py +1 -2
- astrbot/core/platform/sources/slack/client.py +58 -40
- astrbot/core/platform/sources/slack/slack_adapter.py +36 -16
- astrbot/core/platform/sources/slack/slack_event.py +11 -10
- astrbot/core/platform/sources/telegram/tg_adapter.py +2 -3
- astrbot/core/platform/sources/telegram/tg_event.py +23 -27
- astrbot/core/platform/sources/webchat/webchat_adapter.py +97 -31
- astrbot/core/platform/sources/webchat/webchat_event.py +35 -35
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +27 -11
- astrbot/core/platform/sources/wecom/wecom_adapter.py +75 -36
- astrbot/core/platform/sources/wecom/wecom_event.py +3 -3
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +26 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +27 -5
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +81 -35
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +11 -8
- astrbot/core/platform_message_history_mgr.py +3 -3
- astrbot/core/provider/func_tool_manager.py +3 -3
- astrbot/core/provider/manager.py +130 -74
- astrbot/core/provider/provider.py +12 -1
- astrbot/core/provider/sources/azure_tts_source.py +31 -9
- astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
- astrbot/core/provider/sources/dashscope_tts.py +3 -2
- astrbot/core/provider/sources/edge_tts_source.py +1 -1
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
- astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
- astrbot/core/provider/sources/gemini_source.py +12 -10
- astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
- astrbot/core/provider/sources/openai_embedding_source.py +2 -2
- astrbot/core/provider/sources/openai_source.py +4 -0
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
- astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
- astrbot/core/provider/sources/whisper_api_source.py +44 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
- astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
- astrbot/core/star/context.py +2 -2
- astrbot/core/star/register/star_handler.py +22 -5
- astrbot/core/star/star_handler.py +85 -4
- astrbot/core/updator.py +3 -3
- astrbot/core/utils/io.py +1 -1
- astrbot/core/utils/session_waiter.py +17 -10
- astrbot/core/utils/shared_preferences.py +32 -0
- astrbot/core/utils/t2i/__init__.py +2 -2
- astrbot/core/utils/t2i/local_strategy.py +25 -31
- astrbot/core/utils/tencent_record_helper.py +2 -2
- astrbot/core/utils/version_comparator.py +6 -3
- astrbot/core/utils/webhook_utils.py +66 -0
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/chat.py +311 -76
- astrbot/dashboard/routes/config.py +14 -5
- astrbot/dashboard/routes/knowledge_base.py +254 -79
- astrbot/dashboard/routes/log.py +13 -8
- astrbot/dashboard/routes/platform.py +100 -0
- astrbot/dashboard/routes/plugin.py +108 -51
- astrbot/dashboard/routes/route.py +2 -0
- astrbot/dashboard/server.py +9 -4
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/METADATA +50 -37
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/RECORD +111 -108
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/WHEEL +0 -0
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import traceback
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
3
4
|
|
|
4
5
|
from astrbot.core import astrbot_config, logger, sp
|
|
5
6
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
@@ -10,6 +11,7 @@ from .entities import ProviderType
|
|
|
10
11
|
from .provider import (
|
|
11
12
|
EmbeddingProvider,
|
|
12
13
|
Provider,
|
|
14
|
+
Providers,
|
|
13
15
|
RerankProvider,
|
|
14
16
|
STTProvider,
|
|
15
17
|
TTSProvider,
|
|
@@ -17,6 +19,11 @@ from .provider import (
|
|
|
17
19
|
from .register import llm_tools, provider_cls_map
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
@runtime_checkable
|
|
23
|
+
class HasInitialize(Protocol):
|
|
24
|
+
async def initialize(self) -> None: ...
|
|
25
|
+
|
|
26
|
+
|
|
20
27
|
class ProviderManager:
|
|
21
28
|
def __init__(
|
|
22
29
|
self,
|
|
@@ -48,7 +55,7 @@ class ProviderManager:
|
|
|
48
55
|
"""加载的 Rerank Provider 的实例"""
|
|
49
56
|
self.inst_map: dict[
|
|
50
57
|
str,
|
|
51
|
-
|
|
58
|
+
Providers,
|
|
52
59
|
] = {}
|
|
53
60
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
|
54
61
|
self.llm_tools = llm_tools
|
|
@@ -123,15 +130,13 @@ class ProviderManager:
|
|
|
123
130
|
self.curr_provider_inst = prov
|
|
124
131
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
125
132
|
|
|
126
|
-
async def get_provider_by_id(self, provider_id: str) ->
|
|
133
|
+
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
|
127
134
|
"""根据提供商 ID 获取提供商实例"""
|
|
128
135
|
return self.inst_map.get(provider_id)
|
|
129
136
|
|
|
130
137
|
def get_using_provider(
|
|
131
|
-
self,
|
|
132
|
-
|
|
133
|
-
umo=None,
|
|
134
|
-
) -> Provider | STTProvider | TTSProvider | None:
|
|
138
|
+
self, provider_type: ProviderType, umo=None
|
|
139
|
+
) -> Providers | None:
|
|
135
140
|
"""获取正在使用的提供商实例。
|
|
136
141
|
|
|
137
142
|
Args:
|
|
@@ -191,7 +196,6 @@ class ProviderManager:
|
|
|
191
196
|
logger.error(traceback.format_exc())
|
|
192
197
|
logger.error(e)
|
|
193
198
|
|
|
194
|
-
# 设置默认提供商
|
|
195
199
|
selected_provider_id = sp.get(
|
|
196
200
|
"curr_provider",
|
|
197
201
|
self.provider_settings.get("default_provider_id"),
|
|
@@ -210,15 +214,37 @@ class ProviderManager:
|
|
|
210
214
|
scope="global",
|
|
211
215
|
scope_id="global",
|
|
212
216
|
)
|
|
213
|
-
|
|
217
|
+
|
|
218
|
+
temp_provider = (
|
|
219
|
+
self.inst_map.get(selected_provider_id)
|
|
220
|
+
if isinstance(selected_provider_id, str)
|
|
221
|
+
else None
|
|
222
|
+
)
|
|
223
|
+
self.curr_provider_inst = (
|
|
224
|
+
temp_provider if isinstance(temp_provider, Provider) else None
|
|
225
|
+
)
|
|
214
226
|
if not self.curr_provider_inst and self.provider_insts:
|
|
215
227
|
self.curr_provider_inst = self.provider_insts[0]
|
|
216
228
|
|
|
217
|
-
|
|
229
|
+
temp_stt = (
|
|
230
|
+
self.inst_map.get(selected_stt_provider_id)
|
|
231
|
+
if isinstance(selected_stt_provider_id, str)
|
|
232
|
+
else None
|
|
233
|
+
)
|
|
234
|
+
self.curr_stt_provider_inst = (
|
|
235
|
+
temp_stt if isinstance(temp_stt, STTProvider) else None
|
|
236
|
+
)
|
|
218
237
|
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
|
219
238
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
220
239
|
|
|
221
|
-
|
|
240
|
+
temp_tts = (
|
|
241
|
+
self.inst_map.get(selected_tts_provider_id)
|
|
242
|
+
if isinstance(selected_tts_provider_id, str)
|
|
243
|
+
else None
|
|
244
|
+
)
|
|
245
|
+
self.curr_tts_provider_inst = (
|
|
246
|
+
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
|
247
|
+
)
|
|
222
248
|
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
|
223
249
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
224
250
|
|
|
@@ -358,73 +384,103 @@ class ProviderManager:
|
|
|
358
384
|
|
|
359
385
|
provider_metadata.id = provider_config["id"]
|
|
360
386
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
self.
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
387
|
+
match provider_metadata.provider_type:
|
|
388
|
+
case ProviderType.SPEECH_TO_TEXT:
|
|
389
|
+
# STT 任务
|
|
390
|
+
if not issubclass(cls_type, STTProvider):
|
|
391
|
+
raise TypeError(
|
|
392
|
+
f"Provider class {cls_type} is not a subclass of STTProvider"
|
|
393
|
+
)
|
|
394
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
395
|
+
|
|
396
|
+
if isinstance(inst, HasInitialize):
|
|
397
|
+
await inst.initialize()
|
|
398
|
+
|
|
399
|
+
self.stt_provider_insts.append(inst)
|
|
400
|
+
if (
|
|
401
|
+
self.provider_stt_settings.get("provider_id")
|
|
402
|
+
== provider_config["id"]
|
|
403
|
+
):
|
|
404
|
+
self.curr_stt_provider_inst = inst
|
|
405
|
+
logger.info(
|
|
406
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
|
407
|
+
)
|
|
408
|
+
if not self.curr_stt_provider_inst:
|
|
409
|
+
self.curr_stt_provider_inst = inst
|
|
410
|
+
|
|
411
|
+
case ProviderType.TEXT_TO_SPEECH:
|
|
412
|
+
# TTS 任务
|
|
413
|
+
if not issubclass(cls_type, TTSProvider):
|
|
414
|
+
raise TypeError(
|
|
415
|
+
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
|
416
|
+
)
|
|
417
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
418
|
+
|
|
419
|
+
if isinstance(inst, HasInitialize):
|
|
420
|
+
await inst.initialize()
|
|
421
|
+
|
|
422
|
+
self.tts_provider_insts.append(inst)
|
|
423
|
+
if (
|
|
424
|
+
self.provider_settings.get("provider_id")
|
|
425
|
+
== provider_config["id"]
|
|
426
|
+
):
|
|
427
|
+
self.curr_tts_provider_inst = inst
|
|
428
|
+
logger.info(
|
|
429
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
|
430
|
+
)
|
|
431
|
+
if not self.curr_tts_provider_inst:
|
|
432
|
+
self.curr_tts_provider_inst = inst
|
|
433
|
+
|
|
434
|
+
case ProviderType.CHAT_COMPLETION:
|
|
435
|
+
# 文本生成任务
|
|
436
|
+
if not issubclass(cls_type, Provider):
|
|
437
|
+
raise TypeError(
|
|
438
|
+
f"Provider class {cls_type} is not a subclass of Provider"
|
|
439
|
+
)
|
|
440
|
+
inst = cls_type(
|
|
441
|
+
provider_config,
|
|
442
|
+
self.provider_settings,
|
|
392
443
|
)
|
|
393
|
-
if not self.curr_tts_provider_inst:
|
|
394
|
-
self.curr_tts_provider_inst = inst
|
|
395
|
-
|
|
396
|
-
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
|
397
|
-
# 文本生成任务
|
|
398
|
-
inst = cls_type(
|
|
399
|
-
provider_config,
|
|
400
|
-
self.provider_settings,
|
|
401
|
-
)
|
|
402
444
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
445
|
+
if isinstance(inst, HasInitialize):
|
|
446
|
+
await inst.initialize()
|
|
447
|
+
|
|
448
|
+
self.provider_insts.append(inst)
|
|
449
|
+
if (
|
|
450
|
+
self.provider_settings.get("default_provider_id")
|
|
451
|
+
== provider_config["id"]
|
|
452
|
+
):
|
|
453
|
+
self.curr_provider_inst = inst
|
|
454
|
+
logger.info(
|
|
455
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
456
|
+
)
|
|
457
|
+
if not self.curr_provider_inst:
|
|
458
|
+
self.curr_provider_inst = inst
|
|
459
|
+
|
|
460
|
+
case ProviderType.EMBEDDING:
|
|
461
|
+
if not issubclass(cls_type, EmbeddingProvider):
|
|
462
|
+
raise TypeError(
|
|
463
|
+
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
|
464
|
+
)
|
|
465
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
466
|
+
if isinstance(inst, HasInitialize):
|
|
467
|
+
await inst.initialize()
|
|
468
|
+
self.embedding_provider_insts.append(inst)
|
|
469
|
+
case ProviderType.RERANK:
|
|
470
|
+
if not issubclass(cls_type, RerankProvider):
|
|
471
|
+
raise TypeError(
|
|
472
|
+
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
|
473
|
+
)
|
|
474
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
475
|
+
if isinstance(inst, HasInitialize):
|
|
476
|
+
await inst.initialize()
|
|
477
|
+
self.rerank_provider_insts.append(inst)
|
|
478
|
+
case _:
|
|
479
|
+
# 未知供应商抛出异常,确保inst初始化
|
|
480
|
+
# Should be unreachable
|
|
481
|
+
raise Exception(
|
|
482
|
+
f"未知的提供商类型:{provider_metadata.provider_type}"
|
|
414
483
|
)
|
|
415
|
-
if not self.curr_provider_inst:
|
|
416
|
-
self.curr_provider_inst = inst
|
|
417
|
-
|
|
418
|
-
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
|
419
|
-
inst = cls_type(provider_config, self.provider_settings)
|
|
420
|
-
if getattr(inst, "initialize", None):
|
|
421
|
-
await inst.initialize()
|
|
422
|
-
self.embedding_provider_insts.append(inst)
|
|
423
|
-
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
424
|
-
inst = cls_type(provider_config, self.provider_settings)
|
|
425
|
-
if getattr(inst, "initialize", None):
|
|
426
|
-
await inst.initialize()
|
|
427
|
-
self.rerank_provider_insts.append(inst)
|
|
428
484
|
|
|
429
485
|
self.inst_map[provider_config["id"]] = inst
|
|
430
486
|
except Exception as e:
|
|
@@ -2,6 +2,7 @@ import abc
|
|
|
2
2
|
import asyncio
|
|
3
3
|
import os
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import TypeAlias, Union
|
|
5
6
|
|
|
6
7
|
from astrbot.core.agent.message import Message
|
|
7
8
|
from astrbot.core.agent.tool import ToolSet
|
|
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
|
|
|
14
15
|
from astrbot.core.provider.register import provider_cls_map
|
|
15
16
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
|
16
17
|
|
|
18
|
+
Providers: TypeAlias = Union[
|
|
19
|
+
"Provider",
|
|
20
|
+
"STTProvider",
|
|
21
|
+
"TTSProvider",
|
|
22
|
+
"EmbeddingProvider",
|
|
23
|
+
"RerankProvider",
|
|
24
|
+
]
|
|
25
|
+
|
|
17
26
|
|
|
18
27
|
class AbstractProvider(abc.ABC):
|
|
19
28
|
"""Provider Abstract Class"""
|
|
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
|
|
|
142
151
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
143
152
|
|
|
144
153
|
"""
|
|
145
|
-
|
|
154
|
+
if False: # pragma: no cover - make this an async generator for typing
|
|
155
|
+
yield None # type: ignore
|
|
156
|
+
raise NotImplementedError()
|
|
146
157
|
|
|
147
158
|
async def pop_record(self, context: list):
|
|
148
159
|
"""弹出 context 第一条非系统提示词对话记录"""
|
|
@@ -29,15 +29,24 @@ class OTTSProvider:
|
|
|
29
29
|
self.last_sync_time = 0
|
|
30
30
|
self.timeout = Timeout(10.0)
|
|
31
31
|
self.retry_count = 3
|
|
32
|
-
self.
|
|
32
|
+
self._client: AsyncClient | None = None
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def client(self) -> AsyncClient:
|
|
36
|
+
if self._client is None:
|
|
37
|
+
raise RuntimeError(
|
|
38
|
+
"Client not initialized. Please use 'async with' context."
|
|
39
|
+
)
|
|
40
|
+
return self._client
|
|
33
41
|
|
|
34
42
|
async def __aenter__(self):
|
|
35
|
-
self.
|
|
43
|
+
self._client = AsyncClient(timeout=self.timeout)
|
|
36
44
|
return self
|
|
37
45
|
|
|
38
46
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
39
|
-
if self.
|
|
40
|
-
await self.
|
|
47
|
+
if self._client:
|
|
48
|
+
await self._client.aclose()
|
|
49
|
+
self._client = None
|
|
41
50
|
|
|
42
51
|
async def _sync_time(self):
|
|
43
52
|
try:
|
|
@@ -90,6 +99,7 @@ class OTTSProvider:
|
|
|
90
99
|
if attempt == self.retry_count - 1:
|
|
91
100
|
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
|
92
101
|
await asyncio.sleep(0.5 * (attempt + 1))
|
|
102
|
+
raise RuntimeError("OTTS未返回音频文件")
|
|
93
103
|
|
|
94
104
|
|
|
95
105
|
class AzureNativeProvider(TTSProvider):
|
|
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
|
|
105
115
|
self.endpoint = (
|
|
106
116
|
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
|
107
117
|
)
|
|
108
|
-
self.
|
|
118
|
+
self._client: AsyncClient | None = None
|
|
109
119
|
self.token = None
|
|
110
120
|
self.token_expire = 0
|
|
111
121
|
self.voice_params = {
|
|
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
|
|
116
126
|
"volume": provider_config.get("azure_tts_volume", "100"),
|
|
117
127
|
}
|
|
118
128
|
|
|
129
|
+
@property
|
|
130
|
+
def client(self) -> AsyncClient:
|
|
131
|
+
if self._client is None:
|
|
132
|
+
raise RuntimeError(
|
|
133
|
+
"Client not initialized. Please use 'async with' context."
|
|
134
|
+
)
|
|
135
|
+
return self._client
|
|
136
|
+
|
|
119
137
|
async def __aenter__(self):
|
|
120
|
-
self.
|
|
138
|
+
self._client = AsyncClient(
|
|
121
139
|
headers={
|
|
122
140
|
"User-Agent": f"AstrBot/{VERSION}",
|
|
123
141
|
"Content-Type": "application/ssml+xml",
|
|
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
|
|
127
145
|
return self
|
|
128
146
|
|
|
129
147
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
130
|
-
if self.
|
|
131
|
-
await self.
|
|
148
|
+
if self._client:
|
|
149
|
+
await self._client.aclose()
|
|
150
|
+
self._client = None
|
|
132
151
|
|
|
133
152
|
async def _refresh_token(self):
|
|
134
153
|
token_url = (
|
|
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
|
|
181
200
|
key_value = provider_config.get("azure_tts_subscription_key", "")
|
|
182
201
|
self.provider = self._parse_provider(key_value, provider_config)
|
|
183
202
|
|
|
184
|
-
def _parse_provider(
|
|
203
|
+
def _parse_provider(
|
|
204
|
+
self, key_value: str, config: dict
|
|
205
|
+
) -> OTTSProvider | AzureNativeProvider:
|
|
185
206
|
if key_value.lower().startswith("other["):
|
|
207
|
+
json_str = ""
|
|
186
208
|
try:
|
|
187
209
|
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
|
188
210
|
if not match:
|
|
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|
|
36
36
|
super().__init__(provider_config, provider_settings)
|
|
37
37
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
|
38
38
|
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
|
39
|
-
self.set_model(provider_config
|
|
39
|
+
self.set_model(provider_config["model"])
|
|
40
40
|
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
|
41
41
|
dashscope.api_key = self.chosen_api_key
|
|
42
42
|
|
|
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|
|
71
71
|
|
|
72
72
|
kwargs = {
|
|
73
73
|
"model": model,
|
|
74
|
-
"
|
|
74
|
+
"messages": None,
|
|
75
75
|
"api_key": self.chosen_api_key,
|
|
76
76
|
"voice": self.voice or "Cherry",
|
|
77
|
+
"text": text,
|
|
77
78
|
}
|
|
78
79
|
if not self.voice:
|
|
79
80
|
logging.warning(
|
|
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
|
|
67
67
|
from pyffmpeg import FFmpeg
|
|
68
68
|
|
|
69
69
|
ff = FFmpeg()
|
|
70
|
-
ff.convert(
|
|
70
|
+
ff.convert(input_file=mp3_path, output_file=wav_path)
|
|
71
71
|
except Exception as e:
|
|
72
72
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
|
73
73
|
# use ffmpeg command line
|
|
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
59
59
|
self.headers = {
|
|
60
60
|
"Authorization": f"Bearer {self.chosen_api_key}",
|
|
61
61
|
}
|
|
62
|
-
self.set_model(provider_config
|
|
62
|
+
self.set_model(provider_config["model"])
|
|
63
63
|
|
|
64
|
-
async def _get_reference_id_by_character(self, character: str) -> str:
|
|
64
|
+
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
|
65
65
|
"""获取角色的reference_id
|
|
66
66
|
|
|
67
67
|
Args:
|
|
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
109
109
|
pattern = r"^[a-fA-F0-9]{32}$"
|
|
110
110
|
return bool(re.match(pattern, reference_id.strip()))
|
|
111
111
|
|
|
112
|
-
async def _generate_request(self, text: str) ->
|
|
112
|
+
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
|
113
113
|
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
|
114
114
|
if self.reference_id and self.reference_id.strip():
|
|
115
115
|
# 验证reference_id格式
|
|
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|
|
146
146
|
async for chunk in response.aiter_bytes():
|
|
147
147
|
f.write(chunk)
|
|
148
148
|
return path
|
|
149
|
-
|
|
149
|
+
body = await response.aread()
|
|
150
|
+
text = body.decode("utf-8", errors="replace")
|
|
150
151
|
raise Exception(f"Fish Audio API请求失败: {text}")
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import cast
|
|
2
|
+
|
|
1
3
|
from google import genai
|
|
2
4
|
from google.genai import types
|
|
3
5
|
from google.genai.errors import APIError
|
|
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
18
20
|
self.provider_config = provider_config
|
|
19
21
|
self.provider_settings = provider_settings
|
|
20
22
|
|
|
21
|
-
api_key: str = provider_config
|
|
22
|
-
api_base: str = provider_config
|
|
23
|
+
api_key: str = provider_config["embedding_api_key"]
|
|
24
|
+
api_base: str = provider_config["embedding_api_base"]
|
|
23
25
|
timeout: int = int(provider_config.get("timeout", 20))
|
|
24
26
|
|
|
25
27
|
http_options = types.HttpOptions(timeout=timeout * 1000)
|
|
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|
|
41
43
|
model=self.model,
|
|
42
44
|
contents=text,
|
|
43
45
|
)
|
|
46
|
+
assert result.embeddings is not None
|
|
47
|
+
assert result.embeddings[0].values is not None
|
|
44
48
|
return result.embeddings[0].values
|
|
45
49
|
except APIError as e:
|
|
46
50
|
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
|
47
51
|
|
|
48
|
-
async def get_embeddings(self,
|
|
52
|
+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
49
53
|
"""批量获取文本的嵌入"""
|
|
50
54
|
try:
|
|
51
55
|
result = await self.client.models.embed_content(
|
|
52
56
|
model=self.model,
|
|
53
|
-
contents=
|
|
57
|
+
contents=cast(types.ContentListUnion, text),
|
|
54
58
|
)
|
|
55
|
-
|
|
59
|
+
assert result.embeddings is not None
|
|
60
|
+
|
|
61
|
+
embeddings: list[list[float]] = []
|
|
62
|
+
for embedding in result.embeddings:
|
|
63
|
+
assert embedding.values is not None
|
|
64
|
+
embeddings.append(embedding.values)
|
|
65
|
+
return embeddings
|
|
56
66
|
except APIError as e:
|
|
57
67
|
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
|
58
68
|
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import logging
|
|
5
5
|
import random
|
|
6
6
|
from collections.abc import AsyncGenerator
|
|
7
|
+
from typing import cast
|
|
7
8
|
|
|
8
9
|
from google import genai
|
|
9
10
|
from google.genai import types
|
|
@@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider):
|
|
|
126
127
|
) -> types.GenerateContentConfig:
|
|
127
128
|
"""准备查询配置"""
|
|
128
129
|
if not modalities:
|
|
129
|
-
modalities = ["
|
|
130
|
+
modalities = ["TEXT"]
|
|
130
131
|
|
|
131
132
|
# 流式输出不支持图片模态
|
|
132
133
|
if (
|
|
133
134
|
self.provider_settings.get("streaming_response", False)
|
|
134
|
-
and "
|
|
135
|
+
and "IMAGE" in modalities
|
|
135
136
|
):
|
|
136
137
|
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
|
137
|
-
modalities = ["
|
|
138
|
+
modalities = ["TEXT"]
|
|
138
139
|
|
|
139
|
-
tool_list = []
|
|
140
|
+
tool_list: list[types.Tool] | None = []
|
|
140
141
|
model_name = self.get_model()
|
|
141
142
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
|
142
143
|
native_search = self.provider_config.get("gm_native_search", False)
|
|
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
213
214
|
logprobs=payloads.get("logprobs"),
|
|
214
215
|
seed=payloads.get("seed"),
|
|
215
216
|
response_modalities=modalities,
|
|
216
|
-
tools=tool_list,
|
|
217
|
+
tools=cast(types.ToolListUnion | None, tool_list),
|
|
217
218
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
|
218
219
|
thinking_config=(
|
|
219
220
|
types.ThinkingConfig(
|
|
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
257
258
|
content_cls: type[types.Content],
|
|
258
259
|
) -> None:
|
|
259
260
|
if contents and isinstance(contents[-1], content_cls):
|
|
261
|
+
assert contents[-1].parts is not None
|
|
260
262
|
contents[-1].parts.extend(part)
|
|
261
263
|
else:
|
|
262
264
|
contents.append(content_cls(parts=part))
|
|
@@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider):
|
|
|
429
431
|
None,
|
|
430
432
|
)
|
|
431
433
|
|
|
432
|
-
modalities = ["
|
|
434
|
+
modalities = ["TEXT"]
|
|
433
435
|
if self.provider_config.get("gm_resp_image_modal", False):
|
|
434
|
-
modalities.append("
|
|
436
|
+
modalities.append("IMAGE")
|
|
435
437
|
|
|
436
438
|
conversation = self._prepare_conversation(payloads)
|
|
437
439
|
temperature = payloads.get("temperature", 0.7)
|
|
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
448
450
|
)
|
|
449
451
|
result = await self.client.models.generate_content(
|
|
450
452
|
model=self.get_model(),
|
|
451
|
-
contents=conversation,
|
|
453
|
+
contents=cast(types.ContentListUnion, conversation),
|
|
452
454
|
config=config,
|
|
453
455
|
)
|
|
454
456
|
logger.debug(f"genai result: {result}")
|
|
@@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
488
490
|
logger.warning(
|
|
489
491
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
|
490
492
|
)
|
|
491
|
-
modalities = ["
|
|
493
|
+
modalities = ["TEXT"]
|
|
492
494
|
else:
|
|
493
495
|
raise
|
|
494
496
|
continue
|
|
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
|
|
|
524
526
|
)
|
|
525
527
|
result = await self.client.models.generate_content_stream(
|
|
526
528
|
model=self.get_model(),
|
|
527
|
-
contents=conversation,
|
|
529
|
+
contents=cast(types.ContentListUnion, conversation),
|
|
528
530
|
config=config,
|
|
529
531
|
)
|
|
530
532
|
break
|
|
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
87
87
|
|
|
88
88
|
return json.dumps(dict_body)
|
|
89
89
|
|
|
90
|
-
async def _call_tts_stream(self, text: str) -> AsyncIterator[
|
|
90
|
+
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
|
91
91
|
"""进行流式请求"""
|
|
92
92
|
try:
|
|
93
93
|
async with (
|
|
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
117
117
|
data = json.loads(message[6:])
|
|
118
118
|
if "extra_info" in data:
|
|
119
119
|
continue
|
|
120
|
-
audio = data.get("data", {}).get(
|
|
120
|
+
audio: str | None = data.get("data", {}).get(
|
|
121
|
+
"audio"
|
|
122
|
+
)
|
|
121
123
|
if audio is not None:
|
|
122
124
|
yield audio
|
|
123
125
|
except json.JSONDecodeError:
|
|
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
30
30
|
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
|
31
31
|
return embedding.data[0].embedding
|
|
32
32
|
|
|
33
|
-
async def get_embeddings(self,
|
|
33
|
+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
34
34
|
"""批量获取文本的嵌入"""
|
|
35
|
-
embeddings = await self.client.embeddings.create(input=
|
|
35
|
+
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
|
36
36
|
return [item.embedding for item in embeddings.data]
|
|
37
37
|
|
|
38
38
|
def get_dim(self) -> int:
|
|
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
284
284
|
if isinstance(tool_call, str):
|
|
285
285
|
# workaround for #1359
|
|
286
286
|
tool_call = json.loads(tool_call)
|
|
287
|
+
if tools is None:
|
|
288
|
+
# 工具集未提供
|
|
289
|
+
# Should be unreachable
|
|
290
|
+
raise Exception("工具集未提供")
|
|
287
291
|
for tool in tools.func_list:
|
|
288
292
|
if (
|
|
289
293
|
tool_call.type == "function"
|