AstrBot 4.5.6__py3-none-any.whl → 4.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. astrbot/api/all.py +2 -1
  2. astrbot/api/provider/__init__.py +2 -1
  3. astrbot/core/agent/run_context.py +7 -2
  4. astrbot/core/agent/runners/base.py +7 -0
  5. astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
  6. astrbot/core/agent/tool.py +5 -6
  7. astrbot/core/astr_agent_context.py +13 -8
  8. astrbot/core/astr_agent_hooks.py +36 -0
  9. astrbot/core/astr_agent_run_util.py +80 -0
  10. astrbot/core/astr_agent_tool_exec.py +246 -0
  11. astrbot/core/config/default.py +53 -7
  12. astrbot/core/exceptions.py +9 -0
  13. astrbot/core/pipeline/context.py +1 -2
  14. astrbot/core/pipeline/context_utils.py +0 -65
  15. astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
  16. astrbot/core/pipeline/respond/stage.py +21 -20
  17. astrbot/core/platform/platform_metadata.py +3 -0
  18. astrbot/core/platform/register.py +2 -0
  19. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
  20. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
  21. astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
  22. astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
  23. astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
  24. astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
  25. astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
  26. astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
  27. astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
  28. astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
  29. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
  30. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
  31. astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
  32. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
  33. astrbot/core/provider/__init__.py +2 -2
  34. astrbot/core/provider/entities.py +40 -18
  35. astrbot/core/provider/func_tool_manager.py +15 -6
  36. astrbot/core/provider/manager.py +4 -1
  37. astrbot/core/provider/provider.py +7 -22
  38. astrbot/core/provider/register.py +2 -0
  39. astrbot/core/provider/sources/anthropic_source.py +0 -2
  40. astrbot/core/provider/sources/coze_source.py +0 -2
  41. astrbot/core/provider/sources/dashscope_source.py +1 -3
  42. astrbot/core/provider/sources/dify_source.py +0 -2
  43. astrbot/core/provider/sources/gemini_source.py +31 -3
  44. astrbot/core/provider/sources/groq_source.py +15 -0
  45. astrbot/core/provider/sources/openai_source.py +67 -21
  46. astrbot/core/provider/sources/zhipu_source.py +1 -6
  47. astrbot/core/star/context.py +197 -45
  48. astrbot/core/star/register/star_handler.py +30 -10
  49. astrbot/dashboard/routes/chat.py +5 -0
  50. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/METADATA +2 -2
  51. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/RECORD +54 -49
  52. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/WHEEL +0 -0
  53. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/entry_points.txt +0 -0
  54. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/licenses/LICENSE +0 -0
@@ -10,7 +10,6 @@ from astrbot.core.message.message_event_result import MessageChain, ResultConten
10
10
  from astrbot.core.platform.astr_message_event import AstrMessageEvent
11
11
  from astrbot.core.star.star_handler import EventType
12
12
  from astrbot.core.utils.path_util import path_Mapping
13
- from astrbot.core.utils.session_lock import session_lock_manager
14
13
 
15
14
  from ..context import PipelineContext, call_event_hook
16
15
  from ..stage import Stage, register_stage
@@ -169,12 +168,15 @@ class RespondStage(Stage):
169
168
  logger.warning("async_stream 为空,跳过发送。")
170
169
  return
171
170
  # 流式结果直接交付平台适配器处理
172
- use_fallback = self.config.get("provider_settings", {}).get(
173
- "streaming_segmented",
174
- False,
171
+ realtime_segmenting = (
172
+ self.config.get("provider_settings", {}).get(
173
+ "unsupported_streaming_strategy",
174
+ "realtime_segmenting",
175
+ )
176
+ == "realtime_segmenting"
175
177
  )
176
178
  logger.info(f"应用流式输出({event.get_platform_id()})")
177
- await event.send_streaming(result.async_stream, use_fallback)
179
+ await event.send_streaming(result.async_stream, realtime_segmenting)
178
180
  return
179
181
  if len(result.chain) > 0:
180
182
  # 检查路径映射
@@ -218,21 +220,20 @@ class RespondStage(Stage):
218
220
  f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}",
219
221
  )
220
222
  return
