AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +44 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +18 -13
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +40 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +102 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  181. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +109 -84
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.0.dist-info/RECORD +0 -258
  242. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,22 @@
1
1
  import json
2
2
  import os
3
3
  import uuid
4
+ from collections.abc import AsyncIterator
5
+
4
6
  import aiohttp
5
- from typing import Dict, List, Union, AsyncIterator
6
- from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
7
8
  from astrbot.api import logger
9
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
10
+
8
11
  from ..entities import ProviderType
9
12
  from ..provider import TTSProvider
10
13
  from ..register import register_provider_adapter
11
14
 
12
15
 
13
16
  @register_provider_adapter(
14
- "minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
17
+ "minimax_tts_api",
18
+ "MiniMax TTS API",
19
+ provider_type=ProviderType.TEXT_TO_SPEECH,
15
20
  )
16
21
  class ProviderMiniMaxTTSAPI(TTSProvider):
17
22
  def __init__(
@@ -22,19 +27,21 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
22
27
  super().__init__(provider_config, provider_settings)
23
28
  self.chosen_api_key: str = provider_config.get("api_key", "")
24
29
  self.api_base: str = provider_config.get(
25
- "api_base", "https://api.minimax.chat/v1/t2a_v2"
30
+ "api_base",
31
+ "https://api.minimax.chat/v1/t2a_v2",
26
32
  )
27
33
  self.group_id: str = provider_config.get("minimax-group-id", "")
28
34
  self.set_model(provider_config.get("model", ""))
29
35
  self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
30
36
  self.is_timber_weight: bool = provider_config.get(
31
- "minimax-is-timber-weight", False
37
+ "minimax-is-timber-weight",
38
+ False,
32
39
  )
33
- self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads(
40
+ self.timber_weight: list[dict[str, str | int]] = json.loads(
34
41
  provider_config.get(
35
42
  "minimax-timber-weight",
36
43
  '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
37
- )
44
+ ),
38
45
  )
39
46
 
40
47
  self.voice_setting: dict = {
@@ -47,7 +54,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
47
54
  "emotion": provider_config.get("minimax-voice-emotion", "neutral"),
48
55
  "latex_read": provider_config.get("minimax-voice-latex", False),
49
56
  "english_normalization": provider_config.get(
50
- "minimax-voice-english-normalization", False
57
+ "minimax-voice-english-normalization",
58
+ False,
51
59
  ),
52
60
  }
53
61
 
@@ -66,7 +74,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
66
74
 
67
75
  def _build_tts_stream_body(self, text: str):
68
76
  """构建流式请求体"""
69
- dict_body: Dict[str, object] = {
77
+ dict_body: dict[str, object] = {
70
78
  "model": self.model_name,
71
79
  "text": text,
72
80
  "stream": True,
@@ -82,44 +90,46 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
82
90
  async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
83
91
  """进行流式请求"""
84
92
  try:
85
- async with aiohttp.ClientSession() as session:
86
- async with session.post(
93
+ async with (
94
+ aiohttp.ClientSession() as session,
95
+ session.post(
87
96
  self.concat_base_url,
88
97
  headers=self.headers,
89
98
  data=self._build_tts_stream_body(text),
90
99
  timeout=aiohttp.ClientTimeout(total=60),
91
- ) as response:
92
- response.raise_for_status()
93
-
94
- buffer = b""
95
- while True:
96
- chunk = await response.content.read(8192)
97
- if not chunk:
98
- break
99
-
100
- buffer += chunk
101
-
102
- while b"\n\n" in buffer:
103
- try:
104
- message, buffer = buffer.split(b"\n\n", 1)
105
- if message.startswith(b"data: "):
106
- try:
107
- data = json.loads(message[6:])
108
- if "extra_info" in data:
109
- continue
110
- audio = data.get("data", {}).get("audio")
111
- if audio is not None:
112
- yield audio
113
- except json.JSONDecodeError:
114
- logger.warning(
115
- "Failed to parse JSON data from SSE message"
116
- )
100
+ ) as response,
101
+ ):
102
+ response.raise_for_status()
103
+
104
+ buffer = b""
105
+ while True:
106
+ chunk = await response.content.read(8192)
107
+ if not chunk:
108
+ break
109
+
110
+ buffer += chunk
111
+
112
+ while b"\n\n" in buffer:
113
+ try:
114
+ message, buffer = buffer.split(b"\n\n", 1)
115
+ if message.startswith(b"data: "):
116
+ try:
117
+ data = json.loads(message[6:])
118
+ if "extra_info" in data:
117
119
  continue
118
- except ValueError:
119
- buffer = buffer[-1024:]
120
+ audio = data.get("data", {}).get("audio")
121
+ if audio is not None:
122
+ yield audio
123
+ except json.JSONDecodeError:
124
+ logger.warning(
125
+ "Failed to parse JSON data from SSE message",
126
+ )
127
+ continue
128
+ except ValueError:
129
+ buffer = buffer[-1024:]
120
130
 
121
131
  except aiohttp.ClientError as e:
122
- raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
132
+ raise Exception(f"MiniMax TTS API请求失败: {e!s}")
123
133
 
124
134
  async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
125
135
  """解码数据流到 audio 比特流"""
@@ -1,7 +1,8 @@
1
1
  from openai import AsyncOpenAI
2
+
3
+ from ..entities import ProviderType
2
4
  from ..provider import EmbeddingProvider
3
5
  from ..register import register_provider_adapter
4
- from ..entities import ProviderType
5
6
 
6
7
 
7
8
  @register_provider_adapter(
@@ -17,23 +18,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
17
18
  self.client = AsyncOpenAI(
18
19
  api_key=provider_config.get("embedding_api_key"),
19
20
  base_url=provider_config.get(
20
- "embedding_api_base", "https://api.openai.com/v1"
21
+ "embedding_api_base",
22
+ "https://api.openai.com/v1",
21
23
  ),
22
24
  timeout=int(provider_config.get("timeout", 20)),
23
25
  )
24
26
  self.model = provider_config.get("embedding_model", "text-embedding-3-small")
25
27
 
26
28
  async def get_embedding(self, text: str) -> list[float]:
27
- """
28
- 获取文本的嵌入
29
- """
29
+ """获取文本的嵌入"""
30
30
  embedding = await self.client.embeddings.create(input=text, model=self.model)
31
31
  return embedding.data[0].embedding
32
32
 
33
33
  async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
34
- """
35
- 批量获取文本的嵌入
36
- """
34
+ """批量获取文本的嵌入"""
37
35
  embeddings = await self.client.embeddings.create(input=texts, model=self.model)
38
36
  return [item.embedding for item in embeddings.data]
39
37
 
@@ -1,29 +1,31 @@
1
+ import asyncio
1
2
  import base64
3
+ import inspect
2
4
  import json
3
5
  import os
4
- import inspect
5
6
  import random
6
- import asyncio
7
- import astrbot.core.message.components as Comp
8
-
9
- from openai import AsyncOpenAI, AsyncAzureOpenAI
10
- from openai.types.chat.chat_completion import ChatCompletion
7
+ from collections.abc import AsyncGenerator
11
8
 
9
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
12
10
  from openai._exceptions import NotFoundError, UnprocessableEntityError
13
11
  from openai.lib.streaming.chat._completions import ChatCompletionStreamState
14
- from astrbot.core.utils.io import download_image_by_url
15
- from astrbot.core.message.message_event_result import MessageChain
12
+ from openai.types.chat.chat_completion import ChatCompletion
16
13
 
17
- from astrbot.api.provider import Provider
14
+ import astrbot.core.message.components as Comp
18
15
  from astrbot import logger
19
- from astrbot.core.provider.func_tool_manager import ToolSet
20
- from typing import List, AsyncGenerator
21
- from ..register import register_provider_adapter
16
+ from astrbot.api.provider import Provider
17
+ from astrbot.core.agent.message import Message
18
+ from astrbot.core.agent.tool import ToolSet
19
+ from astrbot.core.message.message_event_result import MessageChain
22
20
  from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
21
+ from astrbot.core.utils.io import download_image_by_url
22
+
23
+ from ..register import register_provider_adapter
23
24
 
24
25
 
25
26
  @register_provider_adapter(
26
- "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器"
27
+ "openai_chat_completion",
28
+ "OpenAI API Chat Completion 提供商适配器",
27
29
  )
28
30
  class ProviderOpenAIOfficial(Provider):
29
31
  def __init__(
@@ -38,7 +40,7 @@ class ProviderOpenAIOfficial(Provider):
38
40
  default_persona,
39
41
  )
40
42
  self.chosen_api_key = None
41
- self.api_keys: List = super().get_keys()
43
+ self.api_keys: list = super().get_keys()
42
44
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
43
45
  self.timeout = provider_config.get("timeout", 120)
44
46
  if isinstance(self.timeout, str):
@@ -61,13 +63,35 @@ class ProviderOpenAIOfficial(Provider):
61
63
  )
62
64
 
63
65
  self.default_params = inspect.signature(
64
- self.client.chat.completions.create
66
+ self.client.chat.completions.create,
65
67
  ).parameters.keys()
66
68
 
67
69
  model_config = provider_config.get("model_config", {})
68
70
  model = model_config.get("model", "unknown")
69
71
  self.set_model(model)
70
72
 
73
+ def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
74
+ """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
75
+
76
+ - 仅在 provider_config.xai_native_search 为 True 时生效
77
+ - 默认注入 {"mode": "auto"}
78
+ - 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off)
79
+ """
80
+ if not bool(self.provider_config.get("xai_native_search", False)):
81
+ return
82
+
83
+ mode = kwargs.get("xai_search_mode", "auto")
84
+ mode = str(mode).lower()
85
+ if mode not in ("auto", "on", "off"):
86
+ mode = "auto"
87
+
88
+ # off 时不注入,保持与未开启一致
89
+ if mode == "off":
90
+ return
91
+
92
+ # OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
93
+ payloads["search_parameters"] = {"mode": mode}
94
+
71
95
  async def get_models(self):
72
96
  try:
73
97
  models_str = []
@@ -79,12 +103,12 @@ class ProviderOpenAIOfficial(Provider):
79
103
  except NotFoundError as e:
80
104
  raise Exception(f"获取模型列表失败:{e}")
81
105
 
82
- async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
106
+ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
83
107
  if tools:
84
108
  model = payloads.get("model", "").lower()
85
109
  omit_empty_param_field = "gemini" in model
86
110
  tool_list = tools.get_func_desc_openai_style(
87
- omit_empty_parameter_field=omit_empty_param_field
111
+ omit_empty_parameter_field=omit_empty_param_field,
88
112
  )
89
113
  if tool_list:
90
114
  payloads["tools"] = tool_list
@@ -92,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
92
116
  # 不在默认参数中的参数放在 extra_body 中
93
117
  extra_body = {}
94
118
  to_del = []
95
- for key in payloads.keys():
119
+ for key in payloads:
96
120
  if key not in self.default_params:
97
121
  extra_body[key] = payloads[key]
98
122
  to_del.append(key)
@@ -111,12 +135,14 @@ class ProviderOpenAIOfficial(Provider):
111
135
  del payloads["tools"]
112
136
 
113
137
  completion = await self.client.chat.completions.create(
114
- **payloads, stream=False, extra_body=extra_body
138
+ **payloads,
139
+ stream=False,
140
+ extra_body=extra_body,
115
141
  )
116
142
 
117
143
  if not isinstance(completion, ChatCompletion):
118
144
  raise Exception(
119
- f"API 返回的 completion 类型错误:{type(completion)}: {completion}。"
145
+ f"API 返回的 completion 类型错误:{type(completion)}: {completion}。",
120
146
  )
121
147
 
122
148
  logger.debug(f"completion: {completion}")
@@ -126,14 +152,16 @@ class ProviderOpenAIOfficial(Provider):
126
152
  return llm_response
127
153
 
128
154
  async def _query_stream(
129
- self, payloads: dict, tools: ToolSet
155
+ self,
156
+ payloads: dict,
157
+ tools: ToolSet | None,
130
158
  ) -> AsyncGenerator[LLMResponse, None]:
131
159
  """流式查询API,逐步返回结果"""
132
160
  if tools:
133
161
  model = payloads.get("model", "").lower()
134
162
  omit_empty_param_field = "gemini" in model
135
163
  tool_list = tools.get_func_desc_openai_style(
136
- omit_empty_parameter_field=omit_empty_param_field
164
+ omit_empty_parameter_field=omit_empty_param_field,
137
165
  )
138
166
  if tool_list:
139
167
  payloads["tools"] = tool_list
@@ -147,7 +175,7 @@ class ProviderOpenAIOfficial(Provider):
147
175
  extra_body.update(custom_extra_body)
148
176
 
149
177
  to_del = []
150
- for key in payloads.keys():
178
+ for key in payloads:
151
179
  if key not in self.default_params:
152
180
  extra_body[key] = payloads[key]
153
181
  to_del.append(key)
@@ -155,7 +183,9 @@ class ProviderOpenAIOfficial(Provider):
155
183
  del payloads[key]
156
184
 
157
185
  stream = await self.client.chat.completions.create(
158
- **payloads, stream=True, extra_body=extra_body
186
+ **payloads,
187
+ stream=True,
188
+ extra_body=extra_body,
159
189
  )
160
190
 
161
191
  llm_response = LLMResponse("assistant", is_chunk=True)
@@ -174,7 +204,7 @@ class ProviderOpenAIOfficial(Provider):
174
204
  if delta.content:
175
205
  completion_text = delta.content
176
206
  llm_response.result_chain = MessageChain(
177
- chain=[Comp.Plain(completion_text)]
207
+ chain=[Comp.Plain(completion_text)],
178
208
  )
179
209
  yield llm_response
180
210
 
@@ -183,7 +213,9 @@ class ProviderOpenAIOfficial(Provider):
183
213
 
184
214
  yield llm_response
185
215
 
186
- async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
216
+ async def parse_openai_completion(
217
+ self, completion: ChatCompletion, tools: ToolSet | None
218
+ ) -> LLMResponse:
187
219
  """解析 OpenAI 的 ChatCompletion 响应"""
188
220
  llm_response = LLMResponse("assistant")
189
221
 
@@ -196,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
196
228
  completion_text = str(choice.message.content).strip()
197
229
  llm_response.result_chain = MessageChain().message(completion_text)
198
230
 
199
- if choice.message.tool_calls:
231
+ if choice.message.tool_calls and tools is not None:
200
232
  # tools call (function calling)
201
233
  args_ls = []
202
234
  func_name_ls = []
@@ -225,7 +257,7 @@ class ProviderOpenAIOfficial(Provider):
225
257
 
226
258
  if choice.finish_reason == "content_filter":
227
259
  raise Exception(
228
- "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
260
+ "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
229
261
  )
230
262
 
231
263
  if llm_response.completion_text is None and not llm_response.tools_call_args:
@@ -238,9 +270,9 @@ class ProviderOpenAIOfficial(Provider):
238
270
 
239
271
  async def _prepare_chat_payload(
240
272
  self,
241
- prompt: str,
273
+ prompt: str | None,
242
274
  image_urls: list[str] | None = None,
243
- contexts: list | None = None,
275
+ contexts: list[dict] | list[Message] | None = None,
244
276
  system_prompt: str | None = None,
245
277
  tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
246
278
  model: str | None = None,
@@ -249,8 +281,12 @@ class ProviderOpenAIOfficial(Provider):
249
281
  """准备聊天所需的有效载荷和上下文"""
250
282
  if contexts is None:
251
283
  contexts = []
252
- new_record = await self.assemble_context(prompt, image_urls)
253
- context_query = [*contexts, new_record]
284
+ new_record = None
285
+ if prompt is not None:
286
+ new_record = await self.assemble_context(prompt, image_urls)
287
+ context_query = self._ensure_message_to_dicts(contexts)
288
+ if new_record:
289
+ context_query.append(new_record)
254
290
  if system_prompt:
255
291
  context_query.insert(0, {"role": "system", "content": system_prompt})
256
292
 
@@ -271,6 +307,9 @@ class ProviderOpenAIOfficial(Provider):
271
307
 
272
308
  payloads = {"messages": context_query, **model_config}
273
309
 
310
+ # xAI 原生搜索参数(最小侵入地在此处注入)
311
+ self._maybe_inject_xai_search(payloads, **kwargs)
312
+
274
313
  return payloads, context_query
275
314
 
276
315
  async def _handle_api_error(
@@ -278,16 +317,16 @@ class ProviderOpenAIOfficial(Provider):
278
317
  e: Exception,
279
318
  payloads: dict,
280
319
  context_query: list,
281
- func_tool: ToolSet,
320
+ func_tool: ToolSet | None,
282
321
  chosen_key: str,
283
- available_api_keys: List[str],
322
+ available_api_keys: list[str],
284
323
  retry_cnt: int,
285
324
  max_retries: int,
286
325
  ) -> tuple:
287
326
  """处理API错误并尝试恢复"""
288
327
  if "429" in str(e):
289
328
  logger.warning(
290
- f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
329
+ f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}",
291
330
  )
292
331
  # 最后一次不等待
293
332
  if retry_cnt < max_retries - 1:
@@ -303,11 +342,10 @@ class ProviderOpenAIOfficial(Provider):
303
342
  context_query,
304
343
  func_tool,
305
344
  )
306
- else:
307
- raise e
308
- elif "maximum context length" in str(e):
345
+ raise e
346
+ if "maximum context length" in str(e):
309
347
  logger.warning(
310
- f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
348
+ f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}",
311
349
  )
312
350
  await self.pop_record(context_query)
313
351
  payloads["messages"] = context_query
@@ -319,7 +357,7 @@ class ProviderOpenAIOfficial(Provider):
319
357
  context_query,
320
358
  func_tool,
321
359
  )
