AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,145 +1,239 @@
|
|
|
1
|
-
import traceback
|
|
2
1
|
import asyncio
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from .
|
|
6
|
-
from
|
|
2
|
+
import traceback
|
|
3
|
+
|
|
4
|
+
from astrbot.core import astrbot_config, logger, sp
|
|
5
|
+
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
7
6
|
from astrbot.core.db import BaseDatabase
|
|
8
|
-
|
|
9
|
-
from
|
|
7
|
+
|
|
8
|
+
from ..persona_mgr import PersonaManager
|
|
9
|
+
from .entities import ProviderType
|
|
10
|
+
from .provider import (
|
|
11
|
+
EmbeddingProvider,
|
|
12
|
+
Provider,
|
|
13
|
+
RerankProvider,
|
|
14
|
+
STTProvider,
|
|
15
|
+
TTSProvider,
|
|
16
|
+
)
|
|
17
|
+
from .register import llm_tools, provider_cls_map
|
|
10
18
|
|
|
11
19
|
|
|
12
20
|
class ProviderManager:
|
|
13
|
-
def __init__(
|
|
14
|
-
self
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
acm: AstrBotConfigManager,
|
|
24
|
+
db_helper: BaseDatabase,
|
|
25
|
+
persona_mgr: PersonaManager,
|
|
26
|
+
):
|
|
27
|
+
self.reload_lock = asyncio.Lock()
|
|
28
|
+
self.persona_mgr = persona_mgr
|
|
29
|
+
self.acm = acm
|
|
30
|
+
config = acm.confs["default"]
|
|
31
|
+
self.providers_config: list = config["provider"]
|
|
15
32
|
self.provider_settings: dict = config["provider_settings"]
|
|
16
33
|
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
|
17
34
|
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
|
18
|
-
self.persona_configs: list = config.get("persona", [])
|
|
19
|
-
self.astrbot_config = config
|
|
20
|
-
|
|
21
|
-
self.selected_provider_id = sp.get("curr_provider")
|
|
22
|
-
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
|
23
|
-
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
|
|
24
|
-
self.provider_enabled = self.provider_settings.get("enable", False)
|
|
25
|
-
self.stt_enabled = self.provider_stt_settings.get("enable", False)
|
|
26
|
-
self.tts_enabled = self.provider_tts_settings.get("enable", False)
|
|
27
|
-
|
|
28
|
-
# 人格情景管理
|
|
29
|
-
# 目前没有拆成独立的模块
|
|
30
|
-
self.default_persona_name = self.provider_settings.get(
|
|
31
|
-
"default_personality", "default"
|
|
32
|
-
)
|
|
33
|
-
self.personas: List[Personality] = []
|
|
34
|
-
self.selected_default_persona = None
|
|
35
|
-
for persona in self.persona_configs:
|
|
36
|
-
begin_dialogs = persona.get("begin_dialogs", [])
|
|
37
|
-
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
|
38
|
-
bd_processed = []
|
|
39
|
-
mid_processed = ""
|
|
40
|
-
if begin_dialogs:
|
|
41
|
-
if len(begin_dialogs) % 2 != 0:
|
|
42
|
-
logger.error(
|
|
43
|
-
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
|
44
|
-
)
|
|
45
|
-
begin_dialogs = []
|
|
46
|
-
user_turn = True
|
|
47
|
-
for dialog in begin_dialogs:
|
|
48
|
-
bd_processed.append(
|
|
49
|
-
{
|
|
50
|
-
"role": "user" if user_turn else "assistant",
|
|
51
|
-
"content": dialog,
|
|
52
|
-
"_no_save": None, # 不持久化到 db
|
|
53
|
-
}
|
|
54
|
-
)
|
|
55
|
-
user_turn = not user_turn
|
|
56
|
-
if mood_imitation_dialogs:
|
|
57
|
-
if len(mood_imitation_dialogs) % 2 != 0:
|
|
58
|
-
logger.error(
|
|
59
|
-
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
|
|
60
|
-
)
|
|
61
|
-
mood_imitation_dialogs = []
|
|
62
|
-
user_turn = True
|
|
63
|
-
for dialog in mood_imitation_dialogs:
|
|
64
|
-
role = "A" if user_turn else "B"
|
|
65
|
-
mid_processed += f"{role}: {dialog}\n"
|
|
66
|
-
if not user_turn:
|
|
67
|
-
mid_processed += "\n"
|
|
68
|
-
user_turn = not user_turn
|
|
69
35
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
**persona,
|
|
73
|
-
_begin_dialogs_processed=bd_processed,
|
|
74
|
-
_mood_imitation_dialogs_processed=mid_processed,
|
|
75
|
-
)
|
|
76
|
-
if persona["name"] == self.default_persona_name:
|
|
77
|
-
self.selected_default_persona = persona
|
|
78
|
-
self.personas.append(persona)
|
|
79
|
-
except Exception as e:
|
|
80
|
-
logger.error(f"解析 Persona 配置失败:{e}")
|
|
81
|
-
|
|
82
|
-
if not self.selected_default_persona and len(self.personas) > 0:
|
|
83
|
-
# 默认选择第一个
|
|
84
|
-
self.selected_default_persona = self.personas[0]
|
|
85
|
-
|
|
86
|
-
if not self.selected_default_persona:
|
|
87
|
-
self.selected_default_persona = Personality(
|
|
88
|
-
prompt="You are a helpful and friendly assistant.",
|
|
89
|
-
name="default",
|
|
90
|
-
_begin_dialogs_processed=[],
|
|
91
|
-
_mood_imitation_dialogs_processed="",
|
|
92
|
-
)
|
|
93
|
-
self.personas.append(self.selected_default_persona)
|
|
36
|
+
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
|
37
|
+
self.default_persona_name = persona_mgr.default_persona
|
|
94
38
|
|
|
95
|
-
self.provider_insts:
|
|
39
|
+
self.provider_insts: list[Provider] = []
|
|
96
40
|
"""加载的 Provider 的实例"""
|
|
97
|
-
self.stt_provider_insts:
|
|
41
|
+
self.stt_provider_insts: list[STTProvider] = []
|
|
98
42
|
"""加载的 Speech To Text Provider 的实例"""
|
|
99
|
-
self.tts_provider_insts:
|
|
43
|
+
self.tts_provider_insts: list[TTSProvider] = []
|
|
100
44
|
"""加载的 Text To Speech Provider 的实例"""
|
|
101
|
-
self.
|
|
45
|
+
self.embedding_provider_insts: list[EmbeddingProvider] = []
|
|
46
|
+
"""加载的 Embedding Provider 的实例"""
|
|
47
|
+
self.rerank_provider_insts: list[RerankProvider] = []
|
|
48
|
+
"""加载的 Rerank Provider 的实例"""
|
|
49
|
+
self.inst_map: dict[
|
|
50
|
+
str,
|
|
51
|
+
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
|
52
|
+
] = {}
|
|
102
53
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
|
103
54
|
self.llm_tools = llm_tools
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
55
|
+
|
|
56
|
+
self.curr_provider_inst: Provider | None = None
|
|
57
|
+
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
|
58
|
+
self.curr_stt_provider_inst: STTProvider | None = None
|
|
59
|
+
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
|
60
|
+
self.curr_tts_provider_inst: TTSProvider | None = None
|
|
61
|
+
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
|
110
62
|
self.db_helper = db_helper
|
|
111
63
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
64
|
+
@property
|
|
65
|
+
def persona_configs(self) -> list:
|
|
66
|
+
"""动态获取最新的 persona 配置"""
|
|
67
|
+
return self.persona_mgr.persona_v3_config
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def personas(self) -> list:
|
|
71
|
+
"""动态获取最新的 personas 列表"""
|
|
72
|
+
return self.persona_mgr.personas_v3
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def selected_default_persona(self):
|
|
76
|
+
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
|
77
|
+
return self.persona_mgr.selected_default_persona_v3
|
|
78
|
+
|
|
79
|
+
async def set_provider(
|
|
80
|
+
self,
|
|
81
|
+
provider_id: str,
|
|
82
|
+
provider_type: ProviderType,
|
|
83
|
+
umo: str | None = None,
|
|
84
|
+
):
|
|
85
|
+
"""设置提供商。
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider_id (str): 提供商 ID。
|
|
89
|
+
provider_type (ProviderType): 提供商类型。
|
|
90
|
+
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
|
91
|
+
|
|
92
|
+
Version 4.0.0: 这个版本下已经默认隔离提供商
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
if provider_id not in self.inst_map:
|
|
96
|
+
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
|
97
|
+
if umo:
|
|
98
|
+
await sp.session_put(
|
|
99
|
+
umo,
|
|
100
|
+
f"provider_perf_{provider_type.value}",
|
|
101
|
+
provider_id,
|
|
102
|
+
)
|
|
103
|
+
return
|
|
104
|
+
# 不启用提供商会话隔离模式的情况
|
|
105
|
+
|
|
106
|
+
prov = self.inst_map[provider_id]
|
|
107
|
+
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
|
108
|
+
prov,
|
|
109
|
+
TTSProvider,
|
|
110
|
+
):
|
|
111
|
+
self.curr_tts_provider_inst = prov
|
|
112
|
+
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
|
113
|
+
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
|
114
|
+
prov,
|
|
115
|
+
STTProvider,
|
|
116
|
+
):
|
|
117
|
+
self.curr_stt_provider_inst = prov
|
|
118
|
+
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
|
119
|
+
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
|
120
|
+
prov,
|
|
121
|
+
Provider,
|
|
122
|
+
):
|
|
123
|
+
self.curr_provider_inst = prov
|
|
124
|
+
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
125
|
+
|
|
126
|
+
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
|
127
|
+
"""根据提供商 ID 获取提供商实例"""
|
|
128
|
+
return self.inst_map.get(provider_id)
|
|
129
|
+
|
|
130
|
+
def get_using_provider(
|
|
131
|
+
self,
|
|
132
|
+
provider_type: ProviderType,
|
|
133
|
+
umo=None,
|
|
134
|
+
) -> Provider | STTProvider | TTSProvider | None:
|
|
135
|
+
"""获取正在使用的提供商实例。
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
provider_type (ProviderType): 提供商类型。
|
|
139
|
+
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Provider: 正在使用的提供商实例。
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
provider = None
|
|
146
|
+
if umo:
|
|
147
|
+
provider_id = sp.get(
|
|
148
|
+
f"provider_perf_{provider_type.value}",
|
|
149
|
+
None,
|
|
150
|
+
scope="umo",
|
|
151
|
+
scope_id=umo,
|
|
152
|
+
)
|
|
153
|
+
if provider_id:
|
|
154
|
+
provider = self.inst_map.get(provider_id)
|
|
155
|
+
if not provider:
|
|
156
|
+
# default setting
|
|
157
|
+
config = self.acm.get_conf(umo)
|
|
158
|
+
if provider_type == ProviderType.CHAT_COMPLETION:
|
|
159
|
+
provider_id = config["provider_settings"].get("default_provider_id")
|
|
160
|
+
provider = self.inst_map.get(provider_id)
|
|
161
|
+
if not provider:
|
|
162
|
+
provider = self.provider_insts[0] if self.provider_insts else None
|
|
163
|
+
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
|
164
|
+
provider_id = config["provider_stt_settings"].get("provider_id")
|
|
165
|
+
if not provider_id:
|
|
166
|
+
return None
|
|
167
|
+
provider = self.inst_map.get(provider_id)
|
|
168
|
+
if not provider:
|
|
169
|
+
provider = (
|
|
170
|
+
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
|
171
|
+
)
|
|
172
|
+
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
|
173
|
+
provider_id = config["provider_tts_settings"].get("provider_id")
|
|
174
|
+
if not provider_id:
|
|
175
|
+
return None
|
|
176
|
+
provider = self.inst_map.get(provider_id)
|
|
177
|
+
if not provider:
|
|
178
|
+
provider = (
|
|
179
|
+
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
183
|
+
return provider
|
|
117
184
|
|
|
118
185
|
async def initialize(self):
|
|
186
|
+
# 逐个初始化提供商
|
|
119
187
|
for provider_config in self.providers_config:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
188
|
+
try:
|
|
189
|
+
await self.load_provider(provider_config)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.error(traceback.format_exc())
|
|
192
|
+
logger.error(e)
|
|
193
|
+
|
|
194
|
+
# 设置默认提供商
|
|
195
|
+
selected_provider_id = sp.get(
|
|
196
|
+
"curr_provider",
|
|
197
|
+
self.provider_settings.get("default_provider_id"),
|
|
198
|
+
scope="global",
|
|
199
|
+
scope_id="global",
|
|
200
|
+
)
|
|
201
|
+
selected_stt_provider_id = sp.get(
|
|
202
|
+
"curr_provider_stt",
|
|
203
|
+
self.provider_stt_settings.get("provider_id"),
|
|
204
|
+
scope="global",
|
|
205
|
+
scope_id="global",
|
|
206
|
+
)
|
|
207
|
+
selected_tts_provider_id = sp.get(
|
|
208
|
+
"curr_provider_tts",
|
|
209
|
+
self.provider_tts_settings.get("provider_id"),
|
|
210
|
+
scope="global",
|
|
211
|
+
scope_id="global",
|
|
212
|
+
)
|
|
213
|
+
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
|
214
|
+
if not self.curr_provider_inst and self.provider_insts:
|
|
215
|
+
self.curr_provider_inst = self.provider_insts[0]
|
|
124
216
|
|
|
125
|
-
|
|
126
|
-
|
|
217
|
+
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
|
218
|
+
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
|
219
|
+
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
127
220
|
|
|
128
|
-
|
|
129
|
-
|
|
221
|
+
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
|
222
|
+
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
|
223
|
+
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
130
224
|
|
|
131
225
|
# 初始化 MCP Client 连接
|
|
132
|
-
asyncio.create_task(
|
|
133
|
-
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
|
|
134
|
-
)
|
|
135
|
-
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
|
|
226
|
+
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
|
136
227
|
|
|
137
228
|
async def load_provider(self, provider_config: dict):
|
|
138
229
|
if not provider_config["enable"]:
|
|
230
|
+
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
|
231
|
+
return
|
|
232
|
+
if provider_config.get("provider_type", "") == "agent_runner":
|
|
139
233
|
return
|
|
140
234
|
|
|
141
235
|
logger.info(
|
|
142
|
-
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
|
|
236
|
+
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
|
143
237
|
)
|
|
144
238
|
|
|
145
239
|
# 动态导入
|
|
@@ -151,21 +245,12 @@ class ProviderManager:
|
|
|
151
245
|
)
|
|
152
246
|
case "zhipu_chat_completion":
|
|
153
247
|
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
|
248
|
+
case "groq_chat_completion":
|
|
249
|
+
from .sources.groq_source import ProviderGroq as ProviderGroq
|
|
154
250
|
case "anthropic_chat_completion":
|
|
155
251
|
from .sources.anthropic_source import (
|
|
156
252
|
ProviderAnthropic as ProviderAnthropic,
|
|
157
253
|
)
|
|
158
|
-
case "llm_tuner":
|
|
159
|
-
logger.info("加载 LLM Tuner 工具 ...")
|
|
160
|
-
from .sources.llmtuner_source import (
|
|
161
|
-
LLMTunerModelLoader as LLMTunerModelLoader,
|
|
162
|
-
)
|
|
163
|
-
case "dify":
|
|
164
|
-
from .sources.dify_source import ProviderDify as ProviderDify
|
|
165
|
-
case "dashscope":
|
|
166
|
-
from .sources.dashscope_source import (
|
|
167
|
-
ProviderDashscope as ProviderDashscope,
|
|
168
|
-
)
|
|
169
254
|
case "googlegenai_chat_completion":
|
|
170
255
|
from .sources.gemini_source import (
|
|
171
256
|
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
|
@@ -182,6 +267,10 @@ class ProviderManager:
|
|
|
182
267
|
from .sources.whisper_selfhosted_source import (
|
|
183
268
|
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
|
184
269
|
)
|
|
270
|
+
case "xinference_stt":
|
|
271
|
+
from .sources.xinference_stt_provider import (
|
|
272
|
+
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
|
273
|
+
)
|
|
185
274
|
case "openai_tts_api":
|
|
186
275
|
from .sources.openai_tts_api_source import (
|
|
187
276
|
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
@@ -190,6 +279,10 @@ class ProviderManager:
|
|
|
190
279
|
from .sources.edge_tts_source import (
|
|
191
280
|
ProviderEdgeTTS as ProviderEdgeTTS,
|
|
192
281
|
)
|
|
282
|
+
case "gsv_tts_selfhost":
|
|
283
|
+
from .sources.gsv_selfhosted_source import (
|
|
284
|
+
ProviderGSVTTS as ProviderGSVTTS,
|
|
285
|
+
)
|
|
193
286
|
case "gsvi_tts_api":
|
|
194
287
|
from .sources.gsvi_tts_source import (
|
|
195
288
|
ProviderGSVITTS as ProviderGSVITTS,
|
|
@@ -202,77 +295,109 @@ class ProviderManager:
|
|
|
202
295
|
from .sources.dashscope_tts import (
|
|
203
296
|
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
|
204
297
|
)
|
|
298
|
+
case "azure_tts":
|
|
299
|
+
from .sources.azure_tts_source import (
|
|
300
|
+
AzureTTSProvider as AzureTTSProvider,
|
|
301
|
+
)
|
|
302
|
+
case "minimax_tts_api":
|
|
303
|
+
from .sources.minimax_tts_api_source import (
|
|
304
|
+
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
|
305
|
+
)
|
|
306
|
+
case "volcengine_tts":
|
|
307
|
+
from .sources.volcengine_tts import (
|
|
308
|
+
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
|
309
|
+
)
|
|
310
|
+
case "gemini_tts":
|
|
311
|
+
from .sources.gemini_tts_source import (
|
|
312
|
+
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
|
313
|
+
)
|
|
314
|
+
case "openai_embedding":
|
|
315
|
+
from .sources.openai_embedding_source import (
|
|
316
|
+
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
|
317
|
+
)
|
|
318
|
+
case "gemini_embedding":
|
|
319
|
+
from .sources.gemini_embedding_source import (
|
|
320
|
+
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
|
321
|
+
)
|
|
322
|
+
case "vllm_rerank":
|
|
323
|
+
from .sources.vllm_rerank_source import (
|
|
324
|
+
VLLMRerankProvider as VLLMRerankProvider,
|
|
325
|
+
)
|
|
326
|
+
case "xinference_rerank":
|
|
327
|
+
from .sources.xinference_rerank_source import (
|
|
328
|
+
XinferenceRerankProvider as XinferenceRerankProvider,
|
|
329
|
+
)
|
|
330
|
+
case "bailian_rerank":
|
|
331
|
+
from .sources.bailian_rerank_source import (
|
|
332
|
+
BailianRerankProvider as BailianRerankProvider,
|
|
333
|
+
)
|
|
205
334
|
except (ImportError, ModuleNotFoundError) as e:
|
|
206
335
|
logger.critical(
|
|
207
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
|
336
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
|
208
337
|
)
|
|
209
338
|
return
|
|
210
339
|
except Exception as e:
|
|
211
340
|
logger.critical(
|
|
212
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因"
|
|
341
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
|
213
342
|
)
|
|
214
343
|
return
|
|
215
344
|
|
|
216
345
|
if provider_config["type"] not in provider_cls_map:
|
|
217
346
|
logger.error(
|
|
218
|
-
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。"
|
|
347
|
+
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
|
219
348
|
)
|
|
220
349
|
return
|
|
221
350
|
|
|
222
351
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
|
223
352
|
try:
|
|
224
353
|
# 按任务实例化提供商
|
|
354
|
+
cls_type = provider_metadata.cls_type
|
|
355
|
+
if not cls_type:
|
|
356
|
+
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
|
357
|
+
return
|
|
358
|
+
|
|
359
|
+
provider_metadata.id = provider_config["id"]
|
|
225
360
|
|
|
226
361
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
|
227
362
|
# STT 任务
|
|
228
|
-
inst =
|
|
229
|
-
provider_config, self.provider_settings
|
|
230
|
-
)
|
|
363
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
231
364
|
|
|
232
365
|
if getattr(inst, "initialize", None):
|
|
233
366
|
await inst.initialize()
|
|
234
367
|
|
|
235
368
|
self.stt_provider_insts.append(inst)
|
|
236
369
|
if (
|
|
237
|
-
self.
|
|
238
|
-
|
|
370
|
+
self.provider_stt_settings.get("provider_id")
|
|
371
|
+
== provider_config["id"]
|
|
239
372
|
):
|
|
240
373
|
self.curr_stt_provider_inst = inst
|
|
241
374
|
logger.info(
|
|
242
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
|
375
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
|
243
376
|
)
|
|
244
|
-
if not self.curr_stt_provider_inst
|
|
377
|
+
if not self.curr_stt_provider_inst:
|
|
245
378
|
self.curr_stt_provider_inst = inst
|
|
246
379
|
|
|
247
380
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
|
248
381
|
# TTS 任务
|
|
249
|
-
inst =
|
|
250
|
-
provider_config, self.provider_settings
|
|
251
|
-
)
|
|
382
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
252
383
|
|
|
253
384
|
if getattr(inst, "initialize", None):
|
|
254
385
|
await inst.initialize()
|
|
255
386
|
|
|
256
387
|
self.tts_provider_insts.append(inst)
|
|
257
|
-
if (
|
|
258
|
-
self.selected_tts_provider_id == provider_config["id"]
|
|
259
|
-
and self.tts_enabled
|
|
260
|
-
):
|
|
388
|
+
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
|
261
389
|
self.curr_tts_provider_inst = inst
|
|
262
390
|
logger.info(
|
|
263
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
|
391
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
|
264
392
|
)
|
|
265
|
-
if not self.curr_tts_provider_inst
|
|
393
|
+
if not self.curr_tts_provider_inst:
|
|
266
394
|
self.curr_tts_provider_inst = inst
|
|
267
395
|
|
|
268
396
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
|
269
397
|
# 文本生成任务
|
|
270
|
-
inst =
|
|
398
|
+
inst = cls_type(
|
|
271
399
|
provider_config,
|
|
272
400
|
self.provider_settings,
|
|
273
|
-
self.db_helper,
|
|
274
|
-
self.provider_settings.get("persistant_history", True),
|
|
275
|
-
self.selected_default_persona,
|
|
276
401
|
)
|
|
277
402
|
|
|
278
403
|
if getattr(inst, "initialize", None):
|
|
@@ -280,72 +405,77 @@ class ProviderManager:
|
|
|
280
405
|
|
|
281
406
|
self.provider_insts.append(inst)
|
|
282
407
|
if (
|
|
283
|
-
self.
|
|
284
|
-
|
|
408
|
+
self.provider_settings.get("default_provider_id")
|
|
409
|
+
== provider_config["id"]
|
|
285
410
|
):
|
|
286
411
|
self.curr_provider_inst = inst
|
|
287
412
|
logger.info(
|
|
288
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
|
413
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
289
414
|
)
|
|
290
|
-
if not self.curr_provider_inst
|
|
415
|
+
if not self.curr_provider_inst:
|
|
291
416
|
self.curr_provider_inst = inst
|
|
292
417
|
|
|
418
|
+
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
|
419
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
420
|
+
if getattr(inst, "initialize", None):
|
|
421
|
+
await inst.initialize()
|
|
422
|
+
self.embedding_provider_insts.append(inst)
|
|
423
|
+
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
424
|
+
inst = cls_type(provider_config, self.provider_settings)
|
|
425
|
+
if getattr(inst, "initialize", None):
|
|
426
|
+
await inst.initialize()
|
|
427
|
+
self.rerank_provider_insts.append(inst)
|
|
428
|
+
|
|
293
429
|
self.inst_map[provider_config["id"]] = inst
|
|
294
430
|
except Exception as e:
|
|
295
|
-
logger.error(traceback.format_exc())
|
|
296
431
|
logger.error(
|
|
297
|
-
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
432
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
433
|
+
)
|
|
434
|
+
raise Exception(
|
|
435
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
298
436
|
)
|
|
299
437
|
|
|
300
438
|
async def reload(self, provider_config: dict):
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
self.
|
|
315
|
-
|
|
316
|
-
and self.
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
|
322
|
-
)
|
|
439
|
+
async with self.reload_lock:
|
|
440
|
+
await self.terminate_provider(provider_config["id"])
|
|
441
|
+
if provider_config["enable"]:
|
|
442
|
+
await self.load_provider(provider_config)
|
|
443
|
+
|
|
444
|
+
# 和配置文件保持同步
|
|
445
|
+
self.providers_config = astrbot_config["provider"]
|
|
446
|
+
config_ids = [provider["id"] for provider in self.providers_config]
|
|
447
|
+
logger.info(f"providers in user's config: {config_ids}")
|
|
448
|
+
for key in list(self.inst_map.keys()):
|
|
449
|
+
if key not in config_ids:
|
|
450
|
+
await self.terminate_provider(key)
|
|
451
|
+
|
|
452
|
+
if len(self.provider_insts) == 0:
|
|
453
|
+
self.curr_provider_inst = None
|
|
454
|
+
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
|
455
|
+
self.curr_provider_inst = self.provider_insts[0]
|
|
456
|
+
logger.info(
|
|
457
|
+
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
|
458
|
+
)
|
|
323
459
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
logger.info(
|
|
334
|
-
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
|
335
|
-
)
|
|
460
|
+
if len(self.stt_provider_insts) == 0:
|
|
461
|
+
self.curr_stt_provider_inst = None
|
|
462
|
+
elif (
|
|
463
|
+
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
|
464
|
+
):
|
|
465
|
+
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
466
|
+
logger.info(
|
|
467
|
+
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
|
468
|
+
)
|
|
336
469
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
logger.info(
|
|
347
|
-
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
|
348
|
-
)
|
|
470
|
+
if len(self.tts_provider_insts) == 0:
|
|
471
|
+
self.curr_tts_provider_inst = None
|
|
472
|
+
elif (
|
|
473
|
+
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
|
474
|
+
):
|
|
475
|
+
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
476
|
+
logger.info(
|
|
477
|
+
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
|
478
|
+
)
|
|
349
479
|
|
|
350
480
|
def get_insts(self):
|
|
351
481
|
return self.provider_insts
|
|
@@ -353,15 +483,21 @@ class ProviderManager:
|
|
|
353
483
|
async def terminate_provider(self, provider_id: str):
|
|
354
484
|
if provider_id in self.inst_map:
|
|
355
485
|
logger.info(
|
|
356
|
-
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
|
|
486
|
+
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
|
357
487
|
)
|
|
358
488
|
|
|
359
489
|
if self.inst_map[provider_id] in self.provider_insts:
|
|
360
|
-
self.
|
|
490
|
+
prov_inst = self.inst_map[provider_id]
|
|
491
|
+
if isinstance(prov_inst, Provider):
|
|
492
|
+
self.provider_insts.remove(prov_inst)
|
|
361
493
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
|
362
|
-
self.
|
|
494
|
+
prov_inst = self.inst_map[provider_id]
|
|
495
|
+
if isinstance(prov_inst, STTProvider):
|
|
496
|
+
self.stt_provider_insts.remove(prov_inst)
|
|
363
497
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
|
364
|
-
self.
|
|
498
|
+
prov_inst = self.inst_map[provider_id]
|
|
499
|
+
if isinstance(prov_inst, TTSProvider):
|
|
500
|
+
self.tts_provider_insts.remove(prov_inst)
|
|
365
501
|
|
|
366
502
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
|
367
503
|
self.curr_provider_inst = None
|
|
@@ -371,16 +507,18 @@ class ProviderManager:
|
|
|
371
507
|
self.curr_tts_provider_inst = None
|
|
372
508
|
|
|
373
509
|
if getattr(self.inst_map[provider_id], "terminate", None):
|
|
374
|
-
await self.inst_map[provider_id].terminate()
|
|
510
|
+
await self.inst_map[provider_id].terminate() # type: ignore
|
|
375
511
|
|
|
376
512
|
logger.info(
|
|
377
|
-
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
|
513
|
+
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})",
|
|
378
514
|
)
|
|
379
515
|
del self.inst_map[provider_id]
|
|
380
516
|
|
|
381
517
|
async def terminate(self):
|
|
382
518
|
for provider_inst in self.provider_insts:
|
|
383
519
|
if hasattr(provider_inst, "terminate"):
|
|
384
|
-
await provider_inst.terminate()
|
|
385
|
-
|
|
386
|
-
|
|
520
|
+
await provider_inst.terminate() # type: ignore
|
|
521
|
+
try:
|
|
522
|
+
await self.llm_tools.disable_mcp_server()
|
|
523
|
+
except Exception:
|
|
524
|
+
logger.error("Error while disabling MCP servers", exc_info=True)
|