221
- async with session_lock_manager.acquire_lock(event.unified_msg_origin):
222
- for comp in result.chain:
223
- i = await self._calc_comp_interval(comp)
224
- await asyncio.sleep(i)
225
- try:
226
- if comp.type in need_separately:
227
- await event.send(MessageChain([comp]))
228
- else:
229
- await event.send(MessageChain([*header_comps, comp]))
230
- header_comps.clear()
231
- except Exception as e:
232
- logger.error(
233
- f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
234
- exc_info=True,
235
- )
223
+ for comp in result.chain:
224
+ i = await self._calc_comp_interval(comp)
225
+ await asyncio.sleep(i)
226
+ try:
227
+ if comp.type in need_separately:
228
+ await event.send(MessageChain([comp]))
229
+ else:
230
+ await event.send(MessageChain([*header_comps, comp]))
231
+ header_comps.clear()
232
+ except Exception as e:
233
+ logger.error(
234
+ f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
235
+ exc_info=True,
236
+ )
236
237
  else:
237
238
  if all(
238
239
  comp.type in {ComponentType.Reply, ComponentType.At}
@@ -16,3 +16,6 @@ class PlatformMetadata:
16
16
  """显示在 WebUI 配置页中的平台名称,如空则是 name"""
17
17
  logo_path: str | None = None
18
18
  """平台适配器的 logo 文件路径(相对于插件目录)"""
19
+
20
+ support_streaming_message: bool = True
21
+ """平台是否支持真实流式传输"""
@@ -14,6 +14,7 @@ def register_platform_adapter(
14
14
  default_config_tmpl: dict | None = None,
15
15
  adapter_display_name: str | None = None,
16
16
  logo_path: str | None = None,
17
+ support_streaming_message: bool = True,
17
18
  ):
18
19
  """用于注册平台适配器的带参装饰器。
19
20
 
@@ -42,6 +43,7 @@ def register_platform_adapter(
42
43
  default_config_tmpl=default_config_tmpl,
43
44
  adapter_display_name=adapter_display_name,
44
45
  logo_path=logo_path,
46
+ support_streaming_message=support_streaming_message,
45
47
  )
46
48
  platform_registry.append(pm)
47
49
  platform_cls_map[adapter_name] = cls
@@ -29,6 +29,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
29
29
  @register_platform_adapter(
30
30
  "aiocqhttp",
31
31
  "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。",
32
+ support_streaming_message=False,
32
33
  )
33
34
  class AiocqhttpAdapter(Platform):
34
35
  def __init__(
@@ -49,6 +50,7 @@ class AiocqhttpAdapter(Platform):
49
50
  name="aiocqhttp",
50
51
  description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
51
52
  id=self.config.get("id"),
53
+ support_streaming_message=False,
52
54
  )
53
55
 
54
56
  self.bot = CQHttp(
@@ -37,7 +37,9 @@ class MyEventHandler(dingtalk_stream.EventHandler):
37
37
  return AckMessage.STATUS_OK, "OK"
38
38
 
39
39
 
40
- @register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
40
+ @register_platform_adapter(
41
+ "dingtalk", "钉钉机器人官方 API 适配器", support_streaming_message=False
42
+ )
41
43
  class DingtalkPlatformAdapter(Platform):
42
44
  def __init__(
43
45
  self,
@@ -74,6 +76,14 @@ class DingtalkPlatformAdapter(Platform):
74
76
  )
75
77
  self.client_ = client # 用于 websockets 的 client
76
78
 
79
+ def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
80
+ if not dingtalk_id:
81
+ return dingtalk_id
82
+ prefix = "$:LWCP_v1:$"
83
+ if dingtalk_id.startswith(prefix):
84
+ return dingtalk_id[len(prefix) :]
85
+ return dingtalk_id
86
+
77
87
  async def send_by_session(
78
88
  self,
79
89
  session: MessageSesion,
@@ -86,6 +96,7 @@ class DingtalkPlatformAdapter(Platform):
86
96
  name="dingtalk",
87
97
  description="钉钉机器人官方 API 适配器",
88
98
  id=self.config.get("id"),
99
+ support_streaming_message=False,
89
100
  )
90
101
 
91
102
  async def convert_msg(
@@ -102,10 +113,10 @@ class DingtalkPlatformAdapter(Platform):
102
113
  else MessageType.FRIEND_MESSAGE
103
114
  )
104
115
  abm.sender = MessageMember(
105
- user_id=message.sender_id,
116
+ user_id=self._id_to_sid(message.sender_id),
106
117
  nickname=message.sender_nick,
107
118
  )
108
- abm.self_id = message.chatbot_user_id
119
+ abm.self_id = self._id_to_sid(message.chatbot_user_id)
109
120
  abm.message_id = message.message_id
110
121
  abm.raw_message = message
111
122
 
@@ -113,8 +124,8 @@ class DingtalkPlatformAdapter(Platform):
113
124
  # 处理所有被 @ 的用户(包括机器人自己,因 at_users 已包含)
114
125
  if message.at_users:
115
126
  for user in message.at_users:
116
- if user.dingtalk_id:
117
- abm.message.append(At(qq=user.dingtalk_id))
127
+ if id := self._id_to_sid(user.dingtalk_id):
128
+ abm.message.append(At(qq=id))
118
129
  abm.group_id = message.conversation_id
119
130
  if self.unique_session:
120
131
  abm.session_id = abm.sender.user_id
@@ -34,7 +34,9 @@ else:
34
34
 
35
35
 
36
36
  # 注册平台适配器
37
- @register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)")
37
+ @register_platform_adapter(
38
+ "discord", "Discord 适配器 (基于 Pycord)", support_streaming_message=False
39
+ )
38
40
  class DiscordPlatformAdapter(Platform):
39
41
  def __init__(
40
42
  self,
@@ -111,6 +113,7 @@ class DiscordPlatformAdapter(Platform):
111
113
  "Discord 适配器",
112
114
  id=self.config.get("id"),
113
115
  default_config_tmpl=self.config,
116
+ support_streaming_message=False,
114
117
  )
115
118
 
116
119
  @override
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import base64
3
3
  import binascii
4
- import sys
4
+ from collections.abc import AsyncGenerator
5
5
  from io import BytesIO
6
6
  from pathlib import Path
7
7
 
@@ -21,11 +21,6 @@ from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata
21
21
  from .client import DiscordBotClient
22
22
  from .components import DiscordEmbed, DiscordView
23
23
 
24
- if sys.version_info >= (3, 12):
25
- from typing import override
26
- else:
27
- from typing_extensions import override
28
-
29
24
 
30
25
  # 自定义Discord视图组件(兼容旧版本)
31
26
  class DiscordViewComponent(BaseMessageComponent):
@@ -49,7 +44,6 @@ class DiscordPlatformEvent(AstrMessageEvent):
49
44
  self.client = client
50
45
  self.interaction_followup_webhook = interaction_followup_webhook
51
46
 
52
- @override
53
47
  async def send(self, message: MessageChain):
54
48
  """发送消息到Discord平台"""
55
49
  # 解析消息链为 Discord 所需的对象
@@ -98,6 +92,21 @@ class DiscordPlatformEvent(AstrMessageEvent):
98
92
 
99
93
  await super().send(message)
100
94
 
95
+ async def send_streaming(
96
+ self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
97
+ ):
98
+ buffer = None
99
+ async for chain in generator:
100
+ if not buffer:
101
+ buffer = chain
102
+ else:
103
+ buffer.chain.extend(chain.chain)
104
+ if not buffer:
105
+ return None
106
+ buffer.squash_plain()
107
+ await self.send(buffer)
108
+ return await super().send_streaming(generator, use_fallback)
109
+
101
110
  async def _get_channel(self) -> discord.abc.Messageable | None:
102
111
  """获取当前事件对应的频道对象"""
103
112
  try:
@@ -23,7 +23,9 @@ from ...register import register_platform_adapter
23
23
  from .lark_event import LarkMessageEvent
24
24
 
25
25
 
26
- @register_platform_adapter("lark", "飞书机器人官方 API 适配器")
26
+ @register_platform_adapter(
27
+ "lark", "飞书机器人官方 API 适配器", support_streaming_message=False
28
+ )
27
29
  class LarkPlatformAdapter(Platform):
28
30
  def __init__(
29
31
  self,
@@ -115,6 +117,7 @@ class LarkPlatformAdapter(Platform):
115
117
  name="lark",
116
118
  description="飞书机器人官方 API 适配器",
117
119
  id=self.config.get("id"),
120
+ support_streaming_message=False,
118
121
  )
119
122
 
120
123
  async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
@@ -45,7 +45,9 @@ MAX_FILE_UPLOAD_COUNT = 16
45
45
  DEFAULT_UPLOAD_CONCURRENCY = 3
46
46
 
47
47
 
48
- @register_platform_adapter("misskey", "Misskey 平台适配器")
48
+ @register_platform_adapter(
49
+ "misskey", "Misskey 平台适配器", support_streaming_message=False
50
+ )
49
51
  class MisskeyPlatformAdapter(Platform):
50
52
  def __init__(
51
53
  self,
@@ -120,6 +122,7 @@ class MisskeyPlatformAdapter(Platform):
120
122
  description="Misskey 平台适配器",
121
123
  id=self.config.get("id", "misskey"),
122
124
  default_config_tmpl=default_config,
125
+ support_streaming_message=False,
123
126
  )
124
127
 
125
128
  async def run(self):
@@ -29,8 +29,7 @@ from astrbot.core.platform.astr_message_event import MessageSession
29
29
 
30
30
 
31
31
  @register_platform_adapter(
32
- "satori",
33
- "Satori 协议适配器",
32
+ "satori", "Satori 协议适配器", support_streaming_message=False
34
33
  )
35
34
  class SatoriPlatformAdapter(Platform):
36
35
  def __init__(
@@ -60,6 +59,7 @@ class SatoriPlatformAdapter(Platform):
60
59
  name="satori",
61
60
  description="Satori 通用协议适配器",
62
61
  id=self.config["id"],
62
+ support_streaming_message=False,
63
63
  )
64
64
 
65
65
  self.ws: ClientConnection | None = None
@@ -30,6 +30,7 @@ from .slack_event import SlackMessageEvent
30
30
  @register_platform_adapter(
31
31
  "slack",
32
32
  "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
33
+ support_streaming_message=False,
33
34
  )
34
35
  class SlackAdapter(Platform):
35
36
  def __init__(
@@ -68,6 +69,7 @@ class SlackAdapter(Platform):
68
69
  name="slack",
69
70
  description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
70
71
  id=self.config.get("id"),
72
+ support_streaming_message=False,
71
73
  )
72
74
 
73
75
  # 初始化 Slack Web Client
@@ -163,6 +163,9 @@ class WebChatAdapter(Platform):
163
163
  _, _, payload = message.raw_message # type: ignore
164
164
  message_event.set_extra("selected_provider", payload.get("selected_provider"))
165
165
  message_event.set_extra("selected_model", payload.get("selected_model"))
166
+ message_event.set_extra(
167
+ "enable_streaming", payload.get("enable_streaming", True)
168
+ )
166
169
 
167
170
  self.commit_event(message_event)
168
171
 
@@ -109,6 +109,7 @@ class WebChatMessageEvent(AstrMessageEvent):
109
109
 
110
110
  async def send_streaming(self, generator, use_fallback: bool = False):
111
111
  final_data = ""
112
+ reasoning_content = ""
112
113
  cid = self.session_id.split("!")[-1]
113
114
  web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
114
115
  async for chain in generator:
@@ -124,16 +125,22 @@ class WebChatMessageEvent(AstrMessageEvent):
124
125
  )
125
126
  final_data = ""
126
127
  continue
127
- final_data += await WebChatMessageEvent._send(
128
+
129
+ r = await WebChatMessageEvent._send(
128
130
  chain,
129
131
  session_id=self.session_id,
130
132
  streaming=True,
131
133
  )
134
+ if chain.type == "reasoning":
135
+ reasoning_content += chain.get_plain_text()
136
+ else:
137
+ final_data += r
132
138
 
133
139
  await web_chat_back_queue.put(
134
140
  {
135
141
  "type": "complete", # complete means we return the final result
136
142
  "data": final_data,
143
+ "reasoning": reasoning_content,
137
144
  "streaming": True,
138
145
  "cid": cid,
139
146
  },
@@ -32,7 +32,9 @@ except ImportError as e:
32
32
  )
33
33
 
34
34
 
35
- @register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器")
35
+ @register_platform_adapter(
36
+ "wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
37
+ )
36
38
  class WeChatPadProAdapter(Platform):
37
39
  def __init__(
38
40
  self,
@@ -51,6 +53,7 @@ class WeChatPadProAdapter(Platform):
51
53
  name="wechatpadpro",
52
54
  description="WeChatPadPro 消息平台适配器",
53
55
  id=self.config.get("id", "wechatpadpro"),
56
+ support_streaming_message=False,
54
57
  )
55
58
 
56
59
  # 保存配置信息
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import base64
3
3
  import io
4
+ from collections.abc import AsyncGenerator
4
5
  from typing import TYPE_CHECKING
5
6
 
6
7
  import aiohttp
@@ -50,6 +51,21 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
50
51
  await self._send_voice(session, comp)
51
52
  await super().send(message)
52
53
 
54
+ async def send_streaming(
55
+ self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
56
+ ):
57
+ buffer = None
58
+ async for chain in generator:
59
+ if not buffer:
60
+ buffer = chain
61
+ else:
62
+ buffer.chain.extend(chain.chain)
63
+ if not buffer:
64
+ return None
65
+ buffer.squash_plain()
66
+ await self.send(buffer)
67
+ return await super().send_streaming(generator, use_fallback)
68
+
53
69
  async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
54
70
  b64 = await comp.convert_to_base64()
55
71
  raw = self._validate_base64(b64)
@@ -110,7 +110,7 @@ class WecomServer:
110
110
  await self.shutdown_event.wait()
111
111
 
112
112
 
113
- @register_platform_adapter("wecom", "wecom 适配器")
113
+ @register_platform_adapter("wecom", "wecom 适配器", support_streaming_message=False)
114
114
  class WecomPlatformAdapter(Platform):
115
115
  def __init__(
116
116
  self,
@@ -196,6 +196,7 @@ class WecomPlatformAdapter(Platform):
196
196
  "wecom",
197
197
  "wecom 适配器",
198
198
  id=self.config.get("id", "wecom"),
199
+ support_streaming_message=False,
199
200
  )
200
201
 
201
202
  @override
@@ -113,7 +113,9 @@ class WecomServer:
113
113
  await self.shutdown_event.wait()
114
114
 
115
115
 
116
- @register_platform_adapter("weixin_official_account", "微信公众平台 适配器")
116
+ @register_platform_adapter(
117
+ "weixin_official_account", "微信公众平台 适配器", support_streaming_message=False
118
+ )
117
119
  class WeixinOfficialAccountPlatformAdapter(Platform):
118
120
  def __init__(
119
121
  self,
@@ -195,6 +197,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
195
197
  "weixin_official_account",
196
198
  "微信公众平台 适配器",
197
199
  id=self.config.get("id", "weixin_official_account"),
200
+ support_streaming_message=False,
198
201
  )
199
202
 
200
203
  @override
@@ -1,4 +1,4 @@
1
1
  from .entities import ProviderMetaData
2
- from .provider import Personality, Provider, STTProvider
2
+ from .provider import Provider, STTProvider
3
3
 
4
- __all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"]
4
+ __all__ = ["Provider", "ProviderMetaData", "STTProvider"]
@@ -30,18 +30,31 @@ class ProviderType(enum.Enum):
30
30
 
31
31
 
32
32
  @dataclass
33
- class ProviderMetaData:
33
+ class ProviderMeta:
34
+ """The basic metadata of a provider instance."""
35
+
36
+ id: str
37
+ """the unique id of the provider instance that user configured"""
38
+ model: str | None
39
+ """the model name of the provider instance currently used"""
34
40
  type: str
35
- """提供商适配器名称,如 openai, ollama"""
36
- desc: str = ""
37
- """提供商适配器描述"""
41
+ """the name of the provider adapter, such as openai, ollama"""
38
42
  provider_type: ProviderType = ProviderType.CHAT_COMPLETION
39
- cls_type: Any = None
43
+ """the capability type of the provider adapter"""
44
+
45
+
46
+ @dataclass
47
+ class ProviderMetaData(ProviderMeta):
48
+ """The metadata of a provider adapter for registration."""
40
49
 
50
+ desc: str = ""
51
+ """the short description of the provider adapter"""
52
+ cls_type: Any = None
53
+ """the class type of the provider adapter"""
41
54
  default_config_tmpl: dict | None = None
42
- """平台的默认配置模板"""
55
+ """the default configuration template of the provider adapter"""
43
56
  provider_display_name: str | None = None
44
- """显示在 WebUI 配置页中的提供商名称,如空则是 type"""
57
+ """the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
45
58
 
46
59
 
47
60
  @dataclass
@@ -60,12 +73,20 @@ class ToolCallsResult:
60
73
  ]
