AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +4 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +12 -10
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +32 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +77 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
- astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +97 -75
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.1.dist-info/RECORD +0 -260
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import asyncio
|
|
3
|
-
from
|
|
4
|
-
from
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from astrbot.core.agent.message import Message
|
|
5
7
|
from astrbot.core.agent.tool import ToolSet
|
|
8
|
+
from astrbot.core.db.po import Personality
|
|
6
9
|
from astrbot.core.provider.entities import (
|
|
7
10
|
LLMResponse,
|
|
8
|
-
ToolCallsResult,
|
|
9
11
|
ProviderType,
|
|
10
12
|
RerankResult,
|
|
13
|
+
ToolCallsResult,
|
|
11
14
|
)
|
|
12
15
|
from astrbot.core.provider.register import provider_cls_map
|
|
13
|
-
from astrbot.core.db.po import Personality
|
|
14
|
-
from dataclasses import dataclass
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
@dataclass
|
|
@@ -23,24 +24,28 @@ class ProviderMeta:
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class AbstractProvider(abc.ABC):
|
|
27
|
+
"""Provider Abstract Class"""
|
|
28
|
+
|
|
26
29
|
def __init__(self, provider_config: dict) -> None:
|
|
27
30
|
super().__init__()
|
|
28
31
|
self.model_name = ""
|
|
29
32
|
self.provider_config = provider_config
|
|
30
33
|
|
|
31
34
|
def set_model(self, model_name: str):
|
|
32
|
-
"""
|
|
35
|
+
"""Set the current model name"""
|
|
33
36
|
self.model_name = model_name
|
|
34
37
|
|
|
35
38
|
def get_model(self) -> str:
|
|
36
|
-
"""
|
|
39
|
+
"""Get the current model name"""
|
|
37
40
|
return self.model_name
|
|
38
41
|
|
|
39
42
|
def meta(self) -> ProviderMeta:
|
|
40
|
-
"""
|
|
43
|
+
"""Get the provider metadata"""
|
|
41
44
|
provider_type_name = self.provider_config["type"]
|
|
42
45
|
meta_data = provider_cls_map.get(provider_type_name)
|
|
43
46
|
provider_type = meta_data.provider_type if meta_data else None
|
|
47
|
+
if provider_type is None:
|
|
48
|
+
raise ValueError(f"Cannot find provider type: {provider_type_name}")
|
|
44
49
|
return ProviderMeta(
|
|
45
50
|
id=self.provider_config["id"],
|
|
46
51
|
model=self.get_model(),
|
|
@@ -50,6 +55,8 @@ class AbstractProvider(abc.ABC):
|
|
|
50
55
|
|
|
51
56
|
|
|
52
57
|
class Provider(AbstractProvider):
|
|
58
|
+
"""Chat Provider"""
|
|
59
|
+
|
|
53
60
|
def __init__(
|
|
54
61
|
self,
|
|
55
62
|
provider_config: dict,
|
|
@@ -65,99 +72,114 @@ class Provider(AbstractProvider):
|
|
|
65
72
|
|
|
66
73
|
@abc.abstractmethod
|
|
67
74
|
def get_current_key(self) -> str:
|
|
68
|
-
raise NotImplementedError
|
|
75
|
+
raise NotImplementedError
|
|
69
76
|
|
|
70
|
-
def get_keys(self) ->
|
|
77
|
+
def get_keys(self) -> list[str]:
|
|
71
78
|
"""获得提供商 Key"""
|
|
72
79
|
keys = self.provider_config.get("key", [""])
|
|
73
80
|
return keys or [""]
|
|
74
81
|
|
|
75
82
|
@abc.abstractmethod
|
|
76
83
|
def set_key(self, key: str):
|
|
77
|
-
raise NotImplementedError
|
|
84
|
+
raise NotImplementedError
|
|
78
85
|
|
|
79
86
|
@abc.abstractmethod
|
|
80
|
-
async def get_models(self) ->
|
|
87
|
+
async def get_models(self) -> list[str]:
|
|
81
88
|
"""获得支持的模型列表"""
|
|
82
|
-
raise NotImplementedError
|
|
89
|
+
raise NotImplementedError
|
|
83
90
|
|
|
84
91
|
@abc.abstractmethod
|
|
85
92
|
async def text_chat(
|
|
86
93
|
self,
|
|
87
|
-
prompt: str,
|
|
88
|
-
session_id: str = None,
|
|
89
|
-
image_urls: list[str] = None,
|
|
90
|
-
func_tool: ToolSet = None,
|
|
91
|
-
contexts: list = None,
|
|
92
|
-
system_prompt: str = None,
|
|
93
|
-
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
|
94
|
+
prompt: str | None = None,
|
|
95
|
+
session_id: str | None = None,
|
|
96
|
+
image_urls: list[str] | None = None,
|
|
97
|
+
func_tool: ToolSet | None = None,
|
|
98
|
+
contexts: list[Message] | list[dict] | None = None,
|
|
99
|
+
system_prompt: str | None = None,
|
|
100
|
+
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
94
101
|
model: str | None = None,
|
|
95
102
|
**kwargs,
|
|
96
103
|
) -> LLMResponse:
|
|
97
104
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
|
98
105
|
|
|
99
106
|
Args:
|
|
100
|
-
prompt:
|
|
107
|
+
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
101
108
|
session_id: 会话 ID(此属性已经被废弃)
|
|
102
109
|
image_urls: 图片 URL 列表
|
|
103
|
-
tools:
|
|
104
|
-
contexts:
|
|
110
|
+
tools: tool set
|
|
111
|
+
contexts: 上下文,和 prompt 二选一使用
|
|
105
112
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
106
113
|
kwargs: 其他参数
|
|
107
114
|
|
|
108
115
|
Notes:
|
|
109
116
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
110
117
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
118
|
+
|
|
111
119
|
"""
|
|
112
120
|
...
|
|
113
121
|
|
|
114
122
|
async def text_chat_stream(
|
|
115
123
|
self,
|
|
116
|
-
prompt: str,
|
|
117
|
-
session_id: str = None,
|
|
118
|
-
image_urls: list[str] = None,
|
|
119
|
-
func_tool: ToolSet = None,
|
|
120
|
-
contexts: list = None,
|
|
121
|
-
system_prompt: str = None,
|
|
122
|
-
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
|
124
|
+
prompt: str | None = None,
|
|
125
|
+
session_id: str | None = None,
|
|
126
|
+
image_urls: list[str] | None = None,
|
|
127
|
+
func_tool: ToolSet | None = None,
|
|
128
|
+
contexts: list[Message] | list[dict] | None = None,
|
|
129
|
+
system_prompt: str | None = None,
|
|
130
|
+
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
123
131
|
model: str | None = None,
|
|
124
132
|
**kwargs,
|
|
125
133
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
126
134
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
|
127
135
|
|
|
128
136
|
Args:
|
|
129
|
-
prompt:
|
|
137
|
+
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
130
138
|
session_id: 会话 ID(此属性已经被废弃)
|
|
131
139
|
image_urls: 图片 URL 列表
|
|
132
|
-
tools:
|
|
133
|
-
contexts:
|
|
140
|
+
tools: tool set
|
|
141
|
+
contexts: 上下文,和 prompt 二选一使用
|
|
134
142
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
135
143
|
kwargs: 其他参数
|
|
136
144
|
|
|
137
145
|
Notes:
|
|
138
146
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
139
147
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
148
|
+
|
|
140
149
|
"""
|
|
141
150
|
...
|
|
142
151
|
|
|
143
|
-
async def pop_record(self, context:
|
|
144
|
-
"""
|
|
145
|
-
弹出 context 第一条非系统提示词对话记录
|
|
146
|
-
"""
|
|
152
|
+
async def pop_record(self, context: list):
|
|
153
|
+
"""弹出 context 第一条非系统提示词对话记录"""
|
|
147
154
|
poped = 0
|
|
148
155
|
indexs_to_pop = []
|
|
149
156
|
for idx, record in enumerate(context):
|
|
150
157
|
if record["role"] == "system":
|
|
151
158
|
continue
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
break
|
|
159
|
+
indexs_to_pop.append(idx)
|
|
160
|
+
poped += 1
|
|
161
|
+
if poped == 2:
|
|
162
|
+
break
|
|
157
163
|
|
|
158
164
|
for idx in reversed(indexs_to_pop):
|
|
159
165
|
context.pop(idx)
|
|
160
166
|
|
|
167
|
+
def _ensure_message_to_dicts(
|
|
168
|
+
self,
|
|
169
|
+
messages: list[dict] | list[Message] | None,
|
|
170
|
+
) -> list[dict]:
|
|
171
|
+
"""Convert a list of Message objects to a list of dictionaries."""
|
|
172
|
+
if not messages:
|
|
173
|
+
return []
|
|
174
|
+
dicts: list[dict] = []
|
|
175
|
+
for message in messages:
|
|
176
|
+
if isinstance(message, Message):
|
|
177
|
+
dicts.append(message.model_dump())
|
|
178
|
+
else:
|
|
179
|
+
dicts.append(message)
|
|
180
|
+
|
|
181
|
+
return dicts
|
|
182
|
+
|
|
161
183
|
|
|
162
184
|
class STTProvider(AbstractProvider):
|
|
163
185
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
@@ -168,7 +190,7 @@ class STTProvider(AbstractProvider):
|
|
|
168
190
|
@abc.abstractmethod
|
|
169
191
|
async def get_text(self, audio_url: str) -> str:
|
|
170
192
|
"""获取音频的文本"""
|
|
171
|
-
raise NotImplementedError
|
|
193
|
+
raise NotImplementedError
|
|
172
194
|
|
|
173
195
|
|
|
174
196
|
class TTSProvider(AbstractProvider):
|
|
@@ -180,7 +202,7 @@ class TTSProvider(AbstractProvider):
|
|
|
180
202
|
@abc.abstractmethod
|
|
181
203
|
async def get_audio(self, text: str) -> str:
|
|
182
204
|
"""获取文本的音频,返回音频文件路径"""
|
|
183
|
-
raise NotImplementedError
|
|
205
|
+
raise NotImplementedError
|
|
184
206
|
|
|
185
207
|
|
|
186
208
|
class EmbeddingProvider(AbstractProvider):
|
|
@@ -223,6 +245,7 @@ class EmbeddingProvider(AbstractProvider):
|
|
|
223
245
|
|
|
224
246
|
Returns:
|
|
225
247
|
向量列表
|
|
248
|
+
|
|
226
249
|
"""
|
|
227
250
|
semaphore = asyncio.Semaphore(tasks_limit)
|
|
228
251
|
all_embeddings: list[list[float]] = []
|
|
@@ -246,7 +269,7 @@ class EmbeddingProvider(AbstractProvider):
|
|
|
246
269
|
# 最后一次重试失败,记录失败的批次
|
|
247
270
|
failed_batches.append((batch_idx, batch_texts))
|
|
248
271
|
raise Exception(
|
|
249
|
-
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {
|
|
272
|
+
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
|
|
250
273
|
)
|
|
251
274
|
# 等待一段时间后重试,使用指数退避
|
|
252
275
|
await asyncio.sleep(2**attempt)
|
|
@@ -279,7 +302,10 @@ class RerankProvider(AbstractProvider):
|
|
|
279
302
|
|
|
280
303
|
@abc.abstractmethod
|
|
281
304
|
async def rerank(
|
|
282
|
-
self,
|
|
305
|
+
self,
|
|
306
|
+
query: str,
|
|
307
|
+
documents: list[str],
|
|
308
|
+
top_n: int | None = None,
|
|
283
309
|
) -> list[RerankResult]:
|
|
284
310
|
"""获取查询和文档的重排序分数"""
|
|
285
311
|
...
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import List, Dict
|
|
2
|
-
from .entities import ProviderMetaData, ProviderType
|
|
3
1
|
from astrbot.core import logger
|
|
2
|
+
|
|
3
|
+
from .entities import ProviderMetaData, ProviderType
|
|
4
4
|
from .func_tool_manager import FuncCall
|
|
5
5
|
|
|
6
|
-
provider_registry:
|
|
6
|
+
provider_registry: list[ProviderMetaData] = []
|
|
7
7
|
"""维护了通过装饰器注册的 Provider"""
|
|
8
|
-
provider_cls_map:
|
|
8
|
+
provider_cls_map: dict[str, ProviderMetaData] = {}
|
|
9
9
|
"""维护了 Provider 类型名称和 ProviderMetadata 的映射"""
|
|
10
10
|
|
|
11
11
|
llm_tools = FuncCall()
|
|
@@ -15,15 +15,15 @@ def register_provider_adapter(
|
|
|
15
15
|
provider_type_name: str,
|
|
16
16
|
desc: str,
|
|
17
17
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
|
|
18
|
-
default_config_tmpl: dict = None,
|
|
19
|
-
provider_display_name: str = None,
|
|
18
|
+
default_config_tmpl: dict | None = None,
|
|
19
|
+
provider_display_name: str | None = None,
|
|
20
20
|
):
|
|
21
21
|
"""用于注册平台适配器的带参装饰器"""
|
|
22
22
|
|
|
23
23
|
def decorator(cls):
|
|
24
24
|
if provider_type_name in provider_cls_map:
|
|
25
25
|
raise ValueError(
|
|
26
|
-
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。"
|
|
26
|
+
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
|
|
27
27
|
)
|
|
28
28
|
|
|
29
29
|
# 添加必备选项
|
|
@@ -1,23 +1,24 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import anthropic
|
|
3
1
|
import base64
|
|
4
|
-
|
|
2
|
+
import json
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
5
4
|
from mimetypes import guess_type
|
|
6
5
|
|
|
6
|
+
import anthropic
|
|
7
7
|
from anthropic import AsyncAnthropic
|
|
8
8
|
from anthropic.types import Message
|
|
9
9
|
|
|
10
|
-
from astrbot.core.utils.io import download_image_by_url
|
|
11
|
-
from astrbot.api.provider import Provider
|
|
12
10
|
from astrbot import logger
|
|
11
|
+
from astrbot.api.provider import Provider
|
|
12
|
+
from astrbot.core.provider.entities import LLMResponse
|
|
13
13
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
|
14
|
+
from astrbot.core.utils.io import download_image_by_url
|
|
15
|
+
|
|
14
16
|
from ..register import register_provider_adapter
|
|
15
|
-
from astrbot.core.provider.entities import LLMResponse
|
|
16
|
-
from typing import AsyncGenerator
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
@register_provider_adapter(
|
|
20
|
-
"anthropic_chat_completion",
|
|
20
|
+
"anthropic_chat_completion",
|
|
21
|
+
"Anthropic Claude API 提供商适配器",
|
|
21
22
|
)
|
|
22
23
|
class ProviderAnthropic(Provider):
|
|
23
24
|
def __init__(
|
|
@@ -33,7 +34,7 @@ class ProviderAnthropic(Provider):
|
|
|
33
34
|
)
|
|
34
35
|
|
|
35
36
|
self.chosen_api_key: str = ""
|
|
36
|
-
self.api_keys:
|
|
37
|
+
self.api_keys: list = super().get_keys()
|
|
37
38
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
|
38
39
|
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
|
39
40
|
self.timeout = provider_config.get("timeout", 120)
|
|
@@ -41,7 +42,9 @@ class ProviderAnthropic(Provider):
|
|
|
41
42
|
self.timeout = int(self.timeout)
|
|
42
43
|
|
|
43
44
|
self.client = AsyncAnthropic(
|
|
44
|
-
api_key=self.chosen_api_key,
|
|
45
|
+
api_key=self.chosen_api_key,
|
|
46
|
+
timeout=self.timeout,
|
|
47
|
+
base_url=self.base_url,
|
|
45
48
|
)
|
|
46
49
|
|
|
47
50
|
self.set_model(provider_config["model_config"]["model"])
|
|
@@ -54,6 +57,7 @@ class ProviderAnthropic(Provider):
|
|
|
54
57
|
Returns:
|
|
55
58
|
system_prompt: 系统提示内容
|
|
56
59
|
new_messages: 处理后的消息列表,去除系统提示
|
|
60
|
+
|
|
57
61
|
"""
|
|
58
62
|
system_prompt = ""
|
|
59
63
|
new_messages = []
|
|
@@ -73,18 +77,19 @@ class ProviderAnthropic(Provider):
|
|
|
73
77
|
"input": (
|
|
74
78
|
json.loads(tool_call["function"]["arguments"])
|
|
75
79
|
if isinstance(
|
|
76
|
-
tool_call["function"]["arguments"],
|
|
80
|
+
tool_call["function"]["arguments"],
|
|
81
|
+
str,
|
|
77
82
|
)
|
|
78
83
|
else tool_call["function"]["arguments"]
|
|
79
84
|
),
|
|
80
85
|
"id": tool_call["id"],
|
|
81
|
-
}
|
|
86
|
+
},
|
|
82
87
|
)
|
|
83
88
|
new_messages.append(
|
|
84
89
|
{
|
|
85
90
|
"role": "assistant",
|
|
86
91
|
"content": blocks,
|
|
87
|
-
}
|
|
92
|
+
},
|
|
88
93
|
)
|
|
89
94
|
elif message["role"] == "tool":
|
|
90
95
|
new_messages.append(
|
|
@@ -95,9 +100,9 @@ class ProviderAnthropic(Provider):
|
|
|
95
100
|
"type": "tool_result",
|
|
96
101
|
"tool_use_id": message["tool_call_id"],
|
|
97
102
|
"content": message["content"],
|
|
98
|
-
}
|
|
103
|
+
},
|
|
99
104
|
],
|
|
100
|
-
}
|
|
105
|
+
},
|
|
101
106
|
)
|
|
102
107
|
else:
|
|
103
108
|
new_messages.append(message)
|
|
@@ -135,7 +140,9 @@ class ProviderAnthropic(Provider):
|
|
|
135
140
|
return llm_response
|
|
136
141
|
|
|
137
142
|
async def _query_stream(
|
|
138
|
-
self,
|
|
143
|
+
self,
|
|
144
|
+
payloads: dict,
|
|
145
|
+
tools: ToolSet | None,
|
|
139
146
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
140
147
|
if tools:
|
|
141
148
|
if tool_list := tools.get_func_desc_anthropic_style():
|
|
@@ -154,7 +161,9 @@ class ProviderAnthropic(Provider):
|
|
|
154
161
|
if event.content_block.type == "text":
|
|
155
162
|
# 文本块开始
|
|
156
163
|
yield LLMResponse(
|
|
157
|
-
role="assistant",
|
|
164
|
+
role="assistant",
|
|
165
|
+
completion_text="",
|
|
166
|
+
is_chunk=True,
|
|
158
167
|
)
|
|
159
168
|
elif event.content_block.type == "tool_use":
|
|
160
169
|
# 工具使用块开始,初始化缓冲区
|
|
@@ -198,7 +207,7 @@ class ProviderAnthropic(Provider):
|
|
|
198
207
|
"id": tool_info["id"],
|
|
199
208
|
"name": tool_info["name"],
|
|
200
209
|
"input": tool_info["input"],
|
|
201
|
-
}
|
|
210
|
+
},
|
|
202
211
|
)
|
|
203
212
|
|
|
204
213
|
yield LLMResponse(
|
|
@@ -218,7 +227,9 @@ class ProviderAnthropic(Provider):
|
|
|
218
227
|
|
|
219
228
|
# 返回最终的完整结果
|
|
220
229
|
final_response = LLMResponse(
|
|
221
|
-
role="assistant",
|
|
230
|
+
role="assistant",
|
|
231
|
+
completion_text=final_text,
|
|
232
|
+
is_chunk=False,
|
|
222
233
|
)
|
|
223
234
|
|
|
224
235
|
if final_tool_calls:
|
|
@@ -232,7 +243,7 @@ class ProviderAnthropic(Provider):
|
|
|
232
243
|
|
|
233
244
|
async def text_chat(
|
|
234
245
|
self,
|
|
235
|
-
prompt,
|
|
246
|
+
prompt=None,
|
|
236
247
|
session_id=None,
|
|
237
248
|
image_urls=None,
|
|
238
249
|
func_tool=None,
|
|
@@ -244,8 +255,13 @@ class ProviderAnthropic(Provider):
|
|
|
244
255
|
) -> LLMResponse:
|
|
245
256
|
if contexts is None:
|
|
246
257
|
contexts = []
|
|
247
|
-
new_record =
|
|
248
|
-
|
|
258
|
+
new_record = None
|
|
259
|
+
if prompt is not None:
|
|
260
|
+
new_record = await self.assemble_context(prompt, image_urls)
|
|
261
|
+
context_query = self._ensure_message_to_dicts(contexts)
|
|
262
|
+
if new_record:
|
|
263
|
+
context_query.append(new_record)
|
|
264
|
+
|
|
249
265
|
if system_prompt:
|
|
250
266
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
251
267
|
|
|
@@ -295,8 +311,12 @@ class ProviderAnthropic(Provider):
|
|
|
295
311
|
):
|
|
296
312
|
if contexts is None:
|
|
297
313
|
contexts = []
|
|
298
|
-
new_record =
|
|
299
|
-
|
|
314
|
+
new_record = None
|
|
315
|
+
if prompt is not None:
|
|
316
|
+
new_record = await self.assemble_context(prompt, image_urls)
|
|
317
|
+
context_query = self._ensure_message_to_dicts(contexts)
|
|
318
|
+
if new_record:
|
|
319
|
+
context_query.append(new_record)
|
|
300
320
|
if system_prompt:
|
|
301
321
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
302
322
|
|
|
@@ -326,7 +346,7 @@ class ProviderAnthropic(Provider):
|
|
|
326
346
|
async for llm_response in self._query_stream(payloads, func_tool):
|
|
327
347
|
yield llm_response
|
|
328
348
|
|
|
329
|
-
async def assemble_context(self, text: str, image_urls:
|
|
349
|
+
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
|
330
350
|
"""组装上下文,支持文本和图片"""
|
|
331
351
|
if not image_urls:
|
|
332
352
|
return {"role": "user", "content": text}
|
|
@@ -365,15 +385,13 @@ class ProviderAnthropic(Provider):
|
|
|
365
385
|
else image_data
|
|
366
386
|
),
|
|
367
387
|
},
|
|
368
|
-
}
|
|
388
|
+
},
|
|
369
389
|
)
|
|
370
390
|
|
|
371
391
|
return {"role": "user", "content": content}
|
|
372
392
|
|
|
373
393
|
async def encode_image_bs64(self, image_url: str) -> str:
|
|
374
|
-
"""
|
|
375
|
-
将图片转换为 base64
|
|
376
|
-
"""
|
|
394
|
+
"""将图片转换为 base64"""
|
|
377
395
|
if image_url.startswith("base64://"):
|
|
378
396
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
|
379
397
|
with open(image_url, "rb") as f:
|
|
@@ -384,7 +402,7 @@ class ProviderAnthropic(Provider):
|
|
|
384
402
|
def get_current_key(self) -> str:
|
|
385
403
|
return self.chosen_api_key
|
|
386
404
|
|
|
387
|
-
async def get_models(self) ->
|
|
405
|
+
async def get_models(self) -> list[str]:
|
|
388
406
|
models_str = []
|
|
389
407
|
models = await self.client.models.list()
|
|
390
408
|
models = sorted(models.data, key=lambda x: x.id)
|
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
3
|
import json
|
|
4
4
|
import re
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
import
|
|
5
|
+
import secrets
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import Dict
|
|
10
9
|
from xml.sax.saxutils import escape
|
|
11
10
|
|
|
12
11
|
from httpx import AsyncClient, Timeout
|
|
12
|
+
|
|
13
13
|
from astrbot.core.config.default import VERSION
|
|
14
14
|
|
|
15
15
|
from ..entities import ProviderType
|
|
@@ -21,7 +21,7 @@ TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class OTTSProvider:
|
|
24
|
-
def __init__(self, config:
|
|
24
|
+
def __init__(self, config: dict):
|
|
25
25
|
self.skey = config["OTTS_SKEY"]
|
|
26
26
|
self.api_url = config["OTTS_URL"]
|
|
27
27
|
self.auth_time_url = config["OTTS_AUTH_TIME"]
|
|
@@ -54,11 +54,13 @@ class OTTSProvider:
|
|
|
54
54
|
async def _generate_signature(self) -> str:
|
|
55
55
|
await self._sync_time()
|
|
56
56
|
timestamp = int(time.time()) + self.time_offset
|
|
57
|
-
nonce = "".join(
|
|
57
|
+
nonce = "".join(
|
|
58
|
+
secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
|
|
59
|
+
)
|
|
58
60
|
path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
|
|
59
61
|
return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
|
|
60
62
|
|
|
61
|
-
async def get_audio(self, text: str, voice_params:
|
|
63
|
+
async def get_audio(self, text: str, voice_params: dict) -> str:
|
|
62
64
|
file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
|
|
63
65
|
signature = await self._generate_signature()
|
|
64
66
|
for attempt in range(self.retry_count):
|
|
@@ -86,7 +88,7 @@ class OTTSProvider:
|
|
|
86
88
|
return str(file_path.resolve())
|
|
87
89
|
except Exception as e:
|
|
88
90
|
if attempt == self.retry_count - 1:
|
|
89
|
-
raise RuntimeError(f"OTTS请求失败: {
|
|
91
|
+
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
|
90
92
|
await asyncio.sleep(0.5 * (attempt + 1))
|
|
91
93
|
|
|
92
94
|
|
|
@@ -94,7 +96,8 @@ class AzureNativeProvider(TTSProvider):
|
|
|
94
96
|
def __init__(self, provider_config: dict, provider_settings: dict):
|
|
95
97
|
super().__init__(provider_config, provider_settings)
|
|
96
98
|
self.subscription_key = provider_config.get(
|
|
97
|
-
"azure_tts_subscription_key",
|
|
99
|
+
"azure_tts_subscription_key",
|
|
100
|
+
"",
|
|
98
101
|
).strip()
|
|
99
102
|
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
|
|
100
103
|
raise ValueError("无效的Azure订阅密钥")
|
|
@@ -119,7 +122,7 @@ class AzureNativeProvider(TTSProvider):
|
|
|
119
122
|
"User-Agent": f"AstrBot/{VERSION}",
|
|
120
123
|
"Content-Type": "application/ssml+xml",
|
|
121
124
|
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
|
122
|
-
}
|
|
125
|
+
},
|
|
123
126
|
)
|
|
124
127
|
return self
|
|
125
128
|
|
|
@@ -132,7 +135,8 @@ class AzureNativeProvider(TTSProvider):
|
|
|
132
135
|
f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
|
|
133
136
|
)
|
|
134
137
|
response = await self.client.post(
|
|
135
|
-
token_url,
|
|
138
|
+
token_url,
|
|
139
|
+
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
|
|
136
140
|
)
|
|
137
141
|
response.raise_for_status()
|
|
138
142
|
self.token = response.text
|