AstrBot 4.5.3__py3-none-any.whl → 4.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. astrbot/api/all.py +2 -1
  2. astrbot/api/provider/__init__.py +2 -1
  3. astrbot/core/agent/run_context.py +7 -2
  4. astrbot/core/agent/runners/base.py +7 -0
  5. astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
  6. astrbot/core/agent/tool.py +5 -6
  7. astrbot/core/astr_agent_context.py +13 -8
  8. astrbot/core/astr_agent_hooks.py +36 -0
  9. astrbot/core/astr_agent_run_util.py +80 -0
  10. astrbot/core/astr_agent_tool_exec.py +246 -0
  11. astrbot/core/config/default.py +53 -7
  12. astrbot/core/exceptions.py +9 -0
  13. astrbot/core/pipeline/context.py +1 -2
  14. astrbot/core/pipeline/context_utils.py +0 -65
  15. astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
  16. astrbot/core/pipeline/respond/stage.py +21 -20
  17. astrbot/core/platform/platform_metadata.py +3 -0
  18. astrbot/core/platform/register.py +2 -0
  19. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
  20. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
  21. astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
  22. astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
  23. astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
  24. astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
  25. astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
  26. astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
  27. astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
  28. astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
  29. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
  30. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
  31. astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
  32. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
  33. astrbot/core/provider/__init__.py +2 -2
  34. astrbot/core/provider/entities.py +40 -18
  35. astrbot/core/provider/func_tool_manager.py +15 -6
  36. astrbot/core/provider/manager.py +4 -1
  37. astrbot/core/provider/provider.py +7 -22
  38. astrbot/core/provider/register.py +2 -0
  39. astrbot/core/provider/sources/anthropic_source.py +0 -2
  40. astrbot/core/provider/sources/coze_source.py +0 -2
  41. astrbot/core/provider/sources/dashscope_source.py +1 -3
  42. astrbot/core/provider/sources/dify_source.py +0 -2
  43. astrbot/core/provider/sources/gemini_source.py +31 -3
  44. astrbot/core/provider/sources/groq_source.py +15 -0
  45. astrbot/core/provider/sources/openai_source.py +67 -21
  46. astrbot/core/provider/sources/zhipu_source.py +1 -6
  47. astrbot/core/star/context.py +197 -45
  48. astrbot/core/star/register/star_handler.py +30 -10
  49. astrbot/dashboard/routes/chat.py +5 -0
  50. {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/METADATA +55 -65
  51. {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/RECORD +54 -49
  52. {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/WHEEL +0 -0
  53. {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/entry_points.txt +0 -0
  54. {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,18 @@
1
1
  import abc
2
2
  import asyncio
3
3
  from collections.abc import AsyncGenerator
4
- from dataclasses import dataclass
5
4
 
6
5
  from astrbot.core.agent.message import Message
7
6
  from astrbot.core.agent.tool import ToolSet
8
- from astrbot.core.db.po import Personality
9
7
  from astrbot.core.provider.entities import (
10
8
  LLMResponse,
11
- ProviderType,
9
+ ProviderMeta,
12
10
  RerankResult,
13
11
  ToolCallsResult,
14
12
  )
15
13
  from astrbot.core.provider.register import provider_cls_map
16
14
 
17
15
 
18
- @dataclass
19
- class ProviderMeta:
20
- id: str
21
- model: str
22
- type: str
23
- provider_type: ProviderType
24
-
25
-
26
16
  class AbstractProvider(abc.ABC):
27
17
  """Provider Abstract Class"""
28
18
 
@@ -43,15 +33,15 @@ class AbstractProvider(abc.ABC):
43
33
  """Get the provider metadata"""
44
34
  provider_type_name = self.provider_config["type"]
45
35
  meta_data = provider_cls_map.get(provider_type_name)
46
- provider_type = meta_data.provider_type if meta_data else None
47
- if provider_type is None:
48
- raise ValueError(f"Cannot find provider type: {provider_type_name}")
49
- return ProviderMeta(
50
- id=self.provider_config["id"],
36
+ if not meta_data:
37
+ raise ValueError(f"Provider type {provider_type_name} not registered")
38
+ meta = ProviderMeta(
39
+ id=self.provider_config.get("id", "default"),
51
40
  model=self.get_model(),
52
41
  type=provider_type_name,
53
- provider_type=provider_type,
42
+ provider_type=meta_data.provider_type,
54
43
  )
44
+ return meta
55
45
 
56
46
 
57
47
  class Provider(AbstractProvider):
@@ -61,15 +51,10 @@ class Provider(AbstractProvider):
61
51
  self,
62
52
  provider_config: dict,
63
53
  provider_settings: dict,
64
- default_persona: Personality | None = None,
65
54
  ) -> None:
66
55
  super().__init__(provider_config)
67
-
68
56
  self.provider_settings = provider_settings
69
57
 
70
- self.curr_personality = default_persona
71
- """维护了当前的使用的 persona,即人格。可能为 None"""
72
-
73
58
  @abc.abstractmethod
74
59
  def get_current_key(self) -> str:
75
60
  raise NotImplementedError
@@ -36,6 +36,8 @@ def register_provider_adapter(
36
36
  default_config_tmpl["id"] = provider_type_name
37
37
 
38
38
  pm = ProviderMetaData(
39
+ id="default", # will be replaced when instantiated
40
+ model=None,
39
41
  type=provider_type_name,
40
42
  desc=desc,
41
43
  provider_type=provider_type,
@@ -25,12 +25,10 @@ class ProviderAnthropic(Provider):
25
25
  self,
26
26
  provider_config,
27
27
  provider_settings,
28
- default_persona=None,
29
28
  ) -> None:
30
29
  super().__init__(
31
30
  provider_config,
32
31
  provider_settings,
33
- default_persona,
34
32
  )
35
33
 
36
34
  self.chosen_api_key: str = ""
@@ -20,12 +20,10 @@ class ProviderCoze(Provider):
20
20
  self,
21
21
  provider_config,
22
22
  provider_settings,
23
- default_persona=None,
24
23
  ) -> None:
25
24
  super().__init__(
26
25
  provider_config,
27
26
  provider_settings,
28
- default_persona,
29
27
  )
30
28
  self.api_key = provider_config.get("coze_api_key", "")
31
29
  if not self.api_key:
@@ -8,7 +8,7 @@ from dashscope.app.application_response import ApplicationResponse
8
8
  from astrbot.core import logger, sp
9
9
  from astrbot.core.message.message_event_result import MessageChain
10
10
 
11
- from .. import Personality, Provider
11
+ from .. import Provider
12
12
  from ..entities import LLMResponse
13
13
  from ..register import register_provider_adapter
14
14
  from .openai_source import ProviderOpenAIOfficial
@@ -20,13 +20,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
20
20
  self,
21
21
  provider_config: dict,
22
22
  provider_settings: dict,
23
- default_persona: Personality | None = None,
24
23
  ) -> None:
25
24
  Provider.__init__(
26
25
  self,
27
26
  provider_config,
28
27
  provider_settings,
29
- default_persona,
30
28
  )
31
29
  self.api_key = provider_config.get("dashscope_api_key", "")
32
30
  if not self.api_key:
@@ -18,12 +18,10 @@ class ProviderDify(Provider):
18
18
  self,
19
19
  provider_config,
20
20
  provider_settings,
21
- default_persona=None,
22
21
  ) -> None:
23
22
  super().__init__(
24
23
  provider_config,
25
24
  provider_settings,
26
- default_persona,
27
25
  )
28
26
  self.api_key = provider_config.get("dify_api_key", "")
29
27
  if not self.api_key:
@@ -53,12 +53,10 @@ class ProviderGoogleGenAI(Provider):
53
53
  self,
54
54
  provider_config,
55
55
  provider_settings,
56
- default_persona=None,
57
56
  ) -> None:
58
57
  super().__init__(
59
58
  provider_config,
60
59
  provider_settings,
61
- default_persona,
62
60
  )
63
61
  self.api_keys: list = super().get_keys()
64
62
  self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
@@ -326,8 +324,18 @@ class ProviderGoogleGenAI(Provider):
326
324
 
327
325
  return gemini_contents
328
326
 
329
- @staticmethod
327
+ def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
328
+ """Extract reasoning content from candidate parts"""
329
+ if not candidate.content or not candidate.content.parts:
330
+ return ""
331
+
332
+ thought_buf: list[str] = [
333
+ (p.text or "") for p in candidate.content.parts if p.thought
334
+ ]
335
+ return "".join(thought_buf).strip()
336
+
330
337
  def _process_content_parts(
338
+ self,
331
339
  candidate: types.Candidate,
332
340
  llm_response: LLMResponse,
333
341
  ) -> MessageChain:
@@ -358,6 +366,11 @@ class ProviderGoogleGenAI(Provider):
358
366
  logger.warning(f"收到的 candidate.content.parts 为空: {candidate}")
359
367
  raise Exception("API 返回的 candidate.content.parts 为空。")
360
368
 
369
+ # 提取 reasoning content
370
+ reasoning = self._extract_reasoning_content(candidate)
371
+ if reasoning:
372
+ llm_response.reasoning_content = reasoning
373
+
361
374
  chain = []
362
375
  part: types.Part
363
376
 
@@ -515,6 +528,7 @@ class ProviderGoogleGenAI(Provider):
515
528
 
516
529
  # Accumulate the complete response text for the final response
517
530
  accumulated_text = ""
531
+ accumulated_reasoning = ""
518
532
  final_response = None
519
533
 
520
534
  async for chunk in result:
@@ -539,9 +553,19 @@ class ProviderGoogleGenAI(Provider):
539
553
  yield llm_response
540
554
  return
541
555
 
556
+ _f = False
557
+
558
+ # 提取 reasoning content
559
+ reasoning = self._extract_reasoning_content(chunk.candidates[0])
560
+ if reasoning:
561
+ _f = True
562
+ accumulated_reasoning += reasoning
563
+ llm_response.reasoning_content = reasoning
542
564
  if chunk.text:
565
+ _f = True
543
566
  accumulated_text += chunk.text
544
567
  llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
568
+ if _f:
545
569
  yield llm_response
546
570
 
547
571
  if chunk.candidates[0].finish_reason:
@@ -559,6 +583,10 @@ class ProviderGoogleGenAI(Provider):
559
583
  if not final_response:
560
584
  final_response = LLMResponse("assistant", is_chunk=False)
561
585
 
586
+ # Set the complete accumulated reasoning in the final response
587
+ if accumulated_reasoning:
588
+ final_response.reasoning_content = accumulated_reasoning
589
+
562
590
  # Set the complete accumulated text in the final response
563
591
  if accumulated_text:
564
592
  final_response.result_chain = MessageChain(
@@ -0,0 +1,15 @@
1
+ from ..register import register_provider_adapter
2
+ from .openai_source import ProviderOpenAIOfficial
3
+
4
+
5
+ @register_provider_adapter(
6
+ "groq_chat_completion", "Groq Chat Completion Provider Adapter"
7
+ )
8
+ class ProviderGroq(ProviderOpenAIOfficial):
9
+ def __init__(
10
+ self,
11
+ provider_config: dict,
12
+ provider_settings: dict,
13
+ ) -> None:
14
+ super().__init__(provider_config, provider_settings)
15
+ self.reasoning_key = "reasoning"
@@ -4,12 +4,14 @@ import inspect
4
4
  import json
5
5
  import os
6
6
  import random
7
+ import re
7
8
  from collections.abc import AsyncGenerator
8
9
 
9
10
  from openai import AsyncAzureOpenAI, AsyncOpenAI
10
11
  from openai._exceptions import NotFoundError, UnprocessableEntityError
11
12
  from openai.lib.streaming.chat._completions import ChatCompletionStreamState
12
13
  from openai.types.chat.chat_completion import ChatCompletion
14
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
13
15
 
14
16
  import astrbot.core.message.components as Comp
15
17
  from astrbot import logger
@@ -28,37 +30,37 @@ from ..register import register_provider_adapter
28
30
  "OpenAI API Chat Completion 提供商适配器",
29
31
  )
30
32
  class ProviderOpenAIOfficial(Provider):
31
- def __init__(
32
- self,
33
- provider_config,
34
- provider_settings,
35
- default_persona=None,
36
- ) -> None:
37
- super().__init__(
38
- provider_config,
39
- provider_settings,
40
- default_persona,
41
- )
33
+ def __init__(self, provider_config, provider_settings) -> None:
34
+ super().__init__(provider_config, provider_settings)
42
35
  self.chosen_api_key = None
43
36
  self.api_keys: list = super().get_keys()
44
37
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
45
38
  self.timeout = provider_config.get("timeout", 120)
39
+ self.custom_headers = provider_config.get("custom_headers", {})
46
40
  if isinstance(self.timeout, str):
47
41
  self.timeout = int(self.timeout)
48
- # 适配 azure openai #332
42
+
43
+ if not isinstance(self.custom_headers, dict) or not self.custom_headers:
44
+ self.custom_headers = None
45
+ else:
46
+ for key in self.custom_headers:
47
+ self.custom_headers[key] = str(self.custom_headers[key])
48
+
49
49
  if "api_version" in provider_config:
50
- # 使用 azure api
50
+ # Using Azure OpenAI API
51
51
  self.client = AsyncAzureOpenAI(
52
52
  api_key=self.chosen_api_key,
53
53
  api_version=provider_config.get("api_version", None),
54
+ default_headers=self.custom_headers,
54
55
  base_url=provider_config.get("api_base", ""),
55
56
  timeout=self.timeout,
56
57
  )
57
58
  else:
58
- # 使用 openai api
59
+ # Using OpenAI Official API
59
60
  self.client = AsyncOpenAI(
60
61
  api_key=self.chosen_api_key,
61
62
  base_url=provider_config.get("api_base", None),
63
+ default_headers=self.custom_headers,
62
64
  timeout=self.timeout,
63
65
  )
64
66
 
@@ -70,6 +72,8 @@ class ProviderOpenAIOfficial(Provider):
70
72
  model = model_config.get("model", "unknown")
71
73
  self.set_model(model)
72
74
 
75
+ self.reasoning_key = "reasoning_content"
76
+
73
77
  def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
74
78
  """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
75
79
 
@@ -147,7 +151,7 @@ class ProviderOpenAIOfficial(Provider):
147
151
 
148
152
  logger.debug(f"completion: {completion}")
149
153
 
150
- llm_response = await self.parse_openai_completion(completion, tools)
154
+ llm_response = await self._parse_openai_completion(completion, tools)
151
155
 
152
156
  return llm_response
153
157
 
@@ -200,36 +204,78 @@ class ProviderOpenAIOfficial(Provider):
200
204
  if len(chunk.choices) == 0:
201
205
  continue
202
206
  delta = chunk.choices[0].delta
203
- # 处理文本内容
207
+ # logger.debug(f"chunk delta: {delta}")
208
+ # handle the content delta
209
+ reasoning = self._extract_reasoning_content(chunk)
210
+ _y = False
211
+ if reasoning:
212
+ llm_response.reasoning_content = reasoning
213
+ _y = True
204
214
  if delta.content:
205
215
  completion_text = delta.content
206
216
  llm_response.result_chain = MessageChain(
207
217
  chain=[Comp.Plain(completion_text)],
208
218
  )
219
+ _y = True
220
+ if _y:
209
221
  yield llm_response
210
222
 
211
223
  final_completion = state.get_final_completion()
212
- llm_response = await self.parse_openai_completion(final_completion, tools)
224
+ llm_response = await self._parse_openai_completion(final_completion, tools)
213
225
 
214
226
  yield llm_response
215
227
 
216
- async def parse_openai_completion(
228
+ def _extract_reasoning_content(
229
+ self,
230
+ completion: ChatCompletion | ChatCompletionChunk,
231
+ ) -> str:
232
+ """Extract reasoning content from OpenAI ChatCompletion if available."""
233
+ reasoning_text = ""
234
+ if len(completion.choices) == 0:
235
+ return reasoning_text
236
+ if isinstance(completion, ChatCompletion):
237
+ choice = completion.choices[0]
238
+ reasoning_attr = getattr(choice.message, self.reasoning_key, None)
239
+ if reasoning_attr:
240
+ reasoning_text = str(reasoning_attr)
241
+ elif isinstance(completion, ChatCompletionChunk):
242
+ delta = completion.choices[0].delta
243
+ reasoning_attr = getattr(delta, self.reasoning_key, None)
244
+ if reasoning_attr:
245
+ reasoning_text = str(reasoning_attr)
246
+ return reasoning_text
247
+
248
+ async def _parse_openai_completion(
217
249
  self, completion: ChatCompletion, tools: ToolSet | None
218
250
  ) -> LLMResponse:
219
- """解析 OpenAI ChatCompletion 响应"""
251
+ """Parse OpenAI ChatCompletion into LLMResponse"""
220
252
  llm_response = LLMResponse("assistant")
221
253
 
222
254
  if len(completion.choices) == 0:
223
255
  raise Exception("API 返回的 completion 为空。")
224
256
  choice = completion.choices[0]
225
257
 
258
+ # parse the text completion
226
259
  if choice.message.content is not None:
227
260
  # text completion
228
261
  completion_text = str(choice.message.content).strip()
262
+ # specially, some providers may set <think> tags around reasoning content in the completion text,
263
+ # we use regex to remove them, and store then in reasoning_content field
264
+ reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
265
+ matches = reasoning_pattern.findall(completion_text)
266
+ if matches:
267
+ llm_response.reasoning_content = "\n".join(
268
+ [match.strip() for match in matches],
269
+ )
270
+ completion_text = reasoning_pattern.sub("", completion_text).strip()
229
271
  llm_response.result_chain = MessageChain().message(completion_text)
230
272
 
273
+ # parse the reasoning content if any
274
+ # the priority is higher than the <think> tag extraction
275
+ llm_response.reasoning_content = self._extract_reasoning_content(completion)
276
+
277
+ # parse tool calls if any
231
278
  if choice.message.tool_calls and tools is not None:
232
- # tools call (function calling)
233
279
  args_ls = []
234
280
  func_name_ls = []
235
281
  tool_call_ids = []
@@ -255,11 +301,11 @@ class ProviderOpenAIOfficial(Provider):
255
301
  llm_response.tools_call_name = func_name_ls
256
302
  llm_response.tools_call_ids = tool_call_ids
257
303
 
304
+ # specially handle finish reason
258
305
  if choice.finish_reason == "content_filter":
259
306
  raise Exception(
260
307
  "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
261
308
  )
262
-
263
309
  if llm_response.completion_text is None and not llm_response.tools_call_args:
264
310
  logger.error(f"API 返回的 completion 无法解析:{completion}。")
265
311
  raise Exception(f"API 返回的 completion 无法解析:{completion}。")
@@ -12,10 +12,5 @@ class ProviderZhipu(ProviderOpenAIOfficial):
12
12
  self,
13
13
  provider_config: dict,
14
14
  provider_settings: dict,
15
- default_persona=None,
16
15
  ) -> None:
17
- super().__init__(
18
- provider_config,
19
- provider_settings,
20
- default_persona,
21
- )
16
+ super().__init__(provider_config, provider_settings)