61
74
  return ret
62
75
 
76
+ def to_openai_messages_model(
77
+ self,
78
+ ) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
79
+ return [
80
+ self.tool_calls_info,
81
+ *self.tool_calls_result,
82
+ ]
83
+
63
84
 
64
85
  @dataclass
65
86
  class ProviderRequest:
66
- prompt: str
87
+ prompt: str | None = None
67
88
  """提示词"""
68
- session_id: str = ""
89
+ session_id: str | None = ""
69
90
  """会话 ID"""
70
91
  image_urls: list[str] = field(default_factory=list)
71
92
  """图片 URL 列表"""
@@ -181,25 +202,28 @@ class ProviderRequest:
181
202
  @dataclass
182
203
  class LLMResponse:
183
204
  role: str
184
- """角色, assistant, tool, err"""
205
+ """The role of the message, e.g., assistant, tool, err"""
185
206
  result_chain: MessageChain | None = None
186
- """返回的消息链"""
207
+ """A chain of message components representing the text completion from LLM."""
187
208
  tools_call_args: list[dict[str, Any]] = field(default_factory=list)
188
- """工具调用参数"""
209
+ """Tool call arguments."""
189
210
  tools_call_name: list[str] = field(default_factory=list)
190
- """工具调用名称"""
211
+ """Tool call names."""
191
212
  tools_call_ids: list[str] = field(default_factory=list)
