AstrBot 4.6.1__py3-none-any.whl → 4.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +3 -3
- astrbot/core/agent/runners/base.py +7 -4
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/{utils → agent/runners/dify}/dify_api_client.py +51 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -6
- astrbot/core/config/default.py +141 -26
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/core_lifecycle.py +11 -13
- astrbot/core/db/po.py +1 -1
- astrbot/core/db/sqlite.py +2 -2
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/{llm_request.py → agent_sub_stages/internal.py} +13 -34
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +1 -1
- astrbot/core/pipeline/process_stage/stage.py +8 -5
- astrbot/core/pipeline/result_decorate/stage.py +15 -5
- astrbot/core/provider/manager.py +43 -41
- astrbot/core/star/session_llm_manager.py +0 -107
- astrbot/core/star/session_plugin_manager.py +0 -81
- astrbot/core/umop_config_router.py +19 -0
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/shared_preferences.py +1 -28
- astrbot/dashboard/routes/chat.py +13 -1
- astrbot/dashboard/routes/config.py +20 -16
- astrbot/dashboard/routes/knowledge_base.py +0 -156
- astrbot/dashboard/routes/session_management.py +311 -606
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/METADATA +1 -1
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/RECORD +34 -30
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/WHEEL +1 -1
- astrbot/core/provider/sources/coze_source.py +0 -650
- astrbot/core/provider/sources/dashscope_source.py +0 -207
- astrbot/core/provider/sources/dify_source.py +0 -285
- /astrbot/core/{provider/sources → agent/runners/coze}/coze_api_client.py +0 -0
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/entry_points.txt +0 -0
- {astrbot-4.6.1.dist-info → astrbot-4.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,27 +21,24 @@ from astrbot.core.provider.entities import (
|
|
|
21
21
|
LLMResponse,
|
|
22
22
|
ProviderRequest,
|
|
23
23
|
)
|
|
24
|
-
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
|
25
24
|
from astrbot.core.star.star_handler import EventType, star_map
|
|
26
25
|
from astrbot.core.utils.metrics import Metric
|
|
27
26
|
from astrbot.core.utils.session_lock import session_lock_manager
|
|
28
27
|
|
|
29
|
-
from
|
|
30
|
-
from
|
|
31
|
-
from
|
|
32
|
-
from
|
|
33
|
-
from
|
|
34
|
-
from
|
|
35
|
-
from
|
|
28
|
+
from .....astr_agent_context import AgentContextWrapper
|
|
29
|
+
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
|
30
|
+
from .....astr_agent_run_util import AgentRunner, run_agent
|
|
31
|
+
from .....astr_agent_tool_exec import FunctionToolExecutor
|
|
32
|
+
from ....context import PipelineContext, call_event_hook
|
|
33
|
+
from ...stage import Stage
|
|
34
|
+
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
|
|
36
35
|
|
|
37
36
|
|
|
38
|
-
class
|
|
37
|
+
class InternalAgentSubStage(Stage):
|
|
39
38
|
async def initialize(self, ctx: PipelineContext) -> None:
|
|
40
39
|
self.ctx = ctx
|
|
41
40
|
conf = ctx.astrbot_config
|
|
42
41
|
settings = conf["provider_settings"]
|
|
43
|
-
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
|
44
|
-
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
|
45
42
|
self.max_context_length = settings["max_context_length"] # int
|
|
46
43
|
self.dequeue_context_length: int = min(
|
|
47
44
|
max(1, settings["dequeue_context_length"]),
|
|
@@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage):
|
|
|
59
56
|
self.show_reasoning = settings.get("display_reasoning_text", False)
|
|
60
57
|
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
|
|
61
58
|
|
|
62
|
-
for bwp in self.bot_wake_prefixs:
|
|
63
|
-
if self.provider_wake_prefix.startswith(bwp):
|
|
64
|
-
logger.info(
|
|
65
|
-
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
|
66
|
-
)
|
|
67
|
-
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
|
68
|
-
|
|
69
59
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
|
70
60
|
|
|
71
61
|
def _select_provider(self, event: AstrMessageEvent):
|
|
@@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage):
|
|
|
304
294
|
return fixed_messages
|
|
305
295
|
|
|
306
296
|
async def process(
|
|
307
|
-
self,
|
|
308
|
-
|
|
309
|
-
_nested: bool = False,
|
|
310
|
-
) -> None | AsyncGenerator[None, None]:
|
|
297
|
+
self, event: AstrMessageEvent, provider_wake_prefix: str
|
|
298
|
+
) -> AsyncGenerator[None, None]:
|
|
311
299
|
req: ProviderRequest | None = None
|
|
312
300
|
|
|
313
|
-
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
|
314
|
-
logger.debug("未启用 LLM 能力,跳过处理。")
|
|
315
|
-
return
|
|
316
|
-
|
|
317
|
-
# 检查会话级别的LLM启停状态
|
|
318
|
-
if not SessionServiceManager.should_process_llm_request(event):
|
|
319
|
-
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
|
320
|
-
return
|
|
321
|
-
|
|
322
301
|
provider = self._select_provider(event)
|
|
323
302
|
if provider is None:
|
|
324
303
|
return
|
|
@@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage):
|
|
|
348
327
|
req.image_urls = []
|
|
349
328
|
if sel_model := event.get_extra("selected_model"):
|
|
350
329
|
req.model = sel_model
|
|
351
|
-
if
|
|
352
|
-
|
|
330
|
+
if provider_wake_prefix and not event.message_str.startswith(
|
|
331
|
+
provider_wake_prefix
|
|
353
332
|
):
|
|
354
333
|
return
|
|
355
334
|
|
|
356
|
-
req.prompt = event.message_str[len(
|
|
335
|
+
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
|
357
336
|
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
|
358
337
|
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
|
359
338
|
for comp in event.message_obj.message:
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from astrbot.core import logger
|
|
6
|
+
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
|
|
7
|
+
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
|
|
8
|
+
DashscopeAgentRunner,
|
|
9
|
+
)
|
|
10
|
+
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
|
|
11
|
+
from astrbot.core.message.components import Image
|
|
12
|
+
from astrbot.core.message.message_event_result import (
|
|
13
|
+
MessageChain,
|
|
14
|
+
MessageEventResult,
|
|
15
|
+
ResultContentType,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from astrbot.core.agent.runners.base import BaseAgentRunner
|
|
20
|
+
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
21
|
+
from astrbot.core.provider.entities import (
|
|
22
|
+
ProviderRequest,
|
|
23
|
+
)
|
|
24
|
+
from astrbot.core.star.star_handler import EventType
|
|
25
|
+
from astrbot.core.utils.metrics import Metric
|
|
26
|
+
|
|
27
|
+
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
|
|
28
|
+
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
|
29
|
+
from ....context import PipelineContext, call_event_hook
|
|
30
|
+
from ...stage import Stage
|
|
31
|
+
|
|
32
|
+
AGENT_RUNNER_TYPE_KEY = {
|
|
33
|
+
"dify": "dify_agent_runner_provider_id",
|
|
34
|
+
"coze": "coze_agent_runner_provider_id",
|
|
35
|
+
"dashscope": "dashscope_agent_runner_provider_id",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def run_third_party_agent(
|
|
40
|
+
runner: "BaseAgentRunner",
|
|
41
|
+
stream_to_general: bool = False,
|
|
42
|
+
) -> AsyncGenerator[MessageChain | None, None]:
|
|
43
|
+
"""
|
|
44
|
+
运行第三方 agent runner 并转换响应格式
|
|
45
|
+
类似于 run_agent 函数,但专门处理第三方 agent runner
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
|
|
49
|
+
if resp.type == "streaming_delta":
|
|
50
|
+
if stream_to_general:
|
|
51
|
+
continue
|
|
52
|
+
yield resp.data["chain"]
|
|
53
|
+
elif resp.type == "llm_result":
|
|
54
|
+
if stream_to_general:
|
|
55
|
+
yield resp.data["chain"]
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.error(f"Third party agent runner error: {e}")
|
|
58
|
+
err_msg = (
|
|
59
|
+
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
|
|
60
|
+
f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
|
61
|
+
)
|
|
62
|
+
yield MessageChain().message(err_msg)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ThirdPartyAgentSubStage(Stage):
|
|
66
|
+
async def initialize(self, ctx: PipelineContext) -> None:
|
|
67
|
+
self.ctx = ctx
|
|
68
|
+
self.conf = ctx.astrbot_config
|
|
69
|
+
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
|
|
70
|
+
self.prov_id = self.conf["provider_settings"].get(
|
|
71
|
+
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
|
|
72
|
+
"",
|
|
73
|
+
)
|
|
74
|
+
settings = ctx.astrbot_config["provider_settings"]
|
|
75
|
+
self.streaming_response: bool = settings["streaming_response"]
|
|
76
|
+
self.unsupported_streaming_strategy: str = settings[
|
|
77
|
+
"unsupported_streaming_strategy"
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
async def process(
|
|
81
|
+
self, event: AstrMessageEvent, provider_wake_prefix: str
|
|
82
|
+
) -> AsyncGenerator[None, None]:
|
|
83
|
+
req: ProviderRequest | None = None
|
|
84
|
+
|
|
85
|
+
if provider_wake_prefix and not event.message_str.startswith(
|
|
86
|
+
provider_wake_prefix
|
|
87
|
+
):
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
self.prov_cfg: dict = next(
|
|
91
|
+
(p for p in self.conf["provider"] if p["id"] == self.prov_id),
|
|
92
|
+
{},
|
|
93
|
+
)
|
|
94
|
+
if not self.prov_id or not self.prov_cfg:
|
|
95
|
+
logger.error(
|
|
96
|
+
"Third Party Agent Runner provider ID is not configured properly."
|
|
97
|
+
)
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# make provider request
|
|
101
|
+
req = ProviderRequest()
|
|
102
|
+
req.session_id = event.unified_msg_origin
|
|
103
|
+
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
|
104
|
+
for comp in event.message_obj.message:
|
|
105
|
+
if isinstance(comp, Image):
|
|
106
|
+
image_path = await comp.convert_to_base64()
|
|
107
|
+
req.image_urls.append(image_path)
|
|
108
|
+
|
|
109
|
+
if not req.prompt and not req.image_urls:
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
# call event hook
|
|
113
|
+
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
if self.runner_type == "dify":
|
|
117
|
+
runner = DifyAgentRunner[AstrAgentContext]()
|
|
118
|
+
elif self.runner_type == "coze":
|
|
119
|
+
runner = CozeAgentRunner[AstrAgentContext]()
|
|
120
|
+
elif self.runner_type == "dashscope":
|
|
121
|
+
runner = DashscopeAgentRunner[AstrAgentContext]()
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Unsupported third party agent runner type: {self.runner_type}",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
astr_agent_ctx = AstrAgentContext(
|
|
128
|
+
context=self.ctx.plugin_manager.context,
|
|
129
|
+
event=event,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
streaming_response = self.streaming_response
|
|
133
|
+
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
|
134
|
+
streaming_response = bool(enable_streaming)
|
|
135
|
+
|
|
136
|
+
stream_to_general = (
|
|
137
|
+
self.unsupported_streaming_strategy == "turn_off"
|
|
138
|
+
and not event.platform_meta.support_streaming_message
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
await runner.reset(
|
|
142
|
+
request=req,
|
|
143
|
+
run_context=AgentContextWrapper(
|
|
144
|
+
context=astr_agent_ctx,
|
|
145
|
+
tool_call_timeout=60,
|
|
146
|
+
),
|
|
147
|
+
agent_hooks=MAIN_AGENT_HOOKS,
|
|
148
|
+
provider_config=self.prov_cfg,
|
|
149
|
+
streaming=streaming_response,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if streaming_response and not stream_to_general:
|
|
153
|
+
# 流式响应
|
|
154
|
+
event.set_result(
|
|
155
|
+
MessageEventResult()
|
|
156
|
+
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
|
157
|
+
.set_async_stream(
|
|
158
|
+
run_third_party_agent(
|
|
159
|
+
runner,
|
|
160
|
+
stream_to_general=False,
|
|
161
|
+
),
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
yield
|
|
165
|
+
if runner.done():
|
|
166
|
+
final_resp = runner.get_final_llm_resp()
|
|
167
|
+
if final_resp and final_resp.result_chain:
|
|
168
|
+
event.set_result(
|
|
169
|
+
MessageEventResult(
|
|
170
|
+
chain=final_resp.result_chain.chain or [],
|
|
171
|
+
result_content_type=ResultContentType.STREAMING_FINISH,
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
# 非流式响应或转换为普通响应
|
|
176
|
+
async for _ in run_third_party_agent(
|
|
177
|
+
runner,
|
|
178
|
+
stream_to_general=stream_to_general,
|
|
179
|
+
):
|
|
180
|
+
yield
|
|
181
|
+
|
|
182
|
+
final_resp = runner.get_final_llm_resp()
|
|
183
|
+
|
|
184
|
+
if not final_resp or not final_resp.result_chain:
|
|
185
|
+
logger.warning("Agent Runner 未返回最终结果。")
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
event.set_result(
|
|
189
|
+
MessageEventResult(
|
|
190
|
+
chain=final_resp.result_chain.chain or [],
|
|
191
|
+
result_content_type=ResultContentType.LLM_RESULT,
|
|
192
|
+
),
|
|
193
|
+
)
|
|
194
|
+
yield
|
|
195
|
+
|
|
196
|
+
asyncio.create_task(
|
|
197
|
+
Metric.upload(
|
|
198
|
+
llm_tick=1,
|
|
199
|
+
model_name=self.runner_type,
|
|
200
|
+
provider_type=self.runner_type,
|
|
201
|
+
),
|
|
202
|
+
)
|
|
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
|
|
24
24
|
async def process(
|
|
25
25
|
self,
|
|
26
26
|
event: AstrMessageEvent,
|
|
27
|
-
) ->
|
|
27
|
+
) -> AsyncGenerator[None, None]:
|
|
28
28
|
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
|
29
29
|
"activated_handlers",
|
|
30
30
|
)
|
|
@@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
|
7
7
|
|
|
8
8
|
from ..context import PipelineContext
|
|
9
9
|
from ..stage import Stage, register_stage
|
|
10
|
-
from .method.
|
|
10
|
+
from .method.agent_request import AgentRequestSubStage
|
|
11
11
|
from .method.star_request import StarRequestSubStage
|
|
12
12
|
|
|
13
13
|
|
|
@@ -17,9 +17,12 @@ class ProcessStage(Stage):
|
|
|
17
17
|
self.ctx = ctx
|
|
18
18
|
self.config = ctx.astrbot_config
|
|
19
19
|
self.plugin_manager = ctx.plugin_manager
|
|
20
|
-
self.llm_request_sub_stage = LLMRequestSubStage()
|
|
21
|
-
await self.llm_request_sub_stage.initialize(ctx)
|
|
22
20
|
|
|
21
|
+
# initialize agent sub stage
|
|
22
|
+
self.agent_sub_stage = AgentRequestSubStage()
|
|
23
|
+
await self.agent_sub_stage.initialize(ctx)
|
|
24
|
+
|
|
25
|
+
# initialize star request sub stage
|
|
23
26
|
self.star_request_sub_stage = StarRequestSubStage()
|
|
24
27
|
await self.star_request_sub_stage.initialize(ctx)
|
|
25
28
|
|
|
@@ -39,7 +42,7 @@ class ProcessStage(Stage):
|
|
|
39
42
|
# Handler 的 LLM 请求
|
|
40
43
|
event.set_extra("provider_request", resp)
|
|
41
44
|
_t = False
|
|
42
|
-
async for _ in self.
|
|
45
|
+
async for _ in self.agent_sub_stage.process(event):
|
|
43
46
|
_t = True
|
|
44
47
|
yield
|
|
45
48
|
if not _t:
|
|
@@ -67,5 +70,5 @@ class ProcessStage(Stage):
|
|
|
67
70
|
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
|
68
71
|
return
|
|
69
72
|
|
|
70
|
-
async for _ in self.
|
|
73
|
+
async for _ in self.agent_sub_stage.process(event):
|
|
71
74
|
yield
|
|
@@ -161,11 +161,21 @@ class ResultDecorateStage(Stage):
|
|
|
161
161
|
# 不分段回复
|
|
162
162
|
new_chain.append(comp)
|
|
163
163
|
continue
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
164
|
+
try:
|
|
165
|
+
split_response = re.findall(
|
|
166
|
+
self.regex,
|
|
167
|
+
comp.text,
|
|
168
|
+
re.DOTALL | re.MULTILINE,
|
|
169
|
+
)
|
|
170
|
+
except re.error:
|
|
171
|
+
logger.error(
|
|
172
|
+
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
|
|
173
|
+
)
|
|
174
|
+
split_response = re.findall(
|
|
175
|
+
r".*?[。?!~…]+|.+$",
|
|
176
|
+
comp.text,
|
|
177
|
+
re.DOTALL | re.MULTILINE,
|
|
178
|
+
)
|
|
169
179
|
if not split_response:
|
|
170
180
|
new_chain.append(comp)
|
|
171
181
|
continue
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import traceback
|
|
3
3
|
|
|
4
|
-
from astrbot.core import logger, sp
|
|
4
|
+
from astrbot.core import astrbot_config, logger, sp
|
|
5
5
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
6
6
|
from astrbot.core.db import BaseDatabase
|
|
7
7
|
|
|
@@ -24,6 +24,7 @@ class ProviderManager:
|
|
|
24
24
|
db_helper: BaseDatabase,
|
|
25
25
|
persona_mgr: PersonaManager,
|
|
26
26
|
):
|
|
27
|
+
self.reload_lock = asyncio.Lock()
|
|
27
28
|
self.persona_mgr = persona_mgr
|
|
28
29
|
self.acm = acm
|
|
29
30
|
config = acm.confs["default"]
|
|
@@ -226,6 +227,9 @@ class ProviderManager:
|
|
|
226
227
|
|
|
227
228
|
async def load_provider(self, provider_config: dict):
|
|
228
229
|
if not provider_config["enable"]:
|
|
230
|
+
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
|
231
|
+
return
|
|
232
|
+
if provider_config.get("provider_type", "") == "agent_runner":
|
|
229
233
|
return
|
|
230
234
|
|
|
231
235
|
logger.info(
|
|
@@ -247,14 +251,6 @@ class ProviderManager:
|
|
|
247
251
|
from .sources.anthropic_source import (
|
|
248
252
|
ProviderAnthropic as ProviderAnthropic,
|
|
249
253
|
)
|
|
250
|
-
case "dify":
|
|
251
|
-
from .sources.dify_source import ProviderDify as ProviderDify
|
|
252
|
-
case "coze":
|
|
253
|
-
from .sources.coze_source import ProviderCoze as ProviderCoze
|
|
254
|
-
case "dashscope":
|
|
255
|
-
from .sources.dashscope_source import (
|
|
256
|
-
ProviderDashscope as ProviderDashscope,
|
|
257
|
-
)
|
|
258
254
|
case "googlegenai_chat_completion":
|
|
259
255
|
from .sources.gemini_source import (
|
|
260
256
|
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
|
@@ -440,40 +436,46 @@ class ProviderManager:
|
|
|
440
436
|
)
|
|
441
437
|
|
|
442
438
|
async def reload(self, provider_config: dict):
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
# 和配置文件保持同步
|
|
448
|
-
config_ids = [provider["id"] for provider in self.providers_config]
|
|
449
|
-
logger.debug(f"providers in user's config: {config_ids}")
|
|
450
|
-
for key in list(self.inst_map.keys()):
|
|
451
|
-
if key not in config_ids:
|
|
452
|
-
await self.terminate_provider(key)
|
|
453
|
-
|
|
454
|
-
if len(self.provider_insts) == 0:
|
|
455
|
-
self.curr_provider_inst = None
|
|
456
|
-
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
|
457
|
-
self.curr_provider_inst = self.provider_insts[0]
|
|
458
|
-
logger.info(
|
|
459
|
-
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
|
460
|
-
)
|
|
439
|
+
async with self.reload_lock:
|
|
440
|
+
await self.terminate_provider(provider_config["id"])
|
|
441
|
+
if provider_config["enable"]:
|
|
442
|
+
await self.load_provider(provider_config)
|
|
461
443
|
|
|
462
|
-
|
|
463
|
-
self.
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
444
|
+
# 和配置文件保持同步
|
|
445
|
+
self.providers_config = astrbot_config["provider"]
|
|
446
|
+
config_ids = [provider["id"] for provider in self.providers_config]
|
|
447
|
+
logger.info(f"providers in user's config: {config_ids}")
|
|
448
|
+
for key in list(self.inst_map.keys()):
|
|
449
|
+
if key not in config_ids:
|
|
450
|
+
await self.terminate_provider(key)
|
|
469
451
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
452
|
+
if len(self.provider_insts) == 0:
|
|
453
|
+
self.curr_provider_inst = None
|
|
454
|
+
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
|
455
|
+
self.curr_provider_inst = self.provider_insts[0]
|
|
456
|
+
logger.info(
|
|
457
|
+
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
if len(self.stt_provider_insts) == 0:
|
|
461
|
+
self.curr_stt_provider_inst = None
|
|
462
|
+
elif (
|
|
463
|
+
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
|
|
464
|
+
):
|
|
465
|
+
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
466
|
+
logger.info(
|
|
467
|
+
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
if len(self.tts_provider_insts) == 0:
|
|
471
|
+
self.curr_tts_provider_inst = None
|
|
472
|
+
elif (
|
|
473
|
+
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
|
|
474
|
+
):
|
|
475
|
+
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
476
|
+
logger.info(
|
|
477
|
+
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
|
478
|
+
)
|
|
477
479
|
|
|
478
480
|
def get_insts(self):
|
|
479
481
|
return self.provider_insts
|
|
@@ -171,110 +171,3 @@ class SessionServiceManager:
|
|
|
171
171
|
|
|
172
172
|
# 如果没有配置,默认为启用(兼容性考虑)
|
|
173
173
|
return True
|
|
174
|
-
|
|
175
|
-
@staticmethod
|
|
176
|
-
def set_session_status(session_id: str, enabled: bool) -> None:
|
|
177
|
-
"""设置会话的整体启停状态
|
|
178
|
-
|
|
179
|
-
Args:
|
|
180
|
-
session_id: 会话ID (unified_msg_origin)
|
|
181
|
-
enabled: True表示启用,False表示禁用
|
|
182
|
-
|
|
183
|
-
"""
|
|
184
|
-
session_config = (
|
|
185
|
-
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
|
186
|
-
)
|
|
187
|
-
session_config["session_enabled"] = enabled
|
|
188
|
-
sp.put(
|
|
189
|
-
"session_service_config",
|
|
190
|
-
session_config,
|
|
191
|
-
scope="umo",
|
|
192
|
-
scope_id=session_id,
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
logger.info(
|
|
196
|
-
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}",
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
@staticmethod
|
|
200
|
-
def should_process_session_request(event: AstrMessageEvent) -> bool:
|
|
201
|
-
"""检查是否应该处理会话请求(会话整体启停检查)
|
|
202
|
-
|
|
203
|
-
Args:
|
|
204
|
-
event: 消息事件
|
|
205
|
-
|
|
206
|
-
Returns:
|
|
207
|
-
bool: True表示应该处理,False表示跳过
|
|
208
|
-
|
|
209
|
-
"""
|
|
210
|
-
session_id = event.unified_msg_origin
|
|
211
|
-
return SessionServiceManager.is_session_enabled(session_id)
|
|
212
|
-
|
|
213
|
-
# =============================================================================
|
|
214
|
-
# 会话命名相关方法
|
|
215
|
-
# =============================================================================
|
|
216
|
-
|
|
217
|
-
@staticmethod
|
|
218
|
-
def get_session_custom_name(session_id: str) -> str | None:
|
|
219
|
-
"""获取会话的自定义名称
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
session_id: 会话ID (unified_msg_origin)
|
|
223
|
-
|
|
224
|
-
Returns:
|
|
225
|
-
str: 自定义名称,如果没有设置则返回None
|
|
226
|
-
|
|
227
|
-
"""
|
|
228
|
-
session_services = sp.get(
|
|
229
|
-
"session_service_config",
|
|
230
|
-
{},
|
|
231
|
-
scope="umo",
|
|
232
|
-
scope_id=session_id,
|
|
233
|
-
)
|
|
234
|
-
return session_services.get("custom_name")
|
|
235
|
-
|
|
236
|
-
@staticmethod
|
|
237
|
-
def set_session_custom_name(session_id: str, custom_name: str) -> None:
|
|
238
|
-
"""设置会话的自定义名称
|
|
239
|
-
|
|
240
|
-
Args:
|
|
241
|
-
session_id: 会话ID (unified_msg_origin)
|
|
242
|
-
custom_name: 自定义名称,可以为空字符串来清除名称
|
|
243
|
-
|
|
244
|
-
"""
|
|
245
|
-
session_config = (
|
|
246
|
-
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
|
247
|
-
)
|
|
248
|
-
if custom_name and custom_name.strip():
|
|
249
|
-
session_config["custom_name"] = custom_name.strip()
|
|
250
|
-
else:
|
|
251
|
-
# 如果传入空名称,则删除自定义名称
|
|
252
|
-
session_config.pop("custom_name", None)
|
|
253
|
-
sp.put(
|
|
254
|
-
"session_service_config",
|
|
255
|
-
session_config,
|
|
256
|
-
scope="umo",
|
|
257
|
-
scope_id=session_id,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
logger.info(
|
|
261
|
-
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}",
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
@staticmethod
|
|
265
|
-
def get_session_display_name(session_id: str) -> str:
|
|
266
|
-
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
|
|
267
|
-
|
|
268
|
-
Args:
|
|
269
|
-
session_id: 会话ID (unified_msg_origin)
|
|
270
|
-
|
|
271
|
-
Returns:
|
|
272
|
-
str: 显示名称
|
|
273
|
-
|
|
274
|
-
"""
|
|
275
|
-
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
|
276
|
-
if custom_name:
|
|
277
|
-
return custom_name
|
|
278
|
-
|
|
279
|
-
# 如果没有自定义名称,返回session_id的最后一段
|
|
280
|
-
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|