AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +4 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +12 -10
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +32 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +77 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
- astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +97 -75
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.1.dist-info/RECORD +0 -260
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
import uuid
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
|
|
4
6
|
import aiohttp
|
|
5
|
-
|
|
6
|
-
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
7
|
+
|
|
7
8
|
from astrbot.api import logger
|
|
9
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
10
|
+
|
|
8
11
|
from ..entities import ProviderType
|
|
9
12
|
from ..provider import TTSProvider
|
|
10
13
|
from ..register import register_provider_adapter
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
@register_provider_adapter(
|
|
14
|
-
"minimax_tts_api",
|
|
17
|
+
"minimax_tts_api",
|
|
18
|
+
"MiniMax TTS API",
|
|
19
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
15
20
|
)
|
|
16
21
|
class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
17
22
|
def __init__(
|
|
@@ -22,19 +27,21 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
22
27
|
super().__init__(provider_config, provider_settings)
|
|
23
28
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
|
24
29
|
self.api_base: str = provider_config.get(
|
|
25
|
-
"api_base",
|
|
30
|
+
"api_base",
|
|
31
|
+
"https://api.minimax.chat/v1/t2a_v2",
|
|
26
32
|
)
|
|
27
33
|
self.group_id: str = provider_config.get("minimax-group-id", "")
|
|
28
34
|
self.set_model(provider_config.get("model", ""))
|
|
29
35
|
self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
|
|
30
36
|
self.is_timber_weight: bool = provider_config.get(
|
|
31
|
-
"minimax-is-timber-weight",
|
|
37
|
+
"minimax-is-timber-weight",
|
|
38
|
+
False,
|
|
32
39
|
)
|
|
33
|
-
self.timber_weight:
|
|
40
|
+
self.timber_weight: list[dict[str, str | int]] = json.loads(
|
|
34
41
|
provider_config.get(
|
|
35
42
|
"minimax-timber-weight",
|
|
36
43
|
'[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
|
|
37
|
-
)
|
|
44
|
+
),
|
|
38
45
|
)
|
|
39
46
|
|
|
40
47
|
self.voice_setting: dict = {
|
|
@@ -47,7 +54,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
47
54
|
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
|
48
55
|
"latex_read": provider_config.get("minimax-voice-latex", False),
|
|
49
56
|
"english_normalization": provider_config.get(
|
|
50
|
-
"minimax-voice-english-normalization",
|
|
57
|
+
"minimax-voice-english-normalization",
|
|
58
|
+
False,
|
|
51
59
|
),
|
|
52
60
|
}
|
|
53
61
|
|
|
@@ -66,7 +74,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
66
74
|
|
|
67
75
|
def _build_tts_stream_body(self, text: str):
|
|
68
76
|
"""构建流式请求体"""
|
|
69
|
-
dict_body:
|
|
77
|
+
dict_body: dict[str, object] = {
|
|
70
78
|
"model": self.model_name,
|
|
71
79
|
"text": text,
|
|
72
80
|
"stream": True,
|
|
@@ -82,44 +90,46 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
|
82
90
|
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
|
83
91
|
"""进行流式请求"""
|
|
84
92
|
try:
|
|
85
|
-
async with
|
|
86
|
-
|
|
93
|
+
async with (
|
|
94
|
+
aiohttp.ClientSession() as session,
|
|
95
|
+
session.post(
|
|
87
96
|
self.concat_base_url,
|
|
88
97
|
headers=self.headers,
|
|
89
98
|
data=self._build_tts_stream_body(text),
|
|
90
99
|
timeout=aiohttp.ClientTimeout(total=60),
|
|
91
|
-
) as response
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
audio = data.get("data", {}).get("audio")
|
|
111
|
-
if audio is not None:
|
|
112
|
-
yield audio
|
|
113
|
-
except json.JSONDecodeError:
|
|
114
|
-
logger.warning(
|
|
115
|
-
"Failed to parse JSON data from SSE message"
|
|
116
|
-
)
|
|
100
|
+
) as response,
|
|
101
|
+
):
|
|
102
|
+
response.raise_for_status()
|
|
103
|
+
|
|
104
|
+
buffer = b""
|
|
105
|
+
while True:
|
|
106
|
+
chunk = await response.content.read(8192)
|
|
107
|
+
if not chunk:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
buffer += chunk
|
|
111
|
+
|
|
112
|
+
while b"\n\n" in buffer:
|
|
113
|
+
try:
|
|
114
|
+
message, buffer = buffer.split(b"\n\n", 1)
|
|
115
|
+
if message.startswith(b"data: "):
|
|
116
|
+
try:
|
|
117
|
+
data = json.loads(message[6:])
|
|
118
|
+
if "extra_info" in data:
|
|
117
119
|
continue
|
|
118
|
-
|
|
119
|
-
|
|
120
|
+
audio = data.get("data", {}).get("audio")
|
|
121
|
+
if audio is not None:
|
|
122
|
+
yield audio
|
|
123
|
+
except json.JSONDecodeError:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"Failed to parse JSON data from SSE message",
|
|
126
|
+
)
|
|
127
|
+
continue
|
|
128
|
+
except ValueError:
|
|
129
|
+
buffer = buffer[-1024:]
|
|
120
130
|
|
|
121
131
|
except aiohttp.ClientError as e:
|
|
122
|
-
raise Exception(f"MiniMax TTS API请求失败: {
|
|
132
|
+
raise Exception(f"MiniMax TTS API请求失败: {e!s}")
|
|
123
133
|
|
|
124
134
|
async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
|
|
125
135
|
"""解码数据流到 audio 比特流"""
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from openai import AsyncOpenAI
|
|
2
|
+
|
|
3
|
+
from ..entities import ProviderType
|
|
2
4
|
from ..provider import EmbeddingProvider
|
|
3
5
|
from ..register import register_provider_adapter
|
|
4
|
-
from ..entities import ProviderType
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
@register_provider_adapter(
|
|
@@ -17,23 +18,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
|
17
18
|
self.client = AsyncOpenAI(
|
|
18
19
|
api_key=provider_config.get("embedding_api_key"),
|
|
19
20
|
base_url=provider_config.get(
|
|
20
|
-
"embedding_api_base",
|
|
21
|
+
"embedding_api_base",
|
|
22
|
+
"https://api.openai.com/v1",
|
|
21
23
|
),
|
|
22
24
|
timeout=int(provider_config.get("timeout", 20)),
|
|
23
25
|
)
|
|
24
26
|
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
|
25
27
|
|
|
26
28
|
async def get_embedding(self, text: str) -> list[float]:
|
|
27
|
-
"""
|
|
28
|
-
获取文本的嵌入
|
|
29
|
-
"""
|
|
29
|
+
"""获取文本的嵌入"""
|
|
30
30
|
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
|
31
31
|
return embedding.data[0].embedding
|
|
32
32
|
|
|
33
33
|
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
34
|
-
"""
|
|
35
|
-
批量获取文本的嵌入
|
|
36
|
-
"""
|
|
34
|
+
"""批量获取文本的嵌入"""
|
|
37
35
|
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
|
38
36
|
return [item.embedding for item in embeddings.data]
|
|
39
37
|
|
|
@@ -1,29 +1,31 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import base64
|
|
3
|
+
import inspect
|
|
2
4
|
import json
|
|
3
5
|
import os
|
|
4
|
-
import inspect
|
|
5
6
|
import random
|
|
6
|
-
import
|
|
7
|
-
import astrbot.core.message.components as Comp
|
|
8
|
-
|
|
9
|
-
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
|
10
|
-
from openai.types.chat.chat_completion import ChatCompletion
|
|
7
|
+
from collections.abc import AsyncGenerator
|
|
11
8
|
|
|
9
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
12
10
|
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
|
13
11
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
|
14
|
-
from
|
|
15
|
-
from astrbot.core.message.message_event_result import MessageChain
|
|
12
|
+
from openai.types.chat.chat_completion import ChatCompletion
|
|
16
13
|
|
|
17
|
-
|
|
14
|
+
import astrbot.core.message.components as Comp
|
|
18
15
|
from astrbot import logger
|
|
19
|
-
from astrbot.
|
|
20
|
-
from
|
|
21
|
-
from
|
|
16
|
+
from astrbot.api.provider import Provider
|
|
17
|
+
from astrbot.core.agent.message import Message
|
|
18
|
+
from astrbot.core.agent.tool import ToolSet
|
|
19
|
+
from astrbot.core.message.message_event_result import MessageChain
|
|
22
20
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
|
21
|
+
from astrbot.core.utils.io import download_image_by_url
|
|
22
|
+
|
|
23
|
+
from ..register import register_provider_adapter
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
@register_provider_adapter(
|
|
26
|
-
"openai_chat_completion",
|
|
27
|
+
"openai_chat_completion",
|
|
28
|
+
"OpenAI API Chat Completion 提供商适配器",
|
|
27
29
|
)
|
|
28
30
|
class ProviderOpenAIOfficial(Provider):
|
|
29
31
|
def __init__(
|
|
@@ -38,7 +40,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
38
40
|
default_persona,
|
|
39
41
|
)
|
|
40
42
|
self.chosen_api_key = None
|
|
41
|
-
self.api_keys:
|
|
43
|
+
self.api_keys: list = super().get_keys()
|
|
42
44
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
|
43
45
|
self.timeout = provider_config.get("timeout", 120)
|
|
44
46
|
if isinstance(self.timeout, str):
|
|
@@ -61,7 +63,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
61
63
|
)
|
|
62
64
|
|
|
63
65
|
self.default_params = inspect.signature(
|
|
64
|
-
self.client.chat.completions.create
|
|
66
|
+
self.client.chat.completions.create,
|
|
65
67
|
).parameters.keys()
|
|
66
68
|
|
|
67
69
|
model_config = provider_config.get("model_config", {})
|
|
@@ -101,12 +103,12 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
101
103
|
except NotFoundError as e:
|
|
102
104
|
raise Exception(f"获取模型列表失败:{e}")
|
|
103
105
|
|
|
104
|
-
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
|
106
|
+
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
|
105
107
|
if tools:
|
|
106
108
|
model = payloads.get("model", "").lower()
|
|
107
109
|
omit_empty_param_field = "gemini" in model
|
|
108
110
|
tool_list = tools.get_func_desc_openai_style(
|
|
109
|
-
omit_empty_parameter_field=omit_empty_param_field
|
|
111
|
+
omit_empty_parameter_field=omit_empty_param_field,
|
|
110
112
|
)
|
|
111
113
|
if tool_list:
|
|
112
114
|
payloads["tools"] = tool_list
|
|
@@ -114,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
114
116
|
# 不在默认参数中的参数放在 extra_body 中
|
|
115
117
|
extra_body = {}
|
|
116
118
|
to_del = []
|
|
117
|
-
for key in payloads
|
|
119
|
+
for key in payloads:
|
|
118
120
|
if key not in self.default_params:
|
|
119
121
|
extra_body[key] = payloads[key]
|
|
120
122
|
to_del.append(key)
|
|
@@ -133,12 +135,14 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
133
135
|
del payloads["tools"]
|
|
134
136
|
|
|
135
137
|
completion = await self.client.chat.completions.create(
|
|
136
|
-
**payloads,
|
|
138
|
+
**payloads,
|
|
139
|
+
stream=False,
|
|
140
|
+
extra_body=extra_body,
|
|
137
141
|
)
|
|
138
142
|
|
|
139
143
|
if not isinstance(completion, ChatCompletion):
|
|
140
144
|
raise Exception(
|
|
141
|
-
f"API 返回的 completion 类型错误:{type(completion)}: {completion}。"
|
|
145
|
+
f"API 返回的 completion 类型错误:{type(completion)}: {completion}。",
|
|
142
146
|
)
|
|
143
147
|
|
|
144
148
|
logger.debug(f"completion: {completion}")
|
|
@@ -148,14 +152,16 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
148
152
|
return llm_response
|
|
149
153
|
|
|
150
154
|
async def _query_stream(
|
|
151
|
-
self,
|
|
155
|
+
self,
|
|
156
|
+
payloads: dict,
|
|
157
|
+
tools: ToolSet | None,
|
|
152
158
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
153
159
|
"""流式查询API,逐步返回结果"""
|
|
154
160
|
if tools:
|
|
155
161
|
model = payloads.get("model", "").lower()
|
|
156
162
|
omit_empty_param_field = "gemini" in model
|
|
157
163
|
tool_list = tools.get_func_desc_openai_style(
|
|
158
|
-
omit_empty_parameter_field=omit_empty_param_field
|
|
164
|
+
omit_empty_parameter_field=omit_empty_param_field,
|
|
159
165
|
)
|
|
160
166
|
if tool_list:
|
|
161
167
|
payloads["tools"] = tool_list
|
|
@@ -169,7 +175,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
169
175
|
extra_body.update(custom_extra_body)
|
|
170
176
|
|
|
171
177
|
to_del = []
|
|
172
|
-
for key in payloads
|
|
178
|
+
for key in payloads:
|
|
173
179
|
if key not in self.default_params:
|
|
174
180
|
extra_body[key] = payloads[key]
|
|
175
181
|
to_del.append(key)
|
|
@@ -177,7 +183,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
177
183
|
del payloads[key]
|
|
178
184
|
|
|
179
185
|
stream = await self.client.chat.completions.create(
|
|
180
|
-
**payloads,
|
|
186
|
+
**payloads,
|
|
187
|
+
stream=True,
|
|
188
|
+
extra_body=extra_body,
|
|
181
189
|
)
|
|
182
190
|
|
|
183
191
|
llm_response = LLMResponse("assistant", is_chunk=True)
|
|
@@ -196,7 +204,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
196
204
|
if delta.content:
|
|
197
205
|
completion_text = delta.content
|
|
198
206
|
llm_response.result_chain = MessageChain(
|
|
199
|
-
chain=[Comp.Plain(completion_text)]
|
|
207
|
+
chain=[Comp.Plain(completion_text)],
|
|
200
208
|
)
|
|
201
209
|
yield llm_response
|
|
202
210
|
|
|
@@ -205,7 +213,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
205
213
|
|
|
206
214
|
yield llm_response
|
|
207
215
|
|
|
208
|
-
async def parse_openai_completion(
|
|
216
|
+
async def parse_openai_completion(
|
|
217
|
+
self, completion: ChatCompletion, tools: ToolSet | None
|
|
218
|
+
) -> LLMResponse:
|
|
209
219
|
"""解析 OpenAI 的 ChatCompletion 响应"""
|
|
210
220
|
llm_response = LLMResponse("assistant")
|
|
211
221
|
|
|
@@ -218,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
218
228
|
completion_text = str(choice.message.content).strip()
|
|
219
229
|
llm_response.result_chain = MessageChain().message(completion_text)
|
|
220
230
|
|
|
221
|
-
if choice.message.tool_calls:
|
|
231
|
+
if choice.message.tool_calls and tools is not None:
|
|
222
232
|
# tools call (function calling)
|
|
223
233
|
args_ls = []
|
|
224
234
|
func_name_ls = []
|
|
@@ -247,7 +257,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
247
257
|
|
|
248
258
|
if choice.finish_reason == "content_filter":
|
|
249
259
|
raise Exception(
|
|
250
|
-
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
|
|
260
|
+
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
|
|
251
261
|
)
|
|
252
262
|
|
|
253
263
|
if llm_response.completion_text is None and not llm_response.tools_call_args:
|
|
@@ -260,9 +270,9 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
260
270
|
|
|
261
271
|
async def _prepare_chat_payload(
|
|
262
272
|
self,
|
|
263
|
-
prompt: str,
|
|
273
|
+
prompt: str | None,
|
|
264
274
|
image_urls: list[str] | None = None,
|
|
265
|
-
contexts: list | None = None,
|
|
275
|
+
contexts: list[dict] | list[Message] | None = None,
|
|
266
276
|
system_prompt: str | None = None,
|
|
267
277
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
268
278
|
model: str | None = None,
|
|
@@ -271,8 +281,12 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
271
281
|
"""准备聊天所需的有效载荷和上下文"""
|
|
272
282
|
if contexts is None:
|
|
273
283
|
contexts = []
|
|
274
|
-
new_record =
|
|
275
|
-
|
|
284
|
+
new_record = None
|
|
285
|
+
if prompt is not None:
|
|
286
|
+
new_record = await self.assemble_context(prompt, image_urls)
|
|
287
|
+
context_query = self._ensure_message_to_dicts(contexts)
|
|
288
|
+
if new_record:
|
|
289
|
+
context_query.append(new_record)
|
|
276
290
|
if system_prompt:
|
|
277
291
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
278
292
|
|
|
@@ -303,16 +317,16 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
303
317
|
e: Exception,
|
|
304
318
|
payloads: dict,
|
|
305
319
|
context_query: list,
|
|
306
|
-
func_tool: ToolSet,
|
|
320
|
+
func_tool: ToolSet | None,
|
|
307
321
|
chosen_key: str,
|
|
308
|
-
available_api_keys:
|
|
322
|
+
available_api_keys: list[str],
|
|
309
323
|
retry_cnt: int,
|
|
310
324
|
max_retries: int,
|
|
311
325
|
) -> tuple:
|
|
312
326
|
"""处理API错误并尝试恢复"""
|
|
313
327
|
if "429" in str(e):
|
|
314
328
|
logger.warning(
|
|
315
|
-
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
|
329
|
+
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}",
|
|
316
330
|
)
|
|
317
331
|
# 最后一次不等待
|
|
318
332
|
if retry_cnt < max_retries - 1:
|
|
@@ -328,11 +342,10 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
328
342
|
context_query,
|
|
329
343
|
func_tool,
|
|
330
344
|
)
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
elif "maximum context length" in str(e):
|
|
345
|
+
raise e
|
|
346
|
+
if "maximum context length" in str(e):
|
|
334
347
|
logger.warning(
|
|
335
|
-
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
|
348
|
+
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}",
|
|
336
349
|
)
|
|
337
350
|
await self.pop_record(context_query)
|
|
338
351
|
payloads["messages"] = context_query
|
|
@@ -344,7 +357,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
344
357
|
context_query,
|
|
345
358
|
func_tool,
|
|
346
359
|
)
|
|
347
|
-
|
|
360
|
+
if "The model is not a VLM" in str(e): # siliconcloud
|
|
348
361
|
# 尝试删除所有 image
|
|
349
362
|
new_contexts = await self._remove_image_from_context(context_query)
|
|
350
363
|
payloads["messages"] = new_contexts
|
|
@@ -357,36 +370,34 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
357
370
|
context_query,
|
|
358
371
|
func_tool,
|
|
359
372
|
)
|
|
360
|
-
|
|
373
|
+
if (
|
|
361
374
|
"Function calling is not enabled" in str(e)
|
|
362
375
|
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
|
363
376
|
or ("function" in str(e).lower() and "support" in str(e).lower())
|
|
364
377
|
):
|
|
365
378
|
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
|
366
379
|
logger.info(
|
|
367
|
-
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
|
380
|
+
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
|
|
368
381
|
)
|
|
369
|
-
|
|
370
|
-
del payloads["tools"]
|
|
382
|
+
payloads.pop("tools", None)
|
|
371
383
|
return False, chosen_key, available_api_keys, payloads, context_query, None
|
|
372
|
-
|
|
373
|
-
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
|
384
|
+
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
|
374
385
|
|
|
375
|
-
|
|
376
|
-
|
|
386
|
+
if "tool" in str(e).lower() and "support" in str(e).lower():
|
|
387
|
+
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
|
377
388
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
389
|
+
if "Connection error." in str(e):
|
|
390
|
+
proxy = os.environ.get("http_proxy", None)
|
|
391
|
+
if proxy:
|
|
392
|
+
logger.error(
|
|
393
|
+
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
|
|
394
|
+
)
|
|
384
395
|
|
|
385
|
-
|
|
396
|
+
raise e
|
|
386
397
|
|
|
387
398
|
async def text_chat(
|
|
388
399
|
self,
|
|
389
|
-
prompt,
|
|
400
|
+
prompt=None,
|
|
390
401
|
session_id=None,
|
|
391
402
|
image_urls=None,
|
|
392
403
|
func_tool=None,
|
|
@@ -455,7 +466,7 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
455
466
|
|
|
456
467
|
async def text_chat_stream(
|
|
457
468
|
self,
|
|
458
|
-
prompt
|
|
469
|
+
prompt=None,
|
|
459
470
|
session_id=None,
|
|
460
471
|
image_urls=None,
|
|
461
472
|
func_tool=None,
|
|
@@ -522,10 +533,8 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
522
533
|
raise Exception("未知错误")
|
|
523
534
|
raise last_exception
|
|
524
535
|
|
|
525
|
-
async def _remove_image_from_context(self, contexts:
|
|
526
|
-
"""
|
|
527
|
-
从上下文中删除所有带有 image 的记录
|
|
528
|
-
"""
|
|
536
|
+
async def _remove_image_from_context(self, contexts: list):
|
|
537
|
+
"""从上下文中删除所有带有 image 的记录"""
|
|
529
538
|
new_contexts = []
|
|
530
539
|
|
|
531
540
|
for context in contexts:
|
|
@@ -546,14 +555,16 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
546
555
|
def get_current_key(self) -> str:
|
|
547
556
|
return self.client.api_key
|
|
548
557
|
|
|
549
|
-
def get_keys(self) ->
|
|
558
|
+
def get_keys(self) -> list[str]:
|
|
550
559
|
return self.api_keys
|
|
551
560
|
|
|
552
561
|
def set_key(self, key):
|
|
553
562
|
self.client.api_key = key
|
|
554
563
|
|
|
555
564
|
async def assemble_context(
|
|
556
|
-
self,
|
|
565
|
+
self,
|
|
566
|
+
text: str,
|
|
567
|
+
image_urls: list[str] | None = None,
|
|
557
568
|
) -> dict:
|
|
558
569
|
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
|
559
570
|
if image_urls:
|
|
@@ -577,16 +588,13 @@ class ProviderOpenAIOfficial(Provider):
|
|
|
577
588
|
{
|
|
578
589
|
"type": "image_url",
|
|
579
590
|
"image_url": {"url": image_data},
|
|
580
|
-
}
|
|
591
|
+
},
|
|
581
592
|
)
|
|
582
593
|
return user_content
|
|
583
|
-
|
|
584
|
-
return {"role": "user", "content": text}
|
|
594
|
+
return {"role": "user", "content": text}
|
|
585
595
|
|
|
586
596
|
async def encode_image_bs64(self, image_url: str) -> str:
|
|
587
|
-
"""
|
|
588
|
-
将图片转换为 base64
|
|
589
|
-
"""
|
|
597
|
+
"""将图片转换为 base64"""
|
|
590
598
|
if image_url.startswith("base64://"):
|
|
591
599
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
|
592
600
|
with open(image_url, "rb") as f:
|
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import uuid
|
|
3
|
-
|
|
4
|
-
from
|
|
3
|
+
|
|
4
|
+
from openai import NOT_GIVEN, AsyncOpenAI
|
|
5
|
+
|
|
6
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
7
|
+
|
|
5
8
|
from ..entities import ProviderType
|
|
9
|
+
from ..provider import TTSProvider
|
|
6
10
|
from ..register import register_provider_adapter
|
|
7
|
-
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
@register_provider_adapter(
|
|
11
|
-
"openai_tts_api",
|
|
14
|
+
"openai_tts_api",
|
|
15
|
+
"OpenAI TTS API",
|
|
16
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
12
17
|
)
|
|
13
18
|
class ProviderOpenAITTSAPI(TTSProvider):
|
|
14
19
|
def __init__(
|
|
@@ -26,7 +31,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
|
|
26
31
|
|
|
27
32
|
self.client = AsyncOpenAI(
|
|
28
33
|
api_key=self.chosen_api_key,
|
|
29
|
-
base_url=provider_config.get("api_base"
|
|
34
|
+
base_url=provider_config.get("api_base"),
|
|
30
35
|
timeout=timeout,
|
|
31
36
|
)
|
|
32
37
|
|
|
@@ -36,7 +41,10 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
|
|
36
41
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
37
42
|
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
|
|
38
43
|
async with self.client.audio.speech.with_streaming_response.create(
|
|
39
|
-
model=self.model_name,
|
|
44
|
+
model=self.model_name,
|
|
45
|
+
voice=self.voice,
|
|
46
|
+
response_format="wav",
|
|
47
|
+
input=text,
|
|
40
48
|
) as response:
|
|
41
49
|
with open(path, "wb") as f:
|
|
42
50
|
async for chunk in response.iter_bytes(chunk_size=1024):
|
|
@@ -1,22 +1,24 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Author: diudiu62
|
|
1
|
+
"""Author: diudiu62
|
|
3
2
|
Date: 2025-02-24 18:04:18
|
|
4
3
|
LastEditTime: 2025-02-25 14:06:30
|
|
5
4
|
"""
|
|
6
5
|
|
|
7
6
|
import asyncio
|
|
8
|
-
from datetime import datetime
|
|
9
7
|
import os
|
|
10
8
|
import re
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
|
|
11
11
|
from funasr_onnx import SenseVoiceSmall
|
|
12
12
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
|
13
|
-
|
|
14
|
-
from ..entities import ProviderType
|
|
15
|
-
from astrbot.core.utils.io import download_file
|
|
16
|
-
from ..register import register_provider_adapter
|
|
13
|
+
|
|
17
14
|
from astrbot.core import logger
|
|
15
|
+
from astrbot.core.utils.io import download_file
|
|
18
16
|
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
|
19
17
|
|
|
18
|
+
from ..entities import ProviderType
|
|
19
|
+
from ..provider import STTProvider
|
|
20
|
+
from ..register import register_provider_adapter
|
|
21
|
+
|
|
20
22
|
|
|
21
23
|
@register_provider_adapter(
|
|
22
24
|
"sensevoice_stt_selfhost",
|
|
@@ -30,7 +32,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
30
32
|
provider_settings: dict,
|
|
31
33
|
) -> None:
|
|
32
34
|
super().__init__(provider_config, provider_settings)
|
|
33
|
-
self.set_model(provider_config.get("stt_model"
|
|
35
|
+
self.set_model(provider_config.get("stt_model"))
|
|
34
36
|
self.model = None
|
|
35
37
|
self.is_emotion = provider_config.get("is_emotion", False)
|
|
36
38
|
|
|
@@ -39,7 +41,8 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
39
41
|
|
|
40
42
|
# 将模型加载放到线程池中执行
|
|
41
43
|
self.model = await asyncio.get_event_loop().run_in_executor(
|
|
42
|
-
None,
|
|
44
|
+
None,
|
|
45
|
+
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
|
|
43
46
|
)
|
|
44
47
|
|
|
45
48
|
logger.info("SenseVoice 模型加载完成。")
|
|
@@ -55,8 +58,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|
|
55
58
|
|
|
56
59
|
if silk_header in file_header:
|
|
57
60
|
return True
|
|
58
|
-
|
|
59
|
-
return False
|
|
61
|
+
return False
|
|
60
62
|
|
|
61
63
|
async def get_text(self, audio_url: str) -> str:
|
|
62
64
|
try:
|