AstrBot 4.5.6__py3-none-any.whl → 4.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/all.py +2 -1
- astrbot/api/provider/__init__.py +2 -1
- astrbot/core/agent/run_context.py +7 -2
- astrbot/core/agent/runners/base.py +7 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
- astrbot/core/agent/tool.py +5 -6
- astrbot/core/astr_agent_context.py +13 -8
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/config/default.py +53 -7
- astrbot/core/exceptions.py +9 -0
- astrbot/core/pipeline/context.py +1 -2
- astrbot/core/pipeline/context_utils.py +0 -65
- astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
- astrbot/core/pipeline/respond/stage.py +21 -20
- astrbot/core/platform/platform_metadata.py +3 -0
- astrbot/core/platform/register.py +2 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
- astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
- astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
- astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
- astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
- astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
- astrbot/core/provider/__init__.py +2 -2
- astrbot/core/provider/entities.py +40 -18
- astrbot/core/provider/func_tool_manager.py +15 -6
- astrbot/core/provider/manager.py +4 -1
- astrbot/core/provider/provider.py +7 -22
- astrbot/core/provider/register.py +2 -0
- astrbot/core/provider/sources/anthropic_source.py +0 -2
- astrbot/core/provider/sources/coze_source.py +0 -2
- astrbot/core/provider/sources/dashscope_source.py +1 -3
- astrbot/core/provider/sources/dify_source.py +0 -2
- astrbot/core/provider/sources/gemini_source.py +31 -3
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/openai_source.py +67 -21
- astrbot/core/provider/sources/zhipu_source.py +1 -6
- astrbot/core/star/context.py +197 -45
- astrbot/core/star/register/star_handler.py +30 -10
- astrbot/dashboard/routes/chat.py +5 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/METADATA +2 -2
- {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/RECORD +54 -49
- {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/WHEEL +0 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,7 +10,6 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
|
|
|
10
10
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
11
11
|
from astrbot.core.star.star_handler import EventType
|
|
12
12
|
from astrbot.core.utils.path_util import path_Mapping
|
|
13
|
-
from astrbot.core.utils.session_lock import session_lock_manager
|
|
14
13
|
|
|
15
14
|
from ..context import PipelineContext, call_event_hook
|
|
16
15
|
from ..stage import Stage, register_stage
|
|
@@ -169,12 +168,15 @@ class RespondStage(Stage):
|
|
|
169
168
|
logger.warning("async_stream 为空,跳过发送。")
|
|
170
169
|
return
|
|
171
170
|
# 流式结果直接交付平台适配器处理
|
|
172
|
-
|
|
173
|
-
"
|
|
174
|
-
|
|
171
|
+
realtime_segmenting = (
|
|
172
|
+
self.config.get("provider_settings", {}).get(
|
|
173
|
+
"unsupported_streaming_strategy",
|
|
174
|
+
"realtime_segmenting",
|
|
175
|
+
)
|
|
176
|
+
== "realtime_segmenting"
|
|
175
177
|
)
|
|
176
178
|
logger.info(f"应用流式输出({event.get_platform_id()})")
|
|
177
|
-
await event.send_streaming(result.async_stream,
|
|
179
|
+
await event.send_streaming(result.async_stream, realtime_segmenting)
|
|
178
180
|
return
|
|
179
181
|
if len(result.chain) > 0:
|
|
180
182
|
# 检查路径映射
|
|
@@ -218,21 +220,20 @@ class RespondStage(Stage):
|
|
|
218
220
|
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
|
|
219
221
|
)
|
|
220
222
|
return
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
)
|
|
223
|
+
for comp in result.chain:
|
|
224
|
+
i = await self._calc_comp_interval(comp)
|
|
225
|
+
await asyncio.sleep(i)
|
|
226
|
+
try:
|
|
227
|
+
if comp.type in need_separately:
|
|
228
|
+
await event.send(MessageChain([comp]))
|
|
229
|
+
else:
|
|
230
|
+
await event.send(MessageChain([*header_comps, comp]))
|
|
231
|
+
header_comps.clear()
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(
|
|
234
|
+
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
|
235
|
+
exc_info=True,
|
|
236
|
+
)
|
|
236
237
|
else:
|
|
237
238
|
if all(
|
|
238
239
|
comp.type in {ComponentType.Reply, ComponentType.At}
|
|
@@ -14,6 +14,7 @@ def register_platform_adapter(
|
|
|
14
14
|
default_config_tmpl: dict | None = None,
|
|
15
15
|
adapter_display_name: str | None = None,
|
|
16
16
|
logo_path: str | None = None,
|
|
17
|
+
support_streaming_message: bool = True,
|
|
17
18
|
):
|
|
18
19
|
"""用于注册平台适配器的带参装饰器。
|
|
19
20
|
|
|
@@ -42,6 +43,7 @@ def register_platform_adapter(
|
|
|
42
43
|
default_config_tmpl=default_config_tmpl,
|
|
43
44
|
adapter_display_name=adapter_display_name,
|
|
44
45
|
logo_path=logo_path,
|
|
46
|
+
support_streaming_message=support_streaming_message,
|
|
45
47
|
)
|
|
46
48
|
platform_registry.append(pm)
|
|
47
49
|
platform_cls_map[adapter_name] = cls
|
|
@@ -29,6 +29,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
|
|
29
29
|
@register_platform_adapter(
|
|
30
30
|
"aiocqhttp",
|
|
31
31
|
"适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
|
|
32
|
+
support_streaming_message=False,
|
|
32
33
|
)
|
|
33
34
|
class AiocqhttpAdapter(Platform):
|
|
34
35
|
def __init__(
|
|
@@ -49,6 +50,7 @@ class AiocqhttpAdapter(Platform):
|
|
|
49
50
|
name="aiocqhttp",
|
|
50
51
|
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
|
51
52
|
id=self.config.get("id"),
|
|
53
|
+
support_streaming_message=False,
|
|
52
54
|
)
|
|
53
55
|
|
|
54
56
|
self.bot = CQHttp(
|
|
@@ -37,7 +37,9 @@ class MyEventHandler(dingtalk_stream.EventHandler):
|
|
|
37
37
|
return AckMessage.STATUS_OK, "OK"
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
@register_platform_adapter(
|
|
40
|
+
@register_platform_adapter(
|
|
41
|
+
"dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
|
|
42
|
+
)
|
|
41
43
|
class DingtalkPlatformAdapter(Platform):
|
|
42
44
|
def __init__(
|
|
43
45
|
self,
|
|
@@ -74,6 +76,14 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
74
76
|
)
|
|
75
77
|
self.client_ = client # 用于 websockets 的 client
|
|
76
78
|
|
|
79
|
+
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
|
|
80
|
+
if not dingtalk_id:
|
|
81
|
+
return dingtalk_id
|
|
82
|
+
prefix = "$:LWCP_v1:$"
|
|
83
|
+
if dingtalk_id.startswith(prefix):
|
|
84
|
+
return dingtalk_id[len(prefix) :]
|
|
85
|
+
return dingtalk_id
|
|
86
|
+
|
|
77
87
|
async def send_by_session(
|
|
78
88
|
self,
|
|
79
89
|
session: MessageSesion,
|
|
@@ -86,6 +96,7 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
86
96
|
name="dingtalk",
|
|
87
97
|
description="钉钉机器人官方 API 适配器",
|
|
88
98
|
id=self.config.get("id"),
|
|
99
|
+
support_streaming_message=False,
|
|
89
100
|
)
|
|
90
101
|
|
|
91
102
|
async def convert_msg(
|
|
@@ -102,10 +113,10 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
102
113
|
else MessageType.FRIEND_MESSAGE
|
|
103
114
|
)
|
|
104
115
|
abm.sender = MessageMember(
|
|
105
|
-
user_id=message.sender_id,
|
|
116
|
+
user_id=self._id_to_sid(message.sender_id),
|
|
106
117
|
nickname=message.sender_nick,
|
|
107
118
|
)
|
|
108
|
-
abm.self_id = message.chatbot_user_id
|
|
119
|
+
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
|
109
120
|
abm.message_id = message.message_id
|
|
110
121
|
abm.raw_message = message
|
|
111
122
|
|
|
@@ -113,8 +124,8 @@ class DingtalkPlatformAdapter(Platform):
|
|
|
113
124
|
# 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
|
|
114
125
|
if message.at_users:
|
|
115
126
|
for user in message.at_users:
|
|
116
|
-
if user.dingtalk_id:
|
|
117
|
-
abm.message.append(At(qq=
|
|
127
|
+
if id := self._id_to_sid(user.dingtalk_id):
|
|
128
|
+
abm.message.append(At(qq=id))
|
|
118
129
|
abm.group_id = message.conversation_id
|
|
119
130
|
if self.unique_session:
|
|
120
131
|
abm.session_id = abm.sender.user_id
|
|
@@ -34,7 +34,9 @@ else:
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
# 注册平台适配器
|
|
37
|
-
@register_platform_adapter(
|
|
37
|
+
@register_platform_adapter(
|
|
38
|
+
"discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
|
|
39
|
+
)
|
|
38
40
|
class DiscordPlatformAdapter(Platform):
|
|
39
41
|
def __init__(
|
|
40
42
|
self,
|
|
@@ -111,6 +113,7 @@ class DiscordPlatformAdapter(Platform):
|
|
|
111
113
|
"Discord 适配器",
|
|
112
114
|
id=self.config.get("id"),
|
|
113
115
|
default_config_tmpl=self.config,
|
|
116
|
+
support_streaming_message=False,
|
|
114
117
|
)
|
|
115
118
|
|
|
116
119
|
@override
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import base64
|
|
3
3
|
import binascii
|
|
4
|
-
import
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
5
|
from io import BytesIO
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
|
|
@@ -21,11 +21,6 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
|
|
|
21
21
|
from .client import DiscordBotClient
|
|
22
22
|
from .components import DiscordEmbed, DiscordView
|
|
23
23
|
|
|
24
|
-
if sys.version_info >= (3, 12):
|
|
25
|
-
from typing import override
|
|
26
|
-
else:
|
|
27
|
-
from typing_extensions import override
|
|
28
|
-
|
|
29
24
|
|
|
30
25
|
# 自定义Discord视图组件(兼容旧版本)
|
|
31
26
|
class DiscordViewComponent(BaseMessageComponent):
|
|
@@ -49,7 +44,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|
|
49
44
|
self.client = client
|
|
50
45
|
self.interaction_followup_webhook = interaction_followup_webhook
|
|
51
46
|
|
|
52
|
-
@override
|
|
53
47
|
async def send(self, message: MessageChain):
|
|
54
48
|
"""发送消息到Discord平台"""
|
|
55
49
|
# 解析消息链为 Discord 所需的对象
|
|
@@ -98,6 +92,21 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|
|
98
92
|
|
|
99
93
|
await super().send(message)
|
|
100
94
|
|
|
95
|
+
async def send_streaming(
|
|
96
|
+
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
|
97
|
+
):
|
|
98
|
+
buffer = None
|
|
99
|
+
async for chain in generator:
|
|
100
|
+
if not buffer:
|
|
101
|
+
buffer = chain
|
|
102
|
+
else:
|
|
103
|
+
buffer.chain.extend(chain.chain)
|
|
104
|
+
if not buffer:
|
|
105
|
+
return None
|
|
106
|
+
buffer.squash_plain()
|
|
107
|
+
await self.send(buffer)
|
|
108
|
+
return await super().send_streaming(generator, use_fallback)
|
|
109
|
+
|
|
101
110
|
async def _get_channel(self) -> discord.abc.Messageable | None:
|
|
102
111
|
"""获取当前事件对应的频道对象"""
|
|
103
112
|
try:
|
|
@@ -23,7 +23,9 @@ from ...register import register_platform_adapter
|
|
|
23
23
|
from .lark_event import LarkMessageEvent
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
@register_platform_adapter(
|
|
26
|
+
@register_platform_adapter(
|
|
27
|
+
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
|
|
28
|
+
)
|
|
27
29
|
class LarkPlatformAdapter(Platform):
|
|
28
30
|
def __init__(
|
|
29
31
|
self,
|
|
@@ -115,6 +117,7 @@ class LarkPlatformAdapter(Platform):
|
|
|
115
117
|
name="lark",
|
|
116
118
|
description="飞书机器人官方 API 适配器",
|
|
117
119
|
id=self.config.get("id"),
|
|
120
|
+
support_streaming_message=False,
|
|
118
121
|
)
|
|
119
122
|
|
|
120
123
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
|
@@ -45,7 +45,9 @@ MAX_FILE_UPLOAD_COUNT = 16
|
|
|
45
45
|
DEFAULT_UPLOAD_CONCURRENCY = 3
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
@register_platform_adapter(
|
|
48
|
+
@register_platform_adapter(
|
|
49
|
+
"misskey", "Misskey 平台适配器", support_streaming_message=False
|
|
50
|
+
)
|
|
49
51
|
class MisskeyPlatformAdapter(Platform):
|
|
50
52
|
def __init__(
|
|
51
53
|
self,
|
|
@@ -120,6 +122,7 @@ class MisskeyPlatformAdapter(Platform):
|
|
|
120
122
|
description="Misskey 平台适配器",
|
|
121
123
|
id=self.config.get("id", "misskey"),
|
|
122
124
|
default_config_tmpl=default_config,
|
|
125
|
+
support_streaming_message=False,
|
|
123
126
|
)
|
|
124
127
|
|
|
125
128
|
async def run(self):
|
|
@@ -29,8 +29,7 @@ from astrbot.core.platform.astr_message_event import MessageSession
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
@register_platform_adapter(
|
|
32
|
-
"satori",
|
|
33
|
-
"Satori 协议适配器",
|
|
32
|
+
"satori", "Satori 协议适配器", support_streaming_message=False
|
|
34
33
|
)
|
|
35
34
|
class SatoriPlatformAdapter(Platform):
|
|
36
35
|
def __init__(
|
|
@@ -60,6 +59,7 @@ class SatoriPlatformAdapter(Platform):
|
|
|
60
59
|
name="satori",
|
|
61
60
|
description="Satori 通用协议适配器",
|
|
62
61
|
id=self.config["id"],
|
|
62
|
+
support_streaming_message=False,
|
|
63
63
|
)
|
|
64
64
|
|
|
65
65
|
self.ws: ClientConnection | None = None
|
|
@@ -30,6 +30,7 @@ from .slack_event import SlackMessageEvent
|
|
|
30
30
|
@register_platform_adapter(
|
|
31
31
|
"slack",
|
|
32
32
|
"适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
|
33
|
+
support_streaming_message=False,
|
|
33
34
|
)
|
|
34
35
|
class SlackAdapter(Platform):
|
|
35
36
|
def __init__(
|
|
@@ -68,6 +69,7 @@ class SlackAdapter(Platform):
|
|
|
68
69
|
name="slack",
|
|
69
70
|
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
|
70
71
|
id=self.config.get("id"),
|
|
72
|
+
support_streaming_message=False,
|
|
71
73
|
)
|
|
72
74
|
|
|
73
75
|
# 初始化 Slack Web Client
|
|
@@ -163,6 +163,9 @@ class WebChatAdapter(Platform):
|
|
|
163
163
|
_, _, payload = message.raw_message # type: ignore
|
|
164
164
|
message_event.set_extra("selected_provider", payload.get("selected_provider"))
|
|
165
165
|
message_event.set_extra("selected_model", payload.get("selected_model"))
|
|
166
|
+
message_event.set_extra(
|
|
167
|
+
"enable_streaming", payload.get("enable_streaming", True)
|
|
168
|
+
)
|
|
166
169
|
|
|
167
170
|
self.commit_event(message_event)
|
|
168
171
|
|
|
@@ -109,6 +109,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|
|
109
109
|
|
|
110
110
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
|
111
111
|
final_data = ""
|
|
112
|
+
reasoning_content = ""
|
|
112
113
|
cid = self.session_id.split("!")[-1]
|
|
113
114
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
|
114
115
|
async for chain in generator:
|
|
@@ -124,16 +125,22 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|
|
124
125
|
)
|
|
125
126
|
final_data = ""
|
|
126
127
|
continue
|
|
127
|
-
|
|
128
|
+
|
|
129
|
+
r = await WebChatMessageEvent._send(
|
|
128
130
|
chain,
|
|
129
131
|
session_id=self.session_id,
|
|
130
132
|
streaming=True,
|
|
131
133
|
)
|
|
134
|
+
if chain.type == "reasoning":
|
|
135
|
+
reasoning_content += chain.get_plain_text()
|
|
136
|
+
else:
|
|
137
|
+
final_data += r
|
|
132
138
|
|
|
133
139
|
await web_chat_back_queue.put(
|
|
134
140
|
{
|
|
135
141
|
"type": "complete", # complete means we return the final result
|
|
136
142
|
"data": final_data,
|
|
143
|
+
"reasoning": reasoning_content,
|
|
137
144
|
"streaming": True,
|
|
138
145
|
"cid": cid,
|
|
139
146
|
},
|
|
@@ -32,7 +32,9 @@ except ImportError as e:
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
@register_platform_adapter(
|
|
35
|
+
@register_platform_adapter(
|
|
36
|
+
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
|
|
37
|
+
)
|
|
36
38
|
class WeChatPadProAdapter(Platform):
|
|
37
39
|
def __init__(
|
|
38
40
|
self,
|
|
@@ -51,6 +53,7 @@ class WeChatPadProAdapter(Platform):
|
|
|
51
53
|
name="wechatpadpro",
|
|
52
54
|
description="WeChatPadPro 消息平台适配器",
|
|
53
55
|
id=self.config.get("id", "wechatpadpro"),
|
|
56
|
+
support_streaming_message=False,
|
|
54
57
|
)
|
|
55
58
|
|
|
56
59
|
# 保存配置信息
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import base64
|
|
3
3
|
import io
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
4
5
|
from typing import TYPE_CHECKING
|
|
5
6
|
|
|
6
7
|
import aiohttp
|
|
@@ -50,6 +51,21 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
|
|
50
51
|
await self._send_voice(session, comp)
|
|
51
52
|
await super().send(message)
|
|
52
53
|
|
|
54
|
+
async def send_streaming(
|
|
55
|
+
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
|
56
|
+
):
|
|
57
|
+
buffer = None
|
|
58
|
+
async for chain in generator:
|
|
59
|
+
if not buffer:
|
|
60
|
+
buffer = chain
|
|
61
|
+
else:
|
|
62
|
+
buffer.chain.extend(chain.chain)
|
|
63
|
+
if not buffer:
|
|
64
|
+
return None
|
|
65
|
+
buffer.squash_plain()
|
|
66
|
+
await self.send(buffer)
|
|
67
|
+
return await super().send_streaming(generator, use_fallback)
|
|
68
|
+
|
|
53
69
|
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
|
|
54
70
|
b64 = await comp.convert_to_base64()
|
|
55
71
|
raw = self._validate_base64(b64)
|
|
@@ -110,7 +110,7 @@ class WecomServer:
|
|
|
110
110
|
await self.shutdown_event.wait()
|
|
111
111
|
|
|
112
112
|
|
|
113
|
-
@register_platform_adapter("wecom", "wecom 适配器")
|
|
113
|
+
@register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
|
|
114
114
|
class WecomPlatformAdapter(Platform):
|
|
115
115
|
def __init__(
|
|
116
116
|
self,
|
|
@@ -196,6 +196,7 @@ class WecomPlatformAdapter(Platform):
|
|
|
196
196
|
"wecom",
|
|
197
197
|
"wecom 适配器",
|
|
198
198
|
id=self.config.get("id", "wecom"),
|
|
199
|
+
support_streaming_message=False,
|
|
199
200
|
)
|
|
200
201
|
|
|
201
202
|
@override
|
|
@@ -113,7 +113,9 @@ class WecomServer:
|
|
|
113
113
|
await self.shutdown_event.wait()
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
@register_platform_adapter(
|
|
116
|
+
@register_platform_adapter(
|
|
117
|
+
"weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
|
|
118
|
+
)
|
|
117
119
|
class WeixinOfficialAccountPlatformAdapter(Platform):
|
|
118
120
|
def __init__(
|
|
119
121
|
self,
|
|
@@ -195,6 +197,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|
|
195
197
|
"weixin_official_account",
|
|
196
198
|
"微信公众平台 适配器",
|
|
197
199
|
id=self.config.get("id", "weixin_official_account"),
|
|
200
|
+
support_streaming_message=False,
|
|
198
201
|
)
|
|
199
202
|
|
|
200
203
|
@override
|
|
@@ -1,4 +1,4 @@
|
|
|
1
1
|
from .entities import ProviderMetaData
|
|
2
|
-
from .provider import
|
|
2
|
+
from .provider import Provider, STTProvider
|
|
3
3
|
|
|
4
|
-
__all__ = ["
|
|
4
|
+
__all__ = ["Provider", "ProviderMetaData", "STTProvider"]
|
|
@@ -30,18 +30,31 @@ class ProviderType(enum.Enum):
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
@dataclass
|
|
33
|
-
class
|
|
33
|
+
class ProviderMeta:
|
|
34
|
+
"""The basic metadata of a provider instance."""
|
|
35
|
+
|
|
36
|
+
id: str
|
|
37
|
+
"""the unique id of the provider instance that user configured"""
|
|
38
|
+
model: str | None
|
|
39
|
+
"""the model name of the provider instance currently used"""
|
|
34
40
|
type: str
|
|
35
|
-
"""
|
|
36
|
-
desc: str = ""
|
|
37
|
-
"""提供商适配器描述"""
|
|
41
|
+
"""the name of the provider adapter, such as openai, ollama"""
|
|
38
42
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
|
39
|
-
|
|
43
|
+
"""the capability type of the provider adapter"""
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ProviderMetaData(ProviderMeta):
|
|
48
|
+
"""The metadata of a provider adapter for registration."""
|
|
40
49
|
|
|
50
|
+
desc: str = ""
|
|
51
|
+
"""the short description of the provider adapter"""
|
|
52
|
+
cls_type: Any = None
|
|
53
|
+
"""the class type of the provider adapter"""
|
|
41
54
|
default_config_tmpl: dict | None = None
|
|
42
|
-
"""
|
|
55
|
+
"""the default configuration template of the provider adapter"""
|
|
43
56
|
provider_display_name: str | None = None
|
|
44
|
-
"""
|
|
57
|
+
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
|
|
45
58
|
|
|
46
59
|
|
|
47
60
|
@dataclass
|
|
@@ -60,12 +73,20 @@ class ToolCallsResult:
|
|
|
60
73
|
]
|
|
61
74
|
return ret
|
|
62
75
|
|
|
76
|
+
def to_openai_messages_model(
|
|
77
|
+
self,
|
|
78
|
+
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
|
|
79
|
+
return [
|
|
80
|
+
self.tool_calls_info,
|
|
81
|
+
*self.tool_calls_result,
|
|
82
|
+
]
|
|
83
|
+
|
|
63
84
|
|
|
64
85
|
@dataclass
|
|
65
86
|
class ProviderRequest:
|
|
66
|
-
prompt: str
|
|
87
|
+
prompt: str | None = None
|
|
67
88
|
"""提示词"""
|
|
68
|
-
session_id: str = ""
|
|
89
|
+
session_id: str | None = ""
|
|
69
90
|
"""会话 ID"""
|
|
70
91
|
image_urls: list[str] = field(default_factory=list)
|
|
71
92
|
"""图片 URL 列表"""
|
|
@@ -181,25 +202,28 @@ class ProviderRequest:
|
|
|
181
202
|
@dataclass
|
|
182
203
|
class LLMResponse:
|
|
183
204
|
role: str
|
|
184
|
-
"""
|
|
205
|
+
"""The role of the message, e.g., assistant, tool, err"""
|
|
185
206
|
result_chain: MessageChain | None = None
|
|
186
|
-
"""
|
|
207
|
+
"""A chain of message components representing the text completion from LLM."""
|
|
187
208
|
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
|
|
188
|
-
"""
|
|
209
|
+
"""Tool call arguments."""
|
|
189
210
|
tools_call_name: list[str] = field(default_factory=list)
|
|
190
|
-
"""
|
|
211
|
+
"""Tool call names."""
|
|
191
212
|
tools_call_ids: list[str] = field(default_factory=list)
|
|
192
|
-
"""
|
|
213
|
+
"""Tool call IDs."""
|
|
214
|
+
reasoning_content: str = ""
|
|
215
|
+
"""The reasoning content extracted from the LLM, if any."""
|
|
193
216
|
|
|
194
217
|
raw_completion: (
|
|
195
218
|
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
|
196
219
|
) = None
|
|
197
|
-
|
|
220
|
+
"""The raw completion response from the LLM provider."""
|
|
198
221
|
|
|
199
222
|
_completion_text: str = ""
|
|
223
|
+
"""The plain text of the completion."""
|
|
200
224
|
|
|
201
225
|
is_chunk: bool = False
|
|
202
|
-
"""
|
|
226
|
+
"""Indicates if the response is a chunked response."""
|
|
203
227
|
|
|
204
228
|
def __init__(
|
|
205
229
|
self,
|
|
@@ -213,7 +237,6 @@ class LLMResponse:
|
|
|
213
237
|
| GenerateContentResponse
|
|
214
238
|
| AnthropicMessage
|
|
215
239
|
| None = None,
|
|
216
|
-
_new_record: dict[str, Any] | None = None,
|
|
217
240
|
is_chunk: bool = False,
|
|
218
241
|
):
|
|
219
242
|
"""初始化 LLMResponse
|
|
@@ -241,7 +264,6 @@ class LLMResponse:
|
|
|
241
264
|
self.tools_call_name = tools_call_name
|
|
242
265
|
self.tools_call_ids = tools_call_ids
|
|
243
266
|
self.raw_completion = raw_completion
|
|
244
|
-
self._new_record = _new_record
|
|
245
267
|
self.is_chunk = is_chunk
|
|
246
268
|
|
|
247
269
|
@property
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import copy
|
|
4
5
|
import json
|
|
5
6
|
import os
|
|
6
7
|
from collections.abc import Awaitable, Callable
|
|
@@ -24,7 +25,16 @@ SUPPORTED_TYPES = [
|
|
|
24
25
|
"boolean",
|
|
25
26
|
] # json schema 支持的数据类型
|
|
26
27
|
|
|
27
|
-
|
|
28
|
+
PY_TO_JSON_TYPE = {
|
|
29
|
+
"int": "number",
|
|
30
|
+
"float": "number",
|
|
31
|
+
"bool": "boolean",
|
|
32
|
+
"str": "string",
|
|
33
|
+
"dict": "object",
|
|
34
|
+
"list": "array",
|
|
35
|
+
"tuple": "array",
|
|
36
|
+
"set": "array",
|
|
37
|
+
}
|
|
28
38
|
# alias
|
|
29
39
|
FuncTool = FunctionTool
|
|
30
40
|
|
|
@@ -106,7 +116,7 @@ class FunctionToolManager:
|
|
|
106
116
|
def spec_to_func(
|
|
107
117
|
self,
|
|
108
118
|
name: str,
|
|
109
|
-
func_args: list,
|
|
119
|
+
func_args: list[dict],
|
|
110
120
|
desc: str,
|
|
111
121
|
handler: Callable[..., Awaitable[Any]],
|
|
112
122
|
) -> FuncTool:
|
|
@@ -115,10 +125,9 @@ class FunctionToolManager:
|
|
|
115
125
|
"properties": {},
|
|
116
126
|
}
|
|
117
127
|
for param in func_args:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
}
|
|
128
|
+
p = copy.deepcopy(param)
|
|
129
|
+
p.pop("name", None)
|
|
130
|
+
params["properties"][param["name"]] = p
|
|
122
131
|
return FuncTool(
|
|
123
132
|
name=name,
|
|
124
133
|
parameters=params,
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -241,6 +241,8 @@ class ProviderManager:
|
|
|
241
241
|
)
|
|
242
242
|
case "zhipu_chat_completion":
|
|
243
243
|
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
|
244
|
+
case "groq_chat_completion":
|
|
245
|
+
from .sources.groq_source import ProviderGroq as ProviderGroq
|
|
244
246
|
case "anthropic_chat_completion":
|
|
245
247
|
from .sources.anthropic_source import (
|
|
246
248
|
ProviderAnthropic as ProviderAnthropic,
|
|
@@ -354,6 +356,8 @@ class ProviderManager:
|
|
|
354
356
|
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
|
355
357
|
return
|
|
356
358
|
|
|
359
|
+
provider_metadata.id = provider_config["id"]
|
|
360
|
+
|
|
357
361
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
|
358
362
|
# STT 任务
|
|
359
363
|
inst = cls_type(provider_config, self.provider_settings)
|
|
@@ -394,7 +398,6 @@ class ProviderManager:
|
|
|
394
398
|
inst = cls_type(
|
|
395
399
|
provider_config,
|
|
396
400
|
self.provider_settings,
|
|
397
|
-
self.selected_default_persona,
|
|
398
401
|
)
|
|
399
402
|
|
|
400
403
|
if getattr(inst, "initialize", None):
|