AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
import secrets
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from xml.sax.saxutils import escape
|
|
10
|
+
|
|
11
|
+
from httpx import AsyncClient, Timeout
|
|
12
|
+
|
|
13
|
+
from astrbot.core.config.default import VERSION
|
|
14
|
+
|
|
15
|
+
from ..entities import ProviderType
|
|
16
|
+
from ..provider import TTSProvider
|
|
17
|
+
from ..register import register_provider_adapter
|
|
18
|
+
|
|
19
|
+
TEMP_DIR = Path("data/temp/azure_tts")
|
|
20
|
+
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OTTSProvider:
|
|
24
|
+
def __init__(self, config: dict):
|
|
25
|
+
self.skey = config["OTTS_SKEY"]
|
|
26
|
+
self.api_url = config["OTTS_URL"]
|
|
27
|
+
self.auth_time_url = config["OTTS_AUTH_TIME"]
|
|
28
|
+
self.time_offset = 0
|
|
29
|
+
self.last_sync_time = 0
|
|
30
|
+
self.timeout = Timeout(10.0)
|
|
31
|
+
self.retry_count = 3
|
|
32
|
+
self.client = None
|
|
33
|
+
|
|
34
|
+
async def __aenter__(self):
|
|
35
|
+
self.client = AsyncClient(timeout=self.timeout)
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
39
|
+
if self.client:
|
|
40
|
+
await self.client.aclose()
|
|
41
|
+
|
|
42
|
+
async def _sync_time(self):
|
|
43
|
+
try:
|
|
44
|
+
response = await self.client.get(self.auth_time_url)
|
|
45
|
+
response.raise_for_status()
|
|
46
|
+
server_time = int(response.json()["timestamp"])
|
|
47
|
+
local_time = int(time.time())
|
|
48
|
+
self.time_offset = server_time - local_time
|
|
49
|
+
self.last_sync_time = local_time
|
|
50
|
+
except Exception as e:
|
|
51
|
+
if time.time() - self.last_sync_time > 3600:
|
|
52
|
+
raise RuntimeError("时间同步失败") from e
|
|
53
|
+
|
|
54
|
+
async def _generate_signature(self) -> str:
|
|
55
|
+
await self._sync_time()
|
|
56
|
+
timestamp = int(time.time()) + self.time_offset
|
|
57
|
+
nonce = "".join(
|
|
58
|
+
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
|
|
59
|
+
)
|
|
60
|
+
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
|
61
|
+
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
|
62
|
+
|
|
63
|
+
async def get_audio(self, text: str, voice_params: dict) -> str:
|
|
64
|
+
file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
|
|
65
|
+
signature = await self._generate_signature()
|
|
66
|
+
for attempt in range(self.retry_count):
|
|
67
|
+
try:
|
|
68
|
+
response = await self.client.post(
|
|
69
|
+
f"{self.api_url}?sign={signature}",
|
|
70
|
+
data={
|
|
71
|
+
"text": text,
|
|
72
|
+
"voice": voice_params["voice"],
|
|
73
|
+
"style": voice_params["style"],
|
|
74
|
+
"role": voice_params["role"],
|
|
75
|
+
"rate": voice_params["rate"],
|
|
76
|
+
"volume": voice_params["volume"],
|
|
77
|
+
},
|
|
78
|
+
headers={
|
|
79
|
+
"User-Agent": f"AstrBot/{VERSION}",
|
|
80
|
+
"UAK": "AstrBot/AzureTTS",
|
|
81
|
+
},
|
|
82
|
+
)
|
|
83
|
+
response.raise_for_status()
|
|
84
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
with file_path.open("wb") as f:
|
|
86
|
+
async for chunk in response.aiter_bytes(4096):
|
|
87
|
+
f.write(chunk)
|
|
88
|
+
return str(file_path.resolve())
|
|
89
|
+
except Exception as e:
|
|
90
|
+
if attempt == self.retry_count - 1:
|
|
91
|
+
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
|
92
|
+
await asyncio.sleep(0.5 * (attempt + 1))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class AzureNativeProvider(TTSProvider):
|
|
96
|
+
def __init__(self, provider_config: dict, provider_settings: dict):
|
|
97
|
+
super().__init__(provider_config, provider_settings)
|
|
98
|
+
self.subscription_key = provider_config.get(
|
|
99
|
+
"azure_tts_subscription_key",
|
|
100
|
+
"",
|
|
101
|
+
).strip()
|
|
102
|
+
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
|
103
|
+
raise ValueError("无效的Azure订阅密钥")
|
|
104
|
+
self.region = provider_config.get("azure_tts_region", "eastus").strip()
|
|
105
|
+
self.endpoint = (
|
|
106
|
+
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
|
107
|
+
)
|
|
108
|
+
self.client = None
|
|
109
|
+
self.token = None
|
|
110
|
+
self.token_expire = 0
|
|
111
|
+
self.voice_params = {
|
|
112
|
+
"voice": provider_config.get("azure_tts_voice", "zh-CN-YunxiaNeural"),
|
|
113
|
+
"style": provider_config.get("azure_tts_style", "cheerful"),
|
|
114
|
+
"role": provider_config.get("azure_tts_role", "Boy"),
|
|
115
|
+
"rate": provider_config.get("azure_tts_rate", "1"),
|
|
116
|
+
"volume": provider_config.get("azure_tts_volume", "100"),
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
async def __aenter__(self):
|
|
120
|
+
self.client = AsyncClient(
|
|
121
|
+
headers={
|
|
122
|
+
"User-Agent": f"AstrBot/{VERSION}",
|
|
123
|
+
"Content-Type": "application/ssml+xml",
|
|
124
|
+
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
130
|
+
if self.client:
|
|
131
|
+
await self.client.aclose()
|
|
132
|
+
|
|
133
|
+
async def _refresh_token(self):
|
|
134
|
+
token_url = (
|
|
135
|
+
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
|
136
|
+
)
|
|
137
|
+
response = await self.client.post(
|
|
138
|
+
token_url,
|
|
139
|
+
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
|
|
140
|
+
)
|
|
141
|
+
response.raise_for_status()
|
|
142
|
+
self.token = response.text
|
|
143
|
+
self.token_expire = time.time() + 540
|
|
144
|
+
|
|
145
|
+
async def get_audio(self, text: str) -> str:
|
|
146
|
+
if not self.token or time.time() > self.token_expire:
|
|
147
|
+
await self._refresh_token()
|
|
148
|
+
file_path = TEMP_DIR / f"azure-{uuid.uuid4()}.wav"
|
|
149
|
+
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis'
|
|
150
|
+
xmlns:mstts='http://www.w3.org/2001/mstts' xml:lang='zh-CN'>
|
|
151
|
+
<voice name='{escape(self.voice_params["voice"])}'>
|
|
152
|
+
<mstts:express-as style='{escape(self.voice_params["style"])}'
|
|
153
|
+
role='{escape(self.voice_params["role"])}'>
|
|
154
|
+
<prosody rate='{escape(self.voice_params["rate"])}'
|
|
155
|
+
volume='{escape(self.voice_params["volume"])}'>
|
|
156
|
+
{escape(text)}
|
|
157
|
+
</prosody>
|
|
158
|
+
</mstts:express-as>
|
|
159
|
+
</voice>
|
|
160
|
+
</speak>"""
|
|
161
|
+
response = await self.client.post(
|
|
162
|
+
self.endpoint,
|
|
163
|
+
content=ssml,
|
|
164
|
+
headers={
|
|
165
|
+
"Authorization": f"Bearer {self.token}",
|
|
166
|
+
"User-Agent": f"AstrBot/{VERSION}",
|
|
167
|
+
},
|
|
168
|
+
)
|
|
169
|
+
response.raise_for_status()
|
|
170
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
171
|
+
with file_path.open("wb") as f:
|
|
172
|
+
for chunk in response.iter_bytes(4096):
|
|
173
|
+
f.write(chunk)
|
|
174
|
+
return str(file_path.resolve())
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH)
|
|
178
|
+
class AzureTTSProvider(TTSProvider):
|
|
179
|
+
def __init__(self, provider_config: dict, provider_settings: dict):
|
|
180
|
+
super().__init__(provider_config, provider_settings)
|
|
181
|
+
key_value = provider_config.get("azure_tts_subscription_key", "")
|
|
182
|
+
self.provider = self._parse_provider(key_value, provider_config)
|
|
183
|
+
|
|
184
|
+
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
|
185
|
+
if key_value.lower().startswith("other["):
|
|
186
|
+
try:
|
|
187
|
+
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
|
188
|
+
if not match:
|
|
189
|
+
raise ValueError("无效的other[...]格式,应形如 other[{...}]")
|
|
190
|
+
json_str = match.group(1).strip()
|
|
191
|
+
otts_config = json.loads(json_str)
|
|
192
|
+
required = {"OTTS_SKEY", "OTTS_URL", "OTTS_AUTH_TIME"}
|
|
193
|
+
if missing := required - otts_config.keys():
|
|
194
|
+
raise ValueError(f"缺少OTTS参数: {', '.join(missing)}")
|
|
195
|
+
return OTTSProvider(otts_config)
|
|
196
|
+
except json.JSONDecodeError as e:
|
|
197
|
+
error_msg = (
|
|
198
|
+
f"JSON解析失败,请检查格式(错误位置:行 {e.lineno} 列 {e.colno})\n"
|
|
199
|
+
f"错误详情: {e.msg}\n"
|
|
200
|
+
f"错误上下文: {json_str[max(0, e.pos - 30) : e.pos + 30]}"
|
|
201
|
+
)
|
|
202
|
+
raise ValueError(error_msg) from e
|
|
203
|
+
except KeyError as e:
|
|
204
|
+
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
|
|
205
|
+
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
|
|
206
|
+
return AzureNativeProvider(config, self.provider_settings)
|
|
207
|
+
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
|
|
208
|
+
|
|
209
|
+
async def get_audio(self, text: str) -> str:
|
|
210
|
+
if isinstance(self.provider, OTTSProvider):
|
|
211
|
+
async with self.provider as provider:
|
|
212
|
+
return await provider.get_audio(
|
|
213
|
+
text,
|
|
214
|
+
{
|
|
215
|
+
"voice": self.provider_config.get("azure_tts_voice"),
|
|
216
|
+
"style": self.provider_config.get("azure_tts_style"),
|
|
217
|
+
"role": self.provider_config.get("azure_tts_role"),
|
|
218
|
+
"rate": self.provider_config.get("azure_tts_rate"),
|
|
219
|
+
"volume": self.provider_config.get("azure_tts_volume"),
|
|
220
|
+
},
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
async with self.provider as provider:
|
|
224
|
+
return await provider.get_audio(text)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import aiohttp
|
|
4
|
+
|
|
5
|
+
from astrbot import logger
|
|
6
|
+
|
|
7
|
+
from ..entities import ProviderType, RerankResult
|
|
8
|
+
from ..provider import RerankProvider
|
|
9
|
+
from ..register import register_provider_adapter
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BailianRerankError(Exception):
|
|
13
|
+
"""百炼重排序服务异常基类"""
|
|
14
|
+
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BailianAPIError(BailianRerankError):
|
|
19
|
+
"""百炼API返回错误"""
|
|
20
|
+
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BailianNetworkError(BailianRerankError):
|
|
25
|
+
"""百炼网络请求错误"""
|
|
26
|
+
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_provider_adapter(
|
|
31
|
+
"bailian_rerank", "阿里云百炼文本排序适配器", provider_type=ProviderType.RERANK
|
|
32
|
+
)
|
|
33
|
+
class BailianRerankProvider(RerankProvider):
|
|
34
|
+
"""阿里云百炼文本重排序适配器."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
37
|
+
super().__init__(provider_config, provider_settings)
|
|
38
|
+
self.provider_config = provider_config
|
|
39
|
+
self.provider_settings = provider_settings
|
|
40
|
+
|
|
41
|
+
# API配置
|
|
42
|
+
self.api_key = provider_config.get("rerank_api_key") or os.getenv(
|
|
43
|
+
"DASHSCOPE_API_KEY", ""
|
|
44
|
+
)
|
|
45
|
+
if not self.api_key:
|
|
46
|
+
raise ValueError("阿里云百炼 API Key 不能为空。")
|
|
47
|
+
|
|
48
|
+
self.model = provider_config.get("rerank_model", "qwen3-rerank")
|
|
49
|
+
self.timeout = provider_config.get("timeout", 30)
|
|
50
|
+
self.return_documents = provider_config.get("return_documents", False)
|
|
51
|
+
self.instruct = provider_config.get("instruct", "")
|
|
52
|
+
|
|
53
|
+
self.base_url = provider_config.get(
|
|
54
|
+
"rerank_api_base",
|
|
55
|
+
"https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# 设置HTTP客户端
|
|
59
|
+
headers = {
|
|
60
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
61
|
+
"Content-Type": "application/json",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
self.client = aiohttp.ClientSession(
|
|
65
|
+
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# 设置模型名称
|
|
69
|
+
self.set_model(self.model)
|
|
70
|
+
|
|
71
|
+
logger.info(f"AstrBot 百炼 Rerank 初始化完成。模型: {self.model}")
|
|
72
|
+
|
|
73
|
+
def _build_payload(
|
|
74
|
+
self, query: str, documents: list[str], top_n: int | None
|
|
75
|
+
) -> dict:
|
|
76
|
+
"""构建请求载荷
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
query: 查询文本
|
|
80
|
+
documents: 文档列表
|
|
81
|
+
top_n: 返回前N个结果,如果为None则返回所有结果
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
请求载荷字典
|
|
85
|
+
"""
|
|
86
|
+
base = {"model": self.model, "input": {"query": query, "documents": documents}}
|
|
87
|
+
|
|
88
|
+
params = {
|
|
89
|
+
k: v
|
|
90
|
+
for k, v in [
|
|
91
|
+
("top_n", top_n if top_n is not None and top_n > 0 else None),
|
|
92
|
+
("return_documents", True if self.return_documents else None),
|
|
93
|
+
(
|
|
94
|
+
"instruct",
|
|
95
|
+
self.instruct
|
|
96
|
+
if self.instruct and self.model == "qwen3-rerank"
|
|
97
|
+
else None,
|
|
98
|
+
),
|
|
99
|
+
]
|
|
100
|
+
if v is not None
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
if params:
|
|
104
|
+
base["parameters"] = params
|
|
105
|
+
|
|
106
|
+
return base
|
|
107
|
+
|
|
108
|
+
def _parse_results(self, data: dict) -> list[RerankResult]:
|
|
109
|
+
"""解析API响应结果
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
data: API响应数据
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
重排序结果列表
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
BailianAPIError: API返回错误
|
|
119
|
+
KeyError: 结果缺少必要字段
|
|
120
|
+
"""
|
|
121
|
+
# 检查响应状态
|
|
122
|
+
if data.get("code", "200") != "200":
|
|
123
|
+
raise BailianAPIError(
|
|
124
|
+
f"百炼 API 错误: {data.get('code')} – {data.get('message', '')}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
results = data.get("output", {}).get("results", [])
|
|
128
|
+
if not results:
|
|
129
|
+
logger.warning(f"百炼 Rerank 返回空结果: {data}")
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
# 转换为RerankResult对象,使用.get()避免KeyError
|
|
133
|
+
rerank_results = []
|
|
134
|
+
for idx, result in enumerate(results):
|
|
135
|
+
try:
|
|
136
|
+
index = result.get("index", idx)
|
|
137
|
+
relevance_score = result.get("relevance_score", 0.0)
|
|
138
|
+
|
|
139
|
+
if relevance_score is None:
|
|
140
|
+
logger.warning(f"结果 {idx} 缺少 relevance_score,使用默认值 0.0")
|
|
141
|
+
relevance_score = 0.0
|
|
142
|
+
|
|
143
|
+
rerank_result = RerankResult(
|
|
144
|
+
index=index, relevance_score=relevance_score
|
|
145
|
+
)
|
|
146
|
+
rerank_results.append(rerank_result)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"解析结果 {idx} 时出错: {e}, result={result}")
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
return rerank_results
|
|
152
|
+
|
|
153
|
+
def _log_usage(self, data: dict) -> None:
|
|
154
|
+
"""记录使用量信息
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
data: API响应数据
|
|
158
|
+
"""
|
|
159
|
+
tokens = data.get("usage", {}).get("total_tokens", 0)
|
|
160
|
+
if tokens > 0:
|
|
161
|
+
logger.debug(f"百炼 Rerank 消耗 Token: {tokens}")
|
|
162
|
+
|
|
163
|
+
async def rerank(
|
|
164
|
+
self,
|
|
165
|
+
query: str,
|
|
166
|
+
documents: list[str],
|
|
167
|
+
top_n: int | None = None,
|
|
168
|
+
) -> list[RerankResult]:
|
|
169
|
+
"""
|
|
170
|
+
对文档进行重排序
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
query: 查询文本
|
|
174
|
+
documents: 待排序的文档列表
|
|
175
|
+
top_n: 返回前N个结果,如果为None则使用配置中的默认值
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
重排序结果列表
|
|
179
|
+
"""
|
|
180
|
+
if not documents:
|
|
181
|
+
logger.warning("文档列表为空,返回空结果")
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
if not query.strip():
|
|
185
|
+
logger.warning("查询文本为空,返回空结果")
|
|
186
|
+
return []
|
|
187
|
+
|
|
188
|
+
# 检查限制
|
|
189
|
+
if len(documents) > 500:
|
|
190
|
+
logger.warning(
|
|
191
|
+
f"文档数量({len(documents)})超过限制(500),将截断前500个文档"
|
|
192
|
+
)
|
|
193
|
+
documents = documents[:500]
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
# 构建请求载荷,如果top_n为None则返回所有重排序结果
|
|
197
|
+
payload = self._build_payload(query, documents, top_n)
|
|
198
|
+
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"百炼 Rerank 请求: query='{query[:50]}...', 文档数量={len(documents)}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# 发送请求
|
|
204
|
+
async with self.client.post(self.base_url, json=payload) as response:
|
|
205
|
+
response.raise_for_status()
|
|
206
|
+
response_data = await response.json()
|
|
207
|
+
|
|
208
|
+
# 解析结果并记录使用量
|
|
209
|
+
results = self._parse_results(response_data)
|
|
210
|
+
self._log_usage(response_data)
|
|
211
|
+
|
|
212
|
+
logger.debug(f"百炼 Rerank 成功返回 {len(results)} 个结果")
|
|
213
|
+
|
|
214
|
+
return results
|
|
215
|
+
|
|
216
|
+
except aiohttp.ClientError as e:
|
|
217
|
+
error_msg = f"网络请求失败: {e}"
|
|
218
|
+
logger.error(f"百炼 Rerank 网络请求失败: {e}")
|
|
219
|
+
raise BailianNetworkError(error_msg) from e
|
|
220
|
+
except BailianRerankError:
|
|
221
|
+
raise
|
|
222
|
+
except Exception as e:
|
|
223
|
+
error_msg = f"重排序失败: {e}"
|
|
224
|
+
logger.error(f"百炼 Rerank 处理失败: {e}")
|
|
225
|
+
raise BailianRerankError(error_msg) from e
|
|
226
|
+
|
|
227
|
+
async def terminate(self) -> None:
|
|
228
|
+
"""关闭HTTP客户端会话."""
|
|
229
|
+
if self.client:
|
|
230
|
+
logger.info("关闭 百炼 Rerank 客户端会话")
|
|
231
|
+
try:
|
|
232
|
+
await self.client.close()
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.error(f"关闭 百炼 Rerank 客户端时出错: {e}")
|
|
235
|
+
finally:
|
|
236
|
+
self.client = None
|
|
@@ -1,14 +1,31 @@
|
|
|
1
|
-
import dashscope
|
|
2
|
-
import uuid
|
|
3
1
|
import asyncio
|
|
4
|
-
|
|
5
|
-
|
|
2
|
+
import base64
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
import aiohttp
|
|
8
|
+
import dashscope
|
|
9
|
+
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from dashscope.aigc.multimodal_conversation import MultiModalConversation
|
|
13
|
+
except (
|
|
14
|
+
ImportError
|
|
15
|
+
): # pragma: no cover - older dashscope versions without Qwen TTS support
|
|
16
|
+
MultiModalConversation = None
|
|
17
|
+
|
|
18
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
19
|
+
|
|
6
20
|
from ..entities import ProviderType
|
|
21
|
+
from ..provider import TTSProvider
|
|
7
22
|
from ..register import register_provider_adapter
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
@register_provider_adapter(
|
|
11
|
-
"dashscope_tts",
|
|
26
|
+
"dashscope_tts",
|
|
27
|
+
"Dashscope TTS API",
|
|
28
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
12
29
|
)
|
|
13
30
|
class ProviderDashscopeTTSAPI(TTSProvider):
|
|
14
31
|
def __init__(
|
|
@@ -19,20 +36,127 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|
|
19
36
|
super().__init__(provider_config, provider_settings)
|
|
20
37
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
|
21
38
|
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
|
22
|
-
self.set_model(provider_config.get("model"
|
|
39
|
+
self.set_model(provider_config.get("model"))
|
|
23
40
|
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
|
24
41
|
dashscope.api_key = self.chosen_api_key
|
|
25
42
|
|
|
26
43
|
async def get_audio(self, text: str) -> str:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
model
|
|
44
|
+
model = self.get_model()
|
|
45
|
+
if not model:
|
|
46
|
+
raise RuntimeError("Dashscope TTS model is not configured.")
|
|
47
|
+
|
|
48
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
49
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
50
|
+
|
|
51
|
+
if self._is_qwen_tts_model(model):
|
|
52
|
+
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
|
|
53
|
+
else:
|
|
54
|
+
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
|
|
55
|
+
|
|
56
|
+
if not audio_bytes:
|
|
57
|
+
raise RuntimeError(
|
|
58
|
+
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable.",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
|
|
62
|
+
with open(path, "wb") as f:
|
|
63
|
+
f.write(audio_bytes)
|
|
64
|
+
return path
|
|
65
|
+
|
|
66
|
+
def _call_qwen_tts(self, model: str, text: str):
|
|
67
|
+
if MultiModalConversation is None:
|
|
68
|
+
raise RuntimeError(
|
|
69
|
+
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models.",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
kwargs = {
|
|
73
|
+
"model": model,
|
|
74
|
+
"text": text,
|
|
75
|
+
"api_key": self.chosen_api_key,
|
|
76
|
+
"voice": self.voice or "Cherry",
|
|
77
|
+
}
|
|
78
|
+
if not self.voice:
|
|
79
|
+
logging.warning(
|
|
80
|
+
"No voice specified for Qwen TTS model, using default 'Cherry'.",
|
|
81
|
+
)
|
|
82
|
+
return MultiModalConversation.call(**kwargs)
|
|
83
|
+
|
|
84
|
+
async def _synthesize_with_qwen_tts(
|
|
85
|
+
self,
|
|
86
|
+
model: str,
|
|
87
|
+
text: str,
|
|
88
|
+
) -> tuple[bytes | None, str]:
|
|
89
|
+
loop = asyncio.get_event_loop()
|
|
90
|
+
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
|
91
|
+
audio_bytes = await self._extract_audio_from_response(response)
|
|
92
|
+
if not audio_bytes:
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
f"Audio synthesis failed for model '{model}'. {response}",
|
|
95
|
+
)
|
|
96
|
+
ext = ".wav"
|
|
97
|
+
return audio_bytes, ext
|
|
98
|
+
|
|
99
|
+
async def _extract_audio_from_response(self, response) -> bytes | None:
|
|
100
|
+
output = getattr(response, "output", None)
|
|
101
|
+
audio_obj = getattr(output, "audio", None) if output is not None else None
|
|
102
|
+
if not audio_obj:
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
data_b64 = getattr(audio_obj, "data", None)
|
|
106
|
+
if data_b64:
|
|
107
|
+
try:
|
|
108
|
+
return base64.b64decode(data_b64)
|
|
109
|
+
except (ValueError, TypeError):
|
|
110
|
+
logging.exception("Failed to decode base64 audio data.")
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
url = getattr(audio_obj, "url", None)
|
|
114
|
+
if url:
|
|
115
|
+
return await self._download_audio_from_url(url)
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
async def _download_audio_from_url(self, url: str) -> bytes | None:
|
|
119
|
+
if not url:
|
|
120
|
+
return None
|
|
121
|
+
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
|
|
122
|
+
try:
|
|
123
|
+
async with (
|
|
124
|
+
aiohttp.ClientSession() as session,
|
|
125
|
+
session.get(
|
|
126
|
+
url,
|
|
127
|
+
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
128
|
+
) as response,
|
|
129
|
+
):
|
|
130
|
+
return await response.read()
|
|
131
|
+
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
|
132
|
+
logging.exception(f"Failed to download audio from URL {url}: {e}")
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
async def _synthesize_with_cosyvoice(
|
|
136
|
+
self,
|
|
137
|
+
model: str,
|
|
138
|
+
text: str,
|
|
139
|
+
) -> tuple[bytes | None, str]:
|
|
140
|
+
synthesizer = SpeechSynthesizer(
|
|
141
|
+
model=model,
|
|
30
142
|
voice=self.voice,
|
|
31
143
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
|
32
144
|
)
|
|
33
|
-
|
|
34
|
-
|
|
145
|
+
loop = asyncio.get_event_loop()
|
|
146
|
+
audio_bytes = await loop.run_in_executor(
|
|
147
|
+
None,
|
|
148
|
+
synthesizer.call,
|
|
149
|
+
text,
|
|
150
|
+
self.timeout_ms,
|
|
35
151
|
)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
152
|
+
if not audio_bytes:
|
|
153
|
+
resp = synthesizer.get_response()
|
|
154
|
+
if resp and isinstance(resp, dict):
|
|
155
|
+
raise RuntimeError(
|
|
156
|
+
f"Audio synthesis failed for model '{model}'. {resp}".strip(),
|
|
157
|
+
)
|
|
158
|
+
return audio_bytes, ".wav"
|
|
159
|
+
|
|
160
|
+
def _is_qwen_tts_model(self, model: str) -> bool:
|
|
161
|
+
model_lower = model.lower()
|
|
162
|
+
return "tts" in model_lower and model_lower.startswith("qwen")
|