AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
import wave
|
|
4
|
+
|
|
5
|
+
from google import genai
|
|
6
|
+
from google.genai import types
|
|
7
|
+
|
|
8
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
9
|
+
|
|
10
|
+
from ..entities import ProviderType
|
|
11
|
+
from ..provider import TTSProvider
|
|
12
|
+
from ..register import register_provider_adapter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_provider_adapter(
|
|
16
|
+
"gemini_tts",
|
|
17
|
+
"Gemini TTS API",
|
|
18
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
19
|
+
)
|
|
20
|
+
class ProviderGeminiTTSAPI(TTSProvider):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
provider_config: dict,
|
|
24
|
+
provider_settings: dict,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(provider_config, provider_settings)
|
|
27
|
+
api_key: str = provider_config.get("gemini_tts_api_key", "")
|
|
28
|
+
api_base: str | None = provider_config.get("gemini_tts_api_base")
|
|
29
|
+
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
|
|
30
|
+
http_options = types.HttpOptions(timeout=timeout * 1000)
|
|
31
|
+
|
|
32
|
+
if api_base:
|
|
33
|
+
api_base = api_base.removesuffix("/")
|
|
34
|
+
http_options.base_url = api_base
|
|
35
|
+
|
|
36
|
+
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
|
37
|
+
self.model: str = provider_config.get(
|
|
38
|
+
"gemini_tts_model",
|
|
39
|
+
"gemini-2.5-flash-preview-tts",
|
|
40
|
+
)
|
|
41
|
+
self.prefix: str | None = provider_config.get(
|
|
42
|
+
"gemini_tts_prefix",
|
|
43
|
+
)
|
|
44
|
+
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
|
|
45
|
+
|
|
46
|
+
async def get_audio(self, text: str) -> str:
|
|
47
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
48
|
+
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
|
|
49
|
+
prompt = f"{self.prefix}: {text}" if self.prefix else text
|
|
50
|
+
response = await self.client.models.generate_content(
|
|
51
|
+
model=self.model,
|
|
52
|
+
contents=prompt,
|
|
53
|
+
config=types.GenerateContentConfig(
|
|
54
|
+
response_modalities=["AUDIO"],
|
|
55
|
+
speech_config=types.SpeechConfig(
|
|
56
|
+
voice_config=types.VoiceConfig(
|
|
57
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
58
|
+
voice_name=self.voice_name,
|
|
59
|
+
),
|
|
60
|
+
),
|
|
61
|
+
),
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# 不想看类型检查报错
|
|
66
|
+
if (
|
|
67
|
+
not response.candidates
|
|
68
|
+
or not response.candidates[0].content
|
|
69
|
+
or not response.candidates[0].content.parts
|
|
70
|
+
or not response.candidates[0].content.parts[0].inline_data
|
|
71
|
+
or not response.candidates[0].content.parts[0].inline_data.data
|
|
72
|
+
):
|
|
73
|
+
raise Exception("No audio content returned from Gemini TTS API.")
|
|
74
|
+
|
|
75
|
+
with wave.open(path, "wb") as wf:
|
|
76
|
+
wf.setnchannels(1)
|
|
77
|
+
wf.setsampwidth(2)
|
|
78
|
+
wf.setframerate(24000)
|
|
79
|
+
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
|
80
|
+
|
|
81
|
+
return path
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from ..register import register_provider_adapter
|
|
2
|
+
from .openai_source import ProviderOpenAIOfficial
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@register_provider_adapter(
|
|
6
|
+
"groq_chat_completion", "Groq Chat Completion Provider Adapter"
|
|
7
|
+
)
|
|
8
|
+
class ProviderGroq(ProviderOpenAIOfficial):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
provider_config: dict,
|
|
12
|
+
provider_settings: dict,
|
|
13
|
+
) -> None:
|
|
14
|
+
super().__init__(provider_config, provider_settings)
|
|
15
|
+
self.reasoning_key = "reasoning"
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
|
|
5
|
+
import aiohttp
|
|
6
|
+
|
|
7
|
+
from astrbot import logger
|
|
8
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
9
|
+
|
|
10
|
+
from ..entities import ProviderType
|
|
11
|
+
from ..provider import TTSProvider
|
|
12
|
+
from ..register import register_provider_adapter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@register_provider_adapter(
|
|
16
|
+
provider_type_name="gsv_tts_selfhost",
|
|
17
|
+
desc="GPT-SoVITS TTS(本地加载)",
|
|
18
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
19
|
+
)
|
|
20
|
+
class ProviderGSVTTS(TTSProvider):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
provider_config: dict,
|
|
24
|
+
provider_settings: dict,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(provider_config, provider_settings)
|
|
27
|
+
|
|
28
|
+
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
|
|
29
|
+
"/",
|
|
30
|
+
)
|
|
31
|
+
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
|
|
32
|
+
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
|
|
33
|
+
|
|
34
|
+
# TTS 请求的默认参数,移除前缀gsv_
|
|
35
|
+
self.default_params: dict = {
|
|
36
|
+
key.removeprefix("gsv_"): str(value).lower()
|
|
37
|
+
for key, value in provider_config.get("gsv_default_parms", {}).items()
|
|
38
|
+
}
|
|
39
|
+
self.timeout = provider_config.get("timeout", 60)
|
|
40
|
+
self._session: aiohttp.ClientSession | None = None
|
|
41
|
+
|
|
42
|
+
async def initialize(self):
|
|
43
|
+
"""异步初始化:在 ProviderManager 中被调用"""
|
|
44
|
+
self._session = aiohttp.ClientSession(
|
|
45
|
+
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
|
46
|
+
)
|
|
47
|
+
try:
|
|
48
|
+
await self._set_model_weights()
|
|
49
|
+
logger.info("[GSV TTS] 初始化完成")
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"[GSV TTS] 初始化失败:{e}")
|
|
52
|
+
raise
|
|
53
|
+
|
|
54
|
+
def get_session(self) -> aiohttp.ClientSession:
|
|
55
|
+
if not self._session or self._session.closed:
|
|
56
|
+
raise RuntimeError(
|
|
57
|
+
"[GSV TTS] Provider HTTP session is not ready or closed.",
|
|
58
|
+
)
|
|
59
|
+
return self._session
|
|
60
|
+
|
|
61
|
+
async def _make_request(
|
|
62
|
+
self,
|
|
63
|
+
endpoint: str,
|
|
64
|
+
params=None,
|
|
65
|
+
retries: int = 3,
|
|
66
|
+
) -> bytes | None:
|
|
67
|
+
"""发起请求"""
|
|
68
|
+
for attempt in range(retries):
|
|
69
|
+
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
|
|
70
|
+
try:
|
|
71
|
+
async with self.get_session().get(endpoint, params=params) as response:
|
|
72
|
+
if response.status != 200:
|
|
73
|
+
error_text = await response.text()
|
|
74
|
+
raise Exception(
|
|
75
|
+
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}",
|
|
76
|
+
)
|
|
77
|
+
return await response.read()
|
|
78
|
+
except Exception as e:
|
|
79
|
+
if attempt < retries - 1:
|
|
80
|
+
logger.warning(
|
|
81
|
+
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中...",
|
|
82
|
+
)
|
|
83
|
+
await asyncio.sleep(1)
|
|
84
|
+
else:
|
|
85
|
+
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
|
|
86
|
+
raise
|
|
87
|
+
|
|
88
|
+
async def _set_model_weights(self):
|
|
89
|
+
"""设置模型路径"""
|
|
90
|
+
try:
|
|
91
|
+
if self.gpt_weights_path:
|
|
92
|
+
await self._make_request(
|
|
93
|
+
f"{self.api_base}/set_gpt_weights",
|
|
94
|
+
{"weights_path": self.gpt_weights_path},
|
|
95
|
+
)
|
|
96
|
+
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
|
|
97
|
+
else:
|
|
98
|
+
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
|
|
99
|
+
|
|
100
|
+
if self.sovits_weights_path:
|
|
101
|
+
await self._make_request(
|
|
102
|
+
f"{self.api_base}/set_sovits_weights",
|
|
103
|
+
{"weights_path": self.sovits_weights_path},
|
|
104
|
+
)
|
|
105
|
+
logger.info(
|
|
106
|
+
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}",
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
|
|
110
|
+
except aiohttp.ClientError as e:
|
|
111
|
+
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
|
|
114
|
+
|
|
115
|
+
async def get_audio(self, text: str) -> str:
|
|
116
|
+
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
|
|
117
|
+
if not text.strip():
|
|
118
|
+
raise ValueError("[GSV TTS] TTS 文本不能为空")
|
|
119
|
+
|
|
120
|
+
endpoint = f"{self.api_base}/tts"
|
|
121
|
+
|
|
122
|
+
params = self.build_synthesis_params(text)
|
|
123
|
+
|
|
124
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
125
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
126
|
+
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
|
|
127
|
+
|
|
128
|
+
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
|
|
129
|
+
|
|
130
|
+
result = await self._make_request(endpoint, params)
|
|
131
|
+
if isinstance(result, bytes):
|
|
132
|
+
with open(path, "wb") as f:
|
|
133
|
+
f.write(result)
|
|
134
|
+
return path
|
|
135
|
+
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
|
|
136
|
+
|
|
137
|
+
def build_synthesis_params(self, text: str) -> dict:
|
|
138
|
+
"""构建语音合成所需的参数字典。
|
|
139
|
+
|
|
140
|
+
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
|
|
141
|
+
"""
|
|
142
|
+
params = self.default_params.copy()
|
|
143
|
+
params["text"] = text
|
|
144
|
+
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
|
|
145
|
+
return params
|
|
146
|
+
|
|
147
|
+
async def terminate(self):
|
|
148
|
+
"""终止释放资源:在 ProviderManager 中被调用"""
|
|
149
|
+
if self._session and not self._session.closed:
|
|
150
|
+
await self._session.close()
|
|
151
|
+
logger.info("[GSV TTS] Session 已关闭")
|
|
@@ -1,13 +1,20 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import urllib.parse
|
|
1
3
|
import uuid
|
|
4
|
+
|
|
2
5
|
import aiohttp
|
|
3
|
-
|
|
4
|
-
from
|
|
6
|
+
|
|
7
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
8
|
+
|
|
5
9
|
from ..entities import ProviderType
|
|
10
|
+
from ..provider import TTSProvider
|
|
6
11
|
from ..register import register_provider_adapter
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
@register_provider_adapter(
|
|
10
|
-
"gsvi_tts_api",
|
|
15
|
+
"gsvi_tts_api",
|
|
16
|
+
"GSVI TTS API",
|
|
17
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
11
18
|
)
|
|
12
19
|
class ProviderGSVITTS(TTSProvider):
|
|
13
20
|
def __init__(
|
|
@@ -17,13 +24,13 @@ class ProviderGSVITTS(TTSProvider):
|
|
|
17
24
|
) -> None:
|
|
18
25
|
super().__init__(provider_config, provider_settings)
|
|
19
26
|
self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
|
|
20
|
-
|
|
21
|
-
self.api_base = self.api_base[:-1]
|
|
27
|
+
self.api_base = self.api_base.removesuffix("/")
|
|
22
28
|
self.character = provider_config.get("character")
|
|
23
29
|
self.emotion = provider_config.get("emotion")
|
|
24
30
|
|
|
25
31
|
async def get_audio(self, text: str) -> str:
|
|
26
|
-
|
|
32
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
33
|
+
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
|
|
27
34
|
params = {"text": text}
|
|
28
35
|
|
|
29
36
|
if self.character:
|
|
@@ -46,7 +53,7 @@ class ProviderGSVITTS(TTSProvider):
|
|
|
46
53
|
else:
|
|
47
54
|
error_text = await response.text()
|
|
48
55
|
raise Exception(
|
|
49
|
-
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}"
|
|
56
|
+
f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
|
|
50
57
|
)
|
|
51
58
|
|
|
52
59
|
return path
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
|
|
6
|
+
import aiohttp
|
|
7
|
+
|
|
8
|
+
from astrbot.api import logger
|
|
9
|
+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
10
|
+
|
|
11
|
+
from ..entities import ProviderType
|
|
12
|
+
from ..provider import TTSProvider
|
|
13
|
+
from ..register import register_provider_adapter
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@register_provider_adapter(
|
|
17
|
+
"minimax_tts_api",
|
|
18
|
+
"MiniMax TTS API",
|
|
19
|
+
provider_type=ProviderType.TEXT_TO_SPEECH,
|
|
20
|
+
)
|
|
21
|
+
class ProviderMiniMaxTTSAPI(TTSProvider):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
provider_config: dict,
|
|
25
|
+
provider_settings: dict,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(provider_config, provider_settings)
|
|
28
|
+
self.chosen_api_key: str = provider_config.get("api_key", "")
|
|
29
|
+
self.api_base: str = provider_config.get(
|
|
30
|
+
"api_base",
|
|
31
|
+
"https://api.minimax.chat/v1/t2a_v2",
|
|
32
|
+
)
|
|
33
|
+
self.group_id: str = provider_config.get("minimax-group-id", "")
|
|
34
|
+
self.set_model(provider_config.get("model", ""))
|
|
35
|
+
self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
|
|
36
|
+
self.is_timber_weight: bool = provider_config.get(
|
|
37
|
+
"minimax-is-timber-weight",
|
|
38
|
+
False,
|
|
39
|
+
)
|
|
40
|
+
self.timber_weight: list[dict[str, str | int]] = json.loads(
|
|
41
|
+
provider_config.get(
|
|
42
|
+
"minimax-timber-weight",
|
|
43
|
+
'[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
|
|
44
|
+
),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
self.voice_setting: dict = {
|
|
48
|
+
"speed": provider_config.get("minimax-voice-speed", 1.0),
|
|
49
|
+
"vol": provider_config.get("minimax-voice-vol", 1.0),
|
|
50
|
+
"pitch": provider_config.get("minimax-voice-pitch", 0),
|
|
51
|
+
"voice_id": ""
|
|
52
|
+
if self.is_timber_weight
|
|
53
|
+
else provider_config.get("minimax-voice-id", ""),
|
|
54
|
+
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
|
55
|
+
"latex_read": provider_config.get("minimax-voice-latex", False),
|
|
56
|
+
"english_normalization": provider_config.get(
|
|
57
|
+
"minimax-voice-english-normalization",
|
|
58
|
+
False,
|
|
59
|
+
),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
self.audio_setting: dict = {
|
|
63
|
+
"sample_rate": 32000,
|
|
64
|
+
"bitrate": 128000,
|
|
65
|
+
"format": "mp3",
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
|
|
69
|
+
self.headers = {
|
|
70
|
+
"Authorization": f"Bearer {self.chosen_api_key}",
|
|
71
|
+
"accept": "application/json, text/plain, */*",
|
|
72
|
+
"content-type": "application/json",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
def _build_tts_stream_body(self, text: str):
|
|
76
|
+
"""构建流式请求体"""
|
|
77
|
+
dict_body: dict[str, object] = {
|
|
78
|
+
"model": self.model_name,
|
|
79
|
+
"text": text,
|
|
80
|
+
"stream": True,
|
|
81
|
+
"language_boost": self.lang_boost,
|
|
82
|
+
"voice_setting": self.voice_setting,
|
|
83
|
+
"audio_setting": self.audio_setting,
|
|
84
|
+
}
|
|
85
|
+
if self.is_timber_weight:
|
|
86
|
+
dict_body["timber_weights"] = self.timber_weight
|
|
87
|
+
|
|
88
|
+
return json.dumps(dict_body)
|
|
89
|
+
|
|
90
|
+
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
|
91
|
+
"""进行流式请求"""
|
|
92
|
+
try:
|
|
93
|
+
async with (
|
|
94
|
+
aiohttp.ClientSession() as session,
|
|
95
|
+
session.post(
|
|
96
|
+
self.concat_base_url,
|
|
97
|
+
headers=self.headers,
|
|
98
|
+
data=self._build_tts_stream_body(text),
|
|
99
|
+
timeout=aiohttp.ClientTimeout(total=60),
|
|
100
|
+
) as response,
|
|
101
|
+
):
|
|
102
|
+
response.raise_for_status()
|
|
103
|
+
|
|
104
|
+
buffer = b""
|
|
105
|
+
while True:
|
|
106
|
+
chunk = await response.content.read(8192)
|
|
107
|
+
if not chunk:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
buffer += chunk
|
|
111
|
+
|
|
112
|
+
while b"\n\n" in buffer:
|
|
113
|
+
try:
|
|
114
|
+
message, buffer = buffer.split(b"\n\n", 1)
|
|
115
|
+
if message.startswith(b"data: "):
|
|
116
|
+
try:
|
|
117
|
+
data = json.loads(message[6:])
|
|
118
|
+
if "extra_info" in data:
|
|
119
|
+
continue
|
|
120
|
+
audio = data.get("data", {}).get("audio")
|
|
121
|
+
if audio is not None:
|
|
122
|
+
yield audio
|
|
123
|
+
except json.JSONDecodeError:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"Failed to parse JSON data from SSE message",
|
|
126
|
+
)
|
|
127
|
+
continue
|
|
128
|
+
except ValueError:
|
|
129
|
+
buffer = buffer[-1024:]
|
|
130
|
+
|
|
131
|
+
except aiohttp.ClientError as e:
|
|
132
|
+
raise Exception(f"MiniMax TTS API请求失败: {e!s}")
|
|
133
|
+
|
|
134
|
+
async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
|
|
135
|
+
"""解码数据流到 audio 比特流"""
|
|
136
|
+
chunks = []
|
|
137
|
+
async for chunk in audio_stream:
|
|
138
|
+
if chunk.strip():
|
|
139
|
+
chunks.append(bytes.fromhex(chunk.strip()))
|
|
140
|
+
return b"".join(chunks)
|
|
141
|
+
|
|
142
|
+
async def get_audio(self, text: str) -> str:
|
|
143
|
+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
|
144
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
145
|
+
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
# 直接将异步生成器传递给 _audio_play 方法
|
|
149
|
+
audio_stream = self._call_tts_stream(text)
|
|
150
|
+
audio = await self._audio_play(audio_stream)
|
|
151
|
+
|
|
152
|
+
# 结果保存至文件
|
|
153
|
+
with open(path, "wb") as file:
|
|
154
|
+
file.write(audio)
|
|
155
|
+
|
|
156
|
+
return path
|
|
157
|
+
|
|
158
|
+
except aiohttp.ClientError as e:
|
|
159
|
+
raise e
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from openai import AsyncOpenAI
|
|
2
|
+
|
|
3
|
+
from ..entities import ProviderType
|
|
4
|
+
from ..provider import EmbeddingProvider
|
|
5
|
+
from ..register import register_provider_adapter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@register_provider_adapter(
|
|
9
|
+
"openai_embedding",
|
|
10
|
+
"OpenAI API Embedding 提供商适配器",
|
|
11
|
+
provider_type=ProviderType.EMBEDDING,
|
|
12
|
+
)
|
|
13
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
14
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
15
|
+
super().__init__(provider_config, provider_settings)
|
|
16
|
+
self.provider_config = provider_config
|
|
17
|
+
self.provider_settings = provider_settings
|
|
18
|
+
self.client = AsyncOpenAI(
|
|
19
|
+
api_key=provider_config.get("embedding_api_key"),
|
|
20
|
+
base_url=provider_config.get(
|
|
21
|
+
"embedding_api_base",
|
|
22
|
+
"https://api.openai.com/v1",
|
|
23
|
+
),
|
|
24
|
+
timeout=int(provider_config.get("timeout", 20)),
|
|
25
|
+
)
|
|
26
|
+
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
|
27
|
+
|
|
28
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
29
|
+
"""获取文本的嵌入"""
|
|
30
|
+
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
|
31
|
+
return embedding.data[0].embedding
|
|
32
|
+
|
|
33
|
+
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
34
|
+
"""批量获取文本的嵌入"""
|
|
35
|
+
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
|
36
|
+
return [item.embedding for item in embeddings.data]
|
|
37
|
+
|
|
38
|
+
def get_dim(self) -> int:
|
|
39
|
+
"""获取向量的维度"""
|
|
40
|
+
return self.provider_config.get("embedding_dimensions", 1024)
|