192
- """工具调用 ID"""
213
+ """Tool call IDs."""
214
+ reasoning_content: str = ""
215
+ """The reasoning content extracted from the LLM, if any."""
193
216
 
194
217
  raw_completion: (
195
218
  ChatCompletion | GenerateContentResponse | AnthropicMessage | None
196
219
  ) = None
197
- _new_record: dict[str, Any] | None = None
220
+ """The raw completion response from the LLM provider."""
198
221
 
199
222
  _completion_text: str = ""
223
+ """The plain text of the completion."""
200
224
 
201
225
  is_chunk: bool = False
202
- """是否是流式输出的单个 Chunk"""
226
+ """Indicates if the response is a chunked response."""
203
227
 
204
228
  def __init__(
205
229
  self,
@@ -213,7 +237,6 @@ class LLMResponse:
213
237
  | GenerateContentResponse
214
238
  | AnthropicMessage
215
239
  | None = None,
216
- _new_record: dict[str, Any] | None = None,
217
240
  is_chunk: bool = False,
218
241
  ):
219
242
  """初始化 LLMResponse
@@ -241,7 +264,6 @@ class LLMResponse:
241
264
  self.tools_call_name = tools_call_name
242
265
  self.tools_call_ids = tools_call_ids
243
266
  self.raw_completion = raw_completion
244
- self._new_record = _new_record
245
267
  self.is_chunk = is_chunk
246
268
 
247
269
  @property
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import copy
4
5
  import json
5
6
  import os
6
7
  from collections.abc import Awaitable, Callable
@@ -24,7 +25,16 @@ SUPPORTED_TYPES = [
24
25
  "boolean",
25
26
  ] # json schema 支持的数据类型
26
27
 
27
-
28
+ PY_TO_JSON_TYPE = {
29
+ "int": "number",
30
+ "float": "number",
31
+ "bool": "boolean",
32
+ "str": "string",
33
+ "dict": "object",
34
+ "list": "array",
35
+ "tuple": "array",
36
+ "set": "array",
37
+ }
28
38
  # alias
29
39
  FuncTool = FunctionTool
30
40
 
@@ -106,7 +116,7 @@ class FunctionToolManager:
106
116
  def spec_to_func(
107
117
  self,
108
118
  name: str,
109
- func_args: list,
119
+ func_args: list[dict],
110
120
  desc: str,
111
121
  handler: Callable[..., Awaitable[Any]],
112
122
  ) -> FuncTool:
@@ -115,10 +125,9 @@ class FunctionToolManager:
115
125
  "properties": {},
116
126
  }
117
127
  for param in func_args:
118
- params["properties"][param["name"]] = {
119
- "type": param["type"],
120
- "description": param["description"],
121
- }
128
+ p = copy.deepcopy(param)
129
+ p.pop("name", None)
130
+ params["properties"][param["name"]] = p
122
131
  return FuncTool(
123
132
  name=name,
124
133
  parameters=params,
@@ -241,6 +241,8 @@ class ProviderManager:
241
241
  )
242
242
  case "zhipu_chat_completion":
243
243
  from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
244
+ case "groq_chat_completion":
245
+ from .sources.groq_source import ProviderGroq as ProviderGroq
244
246
  case "anthropic_chat_completion":
245
247
  from .sources.anthropic_source import (
246
248
  ProviderAnthropic as ProviderAnthropic,
@@ -354,6 +356,8 @@ class ProviderManager:
354
356
  logger.error(f"无法找到 {provider_metadata.type} 的类")
355
357
  return
356
358
 
359
+ provider_metadata.id = provider_config["id"]
360
+
357
361
  if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
358
362
  # STT 任务
359
363
  inst = cls_type(provider_config, self.provider_settings)
@@ -394,7 +398,6 @@ class ProviderManager:
394
398
  inst = cls_type(
395
399
  provider_config,
396
400
  self.provider_settings,
397
- self.selected_default_persona,
398
401
  )
399
402
 
400
403
  if getattr(inst, "initialize", None):