AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,161 +1,170 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
from typing import TypedDict, AsyncGenerator
|
|
5
|
-
from astrbot.core.provider.func_tool_manager import FuncCall
|
|
6
|
-
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
|
7
|
-
from dataclasses import dataclass
|
|
2
|
+
import asyncio
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
8
4
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
_mood_imitation_dialogs_processed: str = ""
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class ProviderMeta:
|
|
23
|
-
id: str
|
|
24
|
-
model: str
|
|
25
|
-
type: str
|
|
5
|
+
from astrbot.core.agent.message import Message
|
|
6
|
+
from astrbot.core.agent.tool import ToolSet
|
|
7
|
+
from astrbot.core.provider.entities import (
|
|
8
|
+
LLMResponse,
|
|
9
|
+
ProviderMeta,
|
|
10
|
+
RerankResult,
|
|
11
|
+
ToolCallsResult,
|
|
12
|
+
)
|
|
13
|
+
from astrbot.core.provider.register import provider_cls_map
|
|
26
14
|
|
|
27
15
|
|
|
28
16
|
class AbstractProvider(abc.ABC):
|
|
17
|
+
"""Provider Abstract Class"""
|
|
18
|
+
|
|
29
19
|
def __init__(self, provider_config: dict) -> None:
|
|
30
20
|
super().__init__()
|
|
31
21
|
self.model_name = ""
|
|
32
22
|
self.provider_config = provider_config
|
|
33
23
|
|
|
34
24
|
def set_model(self, model_name: str):
|
|
35
|
-
"""
|
|
25
|
+
"""Set the current model name"""
|
|
36
26
|
self.model_name = model_name
|
|
37
27
|
|
|
38
28
|
def get_model(self) -> str:
|
|
39
|
-
"""
|
|
29
|
+
"""Get the current model name"""
|
|
40
30
|
return self.model_name
|
|
41
31
|
|
|
42
32
|
def meta(self) -> ProviderMeta:
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
|
|
33
|
+
"""Get the provider metadata"""
|
|
34
|
+
provider_type_name = self.provider_config["type"]
|
|
35
|
+
meta_data = provider_cls_map.get(provider_type_name)
|
|
36
|
+
if not meta_data:
|
|
37
|
+
raise ValueError(f"Provider type {provider_type_name} not registered")
|
|
38
|
+
meta = ProviderMeta(
|
|
39
|
+
id=self.provider_config.get("id", "default"),
|
|
46
40
|
model=self.get_model(),
|
|
47
|
-
type=
|
|
41
|
+
type=provider_type_name,
|
|
42
|
+
provider_type=meta_data.provider_type,
|
|
48
43
|
)
|
|
44
|
+
return meta
|
|
49
45
|
|
|
50
46
|
|
|
51
47
|
class Provider(AbstractProvider):
|
|
48
|
+
"""Chat Provider"""
|
|
49
|
+
|
|
52
50
|
def __init__(
|
|
53
51
|
self,
|
|
54
52
|
provider_config: dict,
|
|
55
53
|
provider_settings: dict,
|
|
56
|
-
persistant_history: bool = True,
|
|
57
|
-
db_helper: BaseDatabase = None,
|
|
58
|
-
default_persona: Personality = None,
|
|
59
54
|
) -> None:
|
|
60
55
|
super().__init__(provider_config)
|
|
61
|
-
|
|
62
56
|
self.provider_settings = provider_settings
|
|
63
57
|
|
|
64
|
-
self.curr_personality: Personality = default_persona
|
|
65
|
-
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
|
66
|
-
|
|
67
58
|
@abc.abstractmethod
|
|
68
59
|
def get_current_key(self) -> str:
|
|
69
|
-
raise NotImplementedError
|
|
60
|
+
raise NotImplementedError
|
|
70
61
|
|
|
71
|
-
def get_keys(self) ->
|
|
62
|
+
def get_keys(self) -> list[str]:
|
|
72
63
|
"""获得提供商 Key"""
|
|
73
|
-
|
|
64
|
+
keys = self.provider_config.get("key", [""])
|
|
65
|
+
return keys or [""]
|
|
74
66
|
|
|
75
67
|
@abc.abstractmethod
|
|
76
68
|
def set_key(self, key: str):
|
|
77
|
-
raise NotImplementedError
|
|
69
|
+
raise NotImplementedError
|
|
78
70
|
|
|
79
71
|
@abc.abstractmethod
|
|
80
|
-
def get_models(self) ->
|
|
72
|
+
async def get_models(self) -> list[str]:
|
|
81
73
|
"""获得支持的模型列表"""
|
|
82
|
-
raise NotImplementedError
|
|
74
|
+
raise NotImplementedError
|
|
83
75
|
|
|
84
76
|
@abc.abstractmethod
|
|
85
77
|
async def text_chat(
|
|
86
78
|
self,
|
|
87
|
-
prompt: str,
|
|
88
|
-
session_id: str = None,
|
|
89
|
-
image_urls:
|
|
90
|
-
func_tool:
|
|
91
|
-
contexts:
|
|
92
|
-
system_prompt: str = None,
|
|
93
|
-
tool_calls_result: ToolCallsResult = None,
|
|
79
|
+
prompt: str | None = None,
|
|
80
|
+
session_id: str | None = None,
|
|
81
|
+
image_urls: list[str] | None = None,
|
|
82
|
+
func_tool: ToolSet | None = None,
|
|
83
|
+
contexts: list[Message] | list[dict] | None = None,
|
|
84
|
+
system_prompt: str | None = None,
|
|
85
|
+
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
86
|
+
model: str | None = None,
|
|
94
87
|
**kwargs,
|
|
95
88
|
) -> LLMResponse:
|
|
96
89
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
|
97
90
|
|
|
98
91
|
Args:
|
|
99
|
-
prompt:
|
|
92
|
+
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
100
93
|
session_id: 会话 ID(此属性已经被废弃)
|
|
101
94
|
image_urls: 图片 URL 列表
|
|
102
|
-
tools:
|
|
103
|
-
contexts:
|
|
95
|
+
tools: tool set
|
|
96
|
+
contexts: 上下文,和 prompt 二选一使用
|
|
104
97
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
105
98
|
kwargs: 其他参数
|
|
106
99
|
|
|
107
100
|
Notes:
|
|
108
101
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
109
102
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
103
|
+
|
|
110
104
|
"""
|
|
111
105
|
...
|
|
112
106
|
|
|
113
107
|
async def text_chat_stream(
|
|
114
108
|
self,
|
|
115
|
-
prompt: str,
|
|
116
|
-
session_id: str = None,
|
|
117
|
-
image_urls:
|
|
118
|
-
func_tool:
|
|
119
|
-
contexts:
|
|
120
|
-
system_prompt: str = None,
|
|
121
|
-
tool_calls_result: ToolCallsResult = None,
|
|
109
|
+
prompt: str | None = None,
|
|
110
|
+
session_id: str | None = None,
|
|
111
|
+
image_urls: list[str] | None = None,
|
|
112
|
+
func_tool: ToolSet | None = None,
|
|
113
|
+
contexts: list[Message] | list[dict] | None = None,
|
|
114
|
+
system_prompt: str | None = None,
|
|
115
|
+
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
116
|
+
model: str | None = None,
|
|
122
117
|
**kwargs,
|
|
123
118
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
124
119
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
|
125
120
|
|
|
126
121
|
Args:
|
|
127
|
-
prompt:
|
|
122
|
+
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
128
123
|
session_id: 会话 ID(此属性已经被废弃)
|
|
129
124
|
image_urls: 图片 URL 列表
|
|
130
|
-
tools:
|
|
131
|
-
contexts:
|
|
125
|
+
tools: tool set
|
|
126
|
+
contexts: 上下文,和 prompt 二选一使用
|
|
132
127
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
133
128
|
kwargs: 其他参数
|
|
134
129
|
|
|
135
130
|
Notes:
|
|
136
131
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
137
132
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
133
|
+
|
|
138
134
|
"""
|
|
139
135
|
...
|
|
140
136
|
|
|
141
|
-
async def pop_record(self, context:
|
|
142
|
-
"""
|
|
143
|
-
弹出 context 第一条非系统提示词对话记录
|
|
144
|
-
"""
|
|
137
|
+
async def pop_record(self, context: list):
|
|
138
|
+
"""弹出 context 第一条非系统提示词对话记录"""
|
|
145
139
|
poped = 0
|
|
146
140
|
indexs_to_pop = []
|
|
147
141
|
for idx, record in enumerate(context):
|
|
148
142
|
if record["role"] == "system":
|
|
149
143
|
continue
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
break
|
|
144
|
+
indexs_to_pop.append(idx)
|
|
145
|
+
poped += 1
|
|
146
|
+
if poped == 2:
|
|
147
|
+
break
|
|
155
148
|
|
|
156
149
|
for idx in reversed(indexs_to_pop):
|
|
157
150
|
context.pop(idx)
|
|
158
151
|
|
|
152
|
+
def _ensure_message_to_dicts(
|
|
153
|
+
self,
|
|
154
|
+
messages: list[dict] | list[Message] | None,
|
|
155
|
+
) -> list[dict]:
|
|
156
|
+
"""Convert a list of Message objects to a list of dictionaries."""
|
|
157
|
+
if not messages:
|
|
158
|
+
return []
|
|
159
|
+
dicts: list[dict] = []
|
|
160
|
+
for message in messages:
|
|
161
|
+
if isinstance(message, Message):
|
|
162
|
+
dicts.append(message.model_dump())
|
|
163
|
+
else:
|
|
164
|
+
dicts.append(message)
|
|
165
|
+
|
|
166
|
+
return dicts
|
|
167
|
+
|
|
159
168
|
|
|
160
169
|
class STTProvider(AbstractProvider):
|
|
161
170
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
@@ -166,7 +175,7 @@ class STTProvider(AbstractProvider):
|
|
|
166
175
|
@abc.abstractmethod
|
|
167
176
|
async def get_text(self, audio_url: str) -> str:
|
|
168
177
|
"""获取音频的文本"""
|
|
169
|
-
raise NotImplementedError
|
|
178
|
+
raise NotImplementedError
|
|
170
179
|
|
|
171
180
|
|
|
172
181
|
class TTSProvider(AbstractProvider):
|
|
@@ -178,4 +187,110 @@ class TTSProvider(AbstractProvider):
|
|
|
178
187
|
@abc.abstractmethod
|
|
179
188
|
async def get_audio(self, text: str) -> str:
|
|
180
189
|
"""获取文本的音频,返回音频文件路径"""
|
|
181
|
-
raise NotImplementedError
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class EmbeddingProvider(AbstractProvider):
|
|
194
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
195
|
+
super().__init__(provider_config)
|
|
196
|
+
self.provider_config = provider_config
|
|
197
|
+
self.provider_settings = provider_settings
|
|
198
|
+
|
|
199
|
+
@abc.abstractmethod
|
|
200
|
+
async def get_embedding(self, text: str) -> list[float]:
|
|
201
|
+
"""获取文本的向量"""
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
@abc.abstractmethod
|
|
205
|
+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
206
|
+
"""批量获取文本的向量"""
|
|
207
|
+
...
|
|
208
|
+
|
|
209
|
+
@abc.abstractmethod
|
|
210
|
+
def get_dim(self) -> int:
|
|
211
|
+
"""获取向量的维度"""
|
|
212
|
+
...
|
|
213
|
+
|
|
214
|
+
async def get_embeddings_batch(
|
|
215
|
+
self,
|
|
216
|
+
texts: list[str],
|
|
217
|
+
batch_size: int = 16,
|
|
218
|
+
tasks_limit: int = 3,
|
|
219
|
+
max_retries: int = 3,
|
|
220
|
+
progress_callback=None,
|
|
221
|
+
) -> list[list[float]]:
|
|
222
|
+
"""批量获取文本的向量,分批处理以节省内存
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
texts: 文本列表
|
|
226
|
+
batch_size: 每批处理的文本数量
|
|
227
|
+
tasks_limit: 并发任务数量限制
|
|
228
|
+
max_retries: 失败时的最大重试次数
|
|
229
|
+
progress_callback: 进度回调函数,接收参数 (current, total)
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
向量列表
|
|
233
|
+
|
|
234
|
+
"""
|
|
235
|
+
semaphore = asyncio.Semaphore(tasks_limit)
|
|
236
|
+
all_embeddings: list[list[float]] = []
|
|
237
|
+
failed_batches: list[tuple[int, list[str]]] = []
|
|
238
|
+
completed_count = 0
|
|
239
|
+
total_count = len(texts)
|
|
240
|
+
|
|
241
|
+
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
|
242
|
+
nonlocal completed_count
|
|
243
|
+
async with semaphore:
|
|
244
|
+
for attempt in range(max_retries):
|
|
245
|
+
try:
|
|
246
|
+
batch_embeddings = await self.get_embeddings(batch_texts)
|
|
247
|
+
all_embeddings.extend(batch_embeddings)
|
|
248
|
+
completed_count += len(batch_texts)
|
|
249
|
+
if progress_callback:
|
|
250
|
+
await progress_callback(completed_count, total_count)
|
|
251
|
+
return
|
|
252
|
+
except Exception as e:
|
|
253
|
+
if attempt == max_retries - 1:
|
|
254
|
+
# 最后一次重试失败,记录失败的批次
|
|
255
|
+
failed_batches.append((batch_idx, batch_texts))
|
|
256
|
+
raise Exception(
|
|
257
|
+
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
|
|
258
|
+
)
|
|
259
|
+
# 等待一段时间后重试,使用指数退避
|
|
260
|
+
await asyncio.sleep(2**attempt)
|
|
261
|
+
|
|
262
|
+
tasks = []
|
|
263
|
+
for i in range(0, len(texts), batch_size):
|
|
264
|
+
batch_texts = texts[i : i + batch_size]
|
|
265
|
+
batch_idx = i // batch_size
|
|
266
|
+
tasks.append(process_batch(batch_idx, batch_texts))
|
|
267
|
+
|
|
268
|
+
# 收集所有任务的结果,包括失败的任务
|
|
269
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
270
|
+
|
|
271
|
+
# 检查是否有失败的任务
|
|
272
|
+
errors = [r for r in results if isinstance(r, Exception)]
|
|
273
|
+
if errors:
|
|
274
|
+
error_msg = (
|
|
275
|
+
f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
|
|
276
|
+
)
|
|
277
|
+
raise Exception(error_msg)
|
|
278
|
+
|
|
279
|
+
return all_embeddings
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class RerankProvider(AbstractProvider):
|
|
283
|
+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
284
|
+
super().__init__(provider_config)
|
|
285
|
+
self.provider_config = provider_config
|
|
286
|
+
self.provider_settings = provider_settings
|
|
287
|
+
|
|
288
|
+
@abc.abstractmethod
|
|
289
|
+
async def rerank(
|
|
290
|
+
self,
|
|
291
|
+
query: str,
|
|
292
|
+
documents: list[str],
|
|
293
|
+
top_n: int | None = None,
|
|
294
|
+
) -> list[RerankResult]:
|
|
295
|
+
"""获取查询和文档的重排序分数"""
|
|
296
|
+
...
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import List, Dict
|
|
2
|
-
from .entities import ProviderMetaData, ProviderType
|
|
3
1
|
from astrbot.core import logger
|
|
2
|
+
|
|
3
|
+
from .entities import ProviderMetaData, ProviderType
|
|
4
4
|
from .func_tool_manager import FuncCall
|
|
5
5
|
|
|
6
|
-
provider_registry:
|
|
6
|
+
provider_registry: list[ProviderMetaData] = []
|
|
7
7
|
"""维护了通过装饰器注册的 Provider"""
|
|
8
|
-
provider_cls_map:
|
|
8
|
+
provider_cls_map: dict[str, ProviderMetaData] = {}
|
|
9
9
|
"""维护了 Provider 类型名称和 ProviderMetadata 的映射"""
|
|
10
10
|
|
|
11
11
|
llm_tools = FuncCall()
|
|
@@ -15,15 +15,15 @@ def register_provider_adapter(
|
|
|
15
15
|
provider_type_name: str,
|
|
16
16
|
desc: str,
|
|
17
17
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
|
18
|
-
default_config_tmpl: dict = None,
|
|
19
|
-
provider_display_name: str = None,
|
|
18
|
+
default_config_tmpl: dict | None = None,
|
|
19
|
+
provider_display_name: str | None = None,
|
|
20
20
|
):
|
|
21
21
|
"""用于注册平台适配器的带参装饰器"""
|
|
22
22
|
|
|
23
23
|
def decorator(cls):
|
|
24
24
|
if provider_type_name in provider_cls_map:
|
|
25
25
|
raise ValueError(
|
|
26
|
-
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。"
|
|
26
|
+
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
|
|
27
27
|
)
|
|
28
28
|
|
|
29
29
|
# 添加必备选项
|
|
@@ -36,6 +36,8 @@ def register_provider_adapter(
|
|
|
36
36
|
default_config_tmpl["id"] = provider_type_name
|
|
37
37
|
|
|
38
38
|
pm = ProviderMetaData(
|
|
39
|
+
id="default", # will be replaced when instantiated
|
|
40
|
+
model=None,
|
|
39
41
|
type=provider_type_name,
|
|
40
42
|
desc=desc,
|
|
41
43
|
provider_type=provider_type,
|