322
- elif "The model is not a VLM" in str(e): # siliconcloud
360
+ if "The model is not a VLM" in str(e): # siliconcloud
323
361
  # 尝试删除所有 image
324
362
  new_contexts = await self._remove_image_from_context(context_query)
325
363
  payloads["messages"] = new_contexts
@@ -332,36 +370,34 @@ class ProviderOpenAIOfficial(Provider):
332
370
  context_query,
333
371
  func_tool,
334
372
  )
335
- elif (
373
+ if (
336
374
  "Function calling is not enabled" in str(e)
337
375
  or ("tool" in str(e).lower() and "support" in str(e).lower())
338
376
  or ("function" in str(e).lower() and "support" in str(e).lower())
339
377
  ):
340
378
  # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
341
379
  logger.info(
342
- f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
380
+ f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
343
381
  )
344
- if "tools" in payloads:
345
- del payloads["tools"]
382
+ payloads.pop("tools", None)
346
383
  return False, chosen_key, available_api_keys, payloads, context_query, None
347
- else:
348
- logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
384
+ logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
349
385
 
350
- if "tool" in str(e).lower() and "support" in str(e).lower():
351
- logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
386
+ if "tool" in str(e).lower() and "support" in str(e).lower():
387
+ logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
352
388
 
353
- if "Connection error." in str(e):
354
- proxy = os.environ.get("http_proxy", None)
355
- if proxy:
356
- logger.error(
357
- f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
358
- )
389
+ if "Connection error." in str(e):
390
+ proxy = os.environ.get("http_proxy", None)
391
+ if proxy:
392
+ logger.error(
393
+ f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
394
+ )
359
395
 
360
- raise e
396
+ raise e
361
397
 
362
398
  async def text_chat(
363
399
  self,
364
- prompt,
400
+ prompt=None,
365
401
  session_id=None,
366
402
  image_urls=None,
367
403
  func_tool=None,
@@ -430,7 +466,7 @@ class ProviderOpenAIOfficial(Provider):
430
466
 
431
467
  async def text_chat_stream(
432
468
  self,
433
- prompt: str,
469
+ prompt=None,
434
470
  session_id=None,
435
471
  image_urls=None,
436
472
  func_tool=None,
@@ -497,10 +533,8 @@ class ProviderOpenAIOfficial(Provider):
497
533
  raise Exception("未知错误")
498
534
  raise last_exception
499
535
 
500
- async def _remove_image_from_context(self, contexts: List):
501
- """
502
- 从上下文中删除所有带有 image 的记录
503
- """
536
+ async def _remove_image_from_context(self, contexts: list):
537
+ """从上下文中删除所有带有 image 的记录"""
504
538
  new_contexts = []
505
539
 
506
540
  for context in contexts:
@@ -521,14 +555,16 @@ class ProviderOpenAIOfficial(Provider):
521
555
  def get_current_key(self) -> str:
522
556
  return self.client.api_key
523
557
 
524
- def get_keys(self) -> List[str]:
558
+ def get_keys(self) -> list[str]:
525
559
  return self.api_keys
526
560
 
527
561
  def set_key(self, key):
528
562
  self.client.api_key = key
529
563
 
530
564
  async def assemble_context(
531
- self, text: str, image_urls: List[str] | None = None
565
+ self,
566
+ text: str,
567
+ image_urls: list[str] | None = None,
532
568
  ) -> dict:
533
569
  """组装成符合 OpenAI 格式的 role 为 user 的消息段"""
534
570
  if image_urls:
@@ -552,16 +588,13 @@ class ProviderOpenAIOfficial(Provider):
552
588
  {
553
589
  "type": "image_url",
554
590
  "image_url": {"url": image_data},
555
- }
591
+ },
556
592
  )
557
593
  return user_content
558
- else:
559
- return {"role": "user", "content": text}
594
+ return {"role": "user", "content": text}
560
595
 
561
596
  async def encode_image_bs64(self, image_url: str) -> str:
562
- """
563
- 将图片转换为 base64
564
- """
597
+ """将图片转换为 base64"""
565
598
  if image_url.startswith("base64://"):
566
599
  return image_url.replace("base64://", "data:image/jpeg;base64,")
567
600
  with open(image_url, "rb") as f:
@@ -1,14 +1,19 @@
1
1
  import os
2
2
  import uuid
3
- from openai import AsyncOpenAI, NOT_GIVEN
4
- from ..provider import TTSProvider
3
+
4
+ from openai import NOT_GIVEN, AsyncOpenAI
5
+
6
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
5
8
  from ..entities import ProviderType
9
+ from ..provider import TTSProvider
6
10
  from ..register import register_provider_adapter
7
- from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
11
 
9
12
 
10
13
  @register_provider_adapter(
11
- "openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
14
+ "openai_tts_api",
15
+ "OpenAI TTS API",
16
+ provider_type=ProviderType.TEXT_TO_SPEECH,
12
17
  )
13
18
  class ProviderOpenAITTSAPI(TTSProvider):
14
19
  def __init__(
@@ -26,7 +31,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
26
31
 
27
32
  self.client = AsyncOpenAI(
28
33
  api_key=self.chosen_api_key,
29
- base_url=provider_config.get("api_base", None),
34
+ base_url=provider_config.get("api_base"),
30
35
  timeout=timeout,
31
36
  )
32
37
 
@@ -36,7 +41,10 @@ class ProviderOpenAITTSAPI(TTSProvider):
36
41
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
37
42
  path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
38
43
  async with self.client.audio.speech.with_streaming_response.create(
39
- model=self.model_name, voice=self.voice, response_format="wav", input=text
44
+ model=self.model_name,
45
+ voice=self.voice,
46
+ response_format="wav",
47
+ input=text,
40
48
  ) as response:
41
49
  with open(path, "wb") as f:
42
50
  async for chunk in response.iter_bytes(chunk_size=1024):