AstrBot 4.5.1__py3-none-any.whl → 4.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +4 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +12 -10
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +32 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +77 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
  181. astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +97 -75
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +56 -53
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/METADATA +2 -1
  240. astrbot-4.5.3.dist-info/RECORD +261 -0
  241. astrbot-4.5.1.dist-info/RECORD +0 -260
  242. {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,22 @@
1
1
  import json
2
2
  import os
3
3
  import uuid
4
+ from collections.abc import AsyncIterator
5
+
4
6
  import aiohttp
5
- from typing import Dict, List, Union, AsyncIterator
6
- from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
7
8
  from astrbot.api import logger
9
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
10
+
8
11
  from ..entities import ProviderType
9
12
  from ..provider import TTSProvider
10
13
  from ..register import register_provider_adapter
11
14
 
12
15
 
13
16
  @register_provider_adapter(
14
- "minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
17
+ "minimax_tts_api",
18
+ "MiniMax TTS API",
19
+ provider_type=ProviderType.TEXT_TO_SPEECH,
15
20
  )
16
21
  class ProviderMiniMaxTTSAPI(TTSProvider):
17
22
  def __init__(
@@ -22,19 +27,21 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
22
27
  super().__init__(provider_config, provider_settings)
23
28
  self.chosen_api_key: str = provider_config.get("api_key", "")
24
29
  self.api_base: str = provider_config.get(
25
- "api_base", "https://api.minimax.chat/v1/t2a_v2"
30
+ "api_base",
31
+ "https://api.minimax.chat/v1/t2a_v2",
26
32
  )
27
33
  self.group_id: str = provider_config.get("minimax-group-id", "")
28
34
  self.set_model(provider_config.get("model", ""))
29
35
  self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
30
36
  self.is_timber_weight: bool = provider_config.get(
31
- "minimax-is-timber-weight", False
37
+ "minimax-is-timber-weight",
38
+ False,
32
39
  )
33
- self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads(
40
+ self.timber_weight: list[dict[str, str | int]] = json.loads(
34
41
  provider_config.get(
35
42
  "minimax-timber-weight",
36
43
  '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
37
- )
44
+ ),
38
45
  )
39
46
 
40
47
  self.voice_setting: dict = {
@@ -47,7 +54,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
47
54
  "emotion": provider_config.get("minimax-voice-emotion", "neutral"),
48
55
  "latex_read": provider_config.get("minimax-voice-latex", False),
49
56
  "english_normalization": provider_config.get(
50
- "minimax-voice-english-normalization", False
57
+ "minimax-voice-english-normalization",
58
+ False,
51
59
  ),
52
60
  }
53
61
 
@@ -66,7 +74,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
66
74
 
67
75
  def _build_tts_stream_body(self, text: str):
68
76
  """构建流式请求体"""
69
- dict_body: Dict[str, object] = {
77
+ dict_body: dict[str, object] = {
70
78
  "model": self.model_name,
71
79
  "text": text,
72
80
  "stream": True,
@@ -82,44 +90,46 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
82
90
  async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
83
91
  """进行流式请求"""
84
92
  try:
85
- async with aiohttp.ClientSession() as session:
86
- async with session.post(
93
+ async with (
94
+ aiohttp.ClientSession() as session,
95
+ session.post(
87
96
  self.concat_base_url,
88
97
  headers=self.headers,
89
98
  data=self._build_tts_stream_body(text),
90
99
  timeout=aiohttp.ClientTimeout(total=60),
91
- ) as response:
92
- response.raise_for_status()
93
-
94
- buffer = b""
95
- while True:
96
- chunk = await response.content.read(8192)
97
- if not chunk:
98
- break
99
-
100
- buffer += chunk
101
-
102
- while b"\n\n" in buffer:
103
- try:
104
- message, buffer = buffer.split(b"\n\n", 1)
105
- if message.startswith(b"data: "):
106
- try:
107
- data = json.loads(message[6:])
108
- if "extra_info" in data:
109
- continue
110
- audio = data.get("data", {}).get("audio")
111
- if audio is not None:
112
- yield audio
113
- except json.JSONDecodeError:
114
- logger.warning(
115
- "Failed to parse JSON data from SSE message"
116
- )
100
+ ) as response,
101
+ ):
102
+ response.raise_for_status()
103
+
104
+ buffer = b""
105
+ while True:
106
+ chunk = await response.content.read(8192)
107
+ if not chunk:
108
+ break
109
+
110
+ buffer += chunk
111
+
112
+ while b"\n\n" in buffer:
113
+ try:
114
+ message, buffer = buffer.split(b"\n\n", 1)
115
+ if message.startswith(b"data: "):
116
+ try:
117
+ data = json.loads(message[6:])
118
+ if "extra_info" in data:
117
119
  continue
118
- except ValueError:
119
- buffer = buffer[-1024:]
120
+ audio = data.get("data", {}).get("audio")
121
+ if audio is not None:
122
+ yield audio
123
+ except json.JSONDecodeError:
124
+ logger.warning(
125
+ "Failed to parse JSON data from SSE message",
126
+ )
127
+ continue
128
+ except ValueError:
129
+ buffer = buffer[-1024:]
120
130
 
121
131
  except aiohttp.ClientError as e:
122
- raise Exception(f"MiniMax TTS API请求失败: {str(e)}")
132
+ raise Exception(f"MiniMax TTS API请求失败: {e!s}")
123
133
 
124
134
  async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
125
135
  """解码数据流到 audio 比特流"""
@@ -1,7 +1,8 @@
1
1
  from openai import AsyncOpenAI
2
+
3
+ from ..entities import ProviderType
2
4
  from ..provider import EmbeddingProvider
3
5
  from ..register import register_provider_adapter
4
- from ..entities import ProviderType
5
6
 
6
7
 
7
8
  @register_provider_adapter(
@@ -17,23 +18,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
17
18
  self.client = AsyncOpenAI(
18
19
  api_key=provider_config.get("embedding_api_key"),
19
20
  base_url=provider_config.get(
20
- "embedding_api_base", "https://api.openai.com/v1"
21
+ "embedding_api_base",
22
+ "https://api.openai.com/v1",
21
23
  ),
22
24
  timeout=int(provider_config.get("timeout", 20)),
23
25
  )
24
26
  self.model = provider_config.get("embedding_model", "text-embedding-3-small")
25
27
 
26
28
  async def get_embedding(self, text: str) -> list[float]:
27
- """
28
- 获取文本的嵌入
29
- """
29
+ """获取文本的嵌入"""
30
30
  embedding = await self.client.embeddings.create(input=text, model=self.model)
31
31
  return embedding.data[0].embedding
32
32
 
33
33
  async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
34
- """
35
- 批量获取文本的嵌入
36
- """
34
+ """批量获取文本的嵌入"""
37
35
  embeddings = await self.client.embeddings.create(input=texts, model=self.model)
38
36
  return [item.embedding for item in embeddings.data]
39
37
 
@@ -1,29 +1,31 @@
1
+ import asyncio
1
2
  import base64
3
+ import inspect
2
4
  import json
3
5
  import os
4
- import inspect
5
6
  import random
6
- import asyncio
7
- import astrbot.core.message.components as Comp
8
-
9
- from openai import AsyncOpenAI, AsyncAzureOpenAI
10
- from openai.types.chat.chat_completion import ChatCompletion
7
+ from collections.abc import AsyncGenerator
11
8
 
9
+ from openai import AsyncAzureOpenAI, AsyncOpenAI
12
10
  from openai._exceptions import NotFoundError, UnprocessableEntityError
13
11
  from openai.lib.streaming.chat._completions import ChatCompletionStreamState
14
- from astrbot.core.utils.io import download_image_by_url
15
- from astrbot.core.message.message_event_result import MessageChain
12
+ from openai.types.chat.chat_completion import ChatCompletion
16
13
 
17
- from astrbot.api.provider import Provider
14
+ import astrbot.core.message.components as Comp
18
15
  from astrbot import logger
19
- from astrbot.core.provider.func_tool_manager import ToolSet
20
- from typing import List, AsyncGenerator
21
- from ..register import register_provider_adapter
16
+ from astrbot.api.provider import Provider
17
+ from astrbot.core.agent.message import Message
18
+ from astrbot.core.agent.tool import ToolSet
19
+ from astrbot.core.message.message_event_result import MessageChain
22
20
  from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
21
+ from astrbot.core.utils.io import download_image_by_url
22
+
23
+ from ..register import register_provider_adapter
23
24
 
24
25
 
25
26
  @register_provider_adapter(
26
- "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器"
27
+ "openai_chat_completion",
28
+ "OpenAI API Chat Completion 提供商适配器",
27
29
  )
28
30
  class ProviderOpenAIOfficial(Provider):
29
31
  def __init__(
@@ -38,7 +40,7 @@ class ProviderOpenAIOfficial(Provider):
38
40
  default_persona,
39
41
  )
40
42
  self.chosen_api_key = None
41
- self.api_keys: List = super().get_keys()
43
+ self.api_keys: list = super().get_keys()
42
44
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
43
45
  self.timeout = provider_config.get("timeout", 120)
44
46
  if isinstance(self.timeout, str):
@@ -61,7 +63,7 @@ class ProviderOpenAIOfficial(Provider):
61
63
  )
62
64
 
63
65
  self.default_params = inspect.signature(
64
- self.client.chat.completions.create
66
+ self.client.chat.completions.create,
65
67
  ).parameters.keys()
66
68
 
67
69
  model_config = provider_config.get("model_config", {})
@@ -101,12 +103,12 @@ class ProviderOpenAIOfficial(Provider):
101
103
  except NotFoundError as e:
102
104
  raise Exception(f"获取模型列表失败:{e}")
103
105
 
104
- async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
106
+ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
105
107
  if tools:
106
108
  model = payloads.get("model", "").lower()
107
109
  omit_empty_param_field = "gemini" in model
108
110
  tool_list = tools.get_func_desc_openai_style(
109
- omit_empty_parameter_field=omit_empty_param_field
111
+ omit_empty_parameter_field=omit_empty_param_field,
110
112
  )
111
113
  if tool_list:
112
114
  payloads["tools"] = tool_list
@@ -114,7 +116,7 @@ class ProviderOpenAIOfficial(Provider):
114
116
  # 不在默认参数中的参数放在 extra_body 中
115
117
  extra_body = {}
116
118
  to_del = []
117
- for key in payloads.keys():
119
+ for key in payloads:
118
120
  if key not in self.default_params:
119
121
  extra_body[key] = payloads[key]
120
122
  to_del.append(key)
@@ -133,12 +135,14 @@ class ProviderOpenAIOfficial(Provider):
133
135
  del payloads["tools"]
134
136
 
135
137
  completion = await self.client.chat.completions.create(
136
- **payloads, stream=False, extra_body=extra_body
138
+ **payloads,
139
+ stream=False,
140
+ extra_body=extra_body,
137
141
  )
138
142
 
139
143
  if not isinstance(completion, ChatCompletion):
140
144
  raise Exception(
141
- f"API 返回的 completion 类型错误:{type(completion)}: {completion}。"
145
+ f"API 返回的 completion 类型错误:{type(completion)}: {completion}。",
142
146
  )
143
147
 
144
148
  logger.debug(f"completion: {completion}")
@@ -148,14 +152,16 @@ class ProviderOpenAIOfficial(Provider):
148
152
  return llm_response
149
153
 
150
154
  async def _query_stream(
151
- self, payloads: dict, tools: ToolSet
155
+ self,
156
+ payloads: dict,
157
+ tools: ToolSet | None,
152
158
  ) -> AsyncGenerator[LLMResponse, None]:
153
159
  """流式查询API,逐步返回结果"""
154
160
  if tools:
155
161
  model = payloads.get("model", "").lower()
156
162
  omit_empty_param_field = "gemini" in model
157
163
  tool_list = tools.get_func_desc_openai_style(
158
- omit_empty_parameter_field=omit_empty_param_field
164
+ omit_empty_parameter_field=omit_empty_param_field,
159
165
  )
160
166
  if tool_list:
161
167
  payloads["tools"] = tool_list
@@ -169,7 +175,7 @@ class ProviderOpenAIOfficial(Provider):
169
175
  extra_body.update(custom_extra_body)
170
176
 
171
177
  to_del = []
172
- for key in payloads.keys():
178
+ for key in payloads:
173
179
  if key not in self.default_params:
174
180
  extra_body[key] = payloads[key]
175
181
  to_del.append(key)
@@ -177,7 +183,9 @@ class ProviderOpenAIOfficial(Provider):
177
183
  del payloads[key]
178
184
 
179
185
  stream = await self.client.chat.completions.create(
180
- **payloads, stream=True, extra_body=extra_body
186
+ **payloads,
187
+ stream=True,
188
+ extra_body=extra_body,
181
189
  )
182
190
 
183
191
  llm_response = LLMResponse("assistant", is_chunk=True)
@@ -196,7 +204,7 @@ class ProviderOpenAIOfficial(Provider):
196
204
  if delta.content:
197
205
  completion_text = delta.content
198
206
  llm_response.result_chain = MessageChain(
199
- chain=[Comp.Plain(completion_text)]
207
+ chain=[Comp.Plain(completion_text)],
200
208
  )
201
209
  yield llm_response
202
210
 
@@ -205,7 +213,9 @@ class ProviderOpenAIOfficial(Provider):
205
213
 
206
214
  yield llm_response
207
215
 
208
- async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
216
+ async def parse_openai_completion(
217
+ self, completion: ChatCompletion, tools: ToolSet | None
218
+ ) -> LLMResponse:
209
219
  """解析 OpenAI 的 ChatCompletion 响应"""
210
220
  llm_response = LLMResponse("assistant")
211
221
 
@@ -218,7 +228,7 @@ class ProviderOpenAIOfficial(Provider):
218
228
  completion_text = str(choice.message.content).strip()
219
229
  llm_response.result_chain = MessageChain().message(completion_text)
220
230
 
221
- if choice.message.tool_calls:
231
+ if choice.message.tool_calls and tools is not None:
222
232
  # tools call (function calling)
223
233
  args_ls = []
224
234
  func_name_ls = []
@@ -247,7 +257,7 @@ class ProviderOpenAIOfficial(Provider):
247
257
 
248
258
  if choice.finish_reason == "content_filter":
249
259
  raise Exception(
250
- "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。"
260
+ "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
251
261
  )
252
262
 
253
263
  if llm_response.completion_text is None and not llm_response.tools_call_args:
@@ -260,9 +270,9 @@ class ProviderOpenAIOfficial(Provider):
260
270
 
261
271
  async def _prepare_chat_payload(
262
272
  self,
263
- prompt: str,
273
+ prompt: str | None,
264
274
  image_urls: list[str] | None = None,
265
- contexts: list | None = None,
275
+ contexts: list[dict] | list[Message] | None = None,
266
276
  system_prompt: str | None = None,
267
277
  tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
268
278
  model: str | None = None,
@@ -271,8 +281,12 @@ class ProviderOpenAIOfficial(Provider):
271
281
  """准备聊天所需的有效载荷和上下文"""
272
282
  if contexts is None:
273
283
  contexts = []
274
- new_record = await self.assemble_context(prompt, image_urls)
275
- context_query = [*contexts, new_record]
284
+ new_record = None
285
+ if prompt is not None:
286
+ new_record = await self.assemble_context(prompt, image_urls)
287
+ context_query = self._ensure_message_to_dicts(contexts)
288
+ if new_record:
289
+ context_query.append(new_record)
276
290
  if system_prompt:
277
291
  context_query.insert(0, {"role": "system", "content": system_prompt})
278
292
 
@@ -303,16 +317,16 @@ class ProviderOpenAIOfficial(Provider):
303
317
  e: Exception,
304
318
  payloads: dict,
305
319
  context_query: list,
306
- func_tool: ToolSet,
320
+ func_tool: ToolSet | None,
307
321
  chosen_key: str,
308
- available_api_keys: List[str],
322
+ available_api_keys: list[str],
309
323
  retry_cnt: int,
310
324
  max_retries: int,
311
325
  ) -> tuple:
312
326
  """处理API错误并尝试恢复"""
313
327
  if "429" in str(e):
314
328
  logger.warning(
315
- f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
329
+ f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}",
316
330
  )
317
331
  # 最后一次不等待
318
332
  if retry_cnt < max_retries - 1:
@@ -328,11 +342,10 @@ class ProviderOpenAIOfficial(Provider):
328
342
  context_query,
329
343
  func_tool,
330
344
  )
331
- else:
332
- raise e
333
- elif "maximum context length" in str(e):
345
+ raise e
346
+ if "maximum context length" in str(e):
334
347
  logger.warning(
335
- f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
348
+ f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}",
336
349
  )
337
350
  await self.pop_record(context_query)
338
351
  payloads["messages"] = context_query
@@ -344,7 +357,7 @@ class ProviderOpenAIOfficial(Provider):
344
357
  context_query,
345
358
  func_tool,
346
359
  )
347
- elif "The model is not a VLM" in str(e): # siliconcloud
360
+ if "The model is not a VLM" in str(e): # siliconcloud
348
361
  # 尝试删除所有 image
349
362
  new_contexts = await self._remove_image_from_context(context_query)
350
363
  payloads["messages"] = new_contexts
@@ -357,36 +370,34 @@ class ProviderOpenAIOfficial(Provider):
357
370
  context_query,
358
371
  func_tool,
359
372
  )
360
- elif (
373
+ if (
361
374
  "Function calling is not enabled" in str(e)
362
375
  or ("tool" in str(e).lower() and "support" in str(e).lower())
363
376
  or ("function" in str(e).lower() and "support" in str(e).lower())
364
377
  ):
365
378
  # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
366
379
  logger.info(
367
- f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
380
+ f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
368
381
  )
369
- if "tools" in payloads:
370
- del payloads["tools"]
382
+ payloads.pop("tools", None)
371
383
  return False, chosen_key, available_api_keys, payloads, context_query, None
372
- else:
373
- logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
384
+ logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
374
385
 
375
- if "tool" in str(e).lower() and "support" in str(e).lower():
376
- logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
386
+ if "tool" in str(e).lower() and "support" in str(e).lower():
387
+ logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
377
388
 
378
- if "Connection error." in str(e):
379
- proxy = os.environ.get("http_proxy", None)
380
- if proxy:
381
- logger.error(
382
- f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
383
- )
389
+ if "Connection error." in str(e):
390
+ proxy = os.environ.get("http_proxy", None)
391
+ if proxy:
392
+ logger.error(
393
+ f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
394
+ )
384
395
 
385
- raise e
396
+ raise e
386
397
 
387
398
  async def text_chat(
388
399
  self,
389
- prompt,
400
+ prompt=None,
390
401
  session_id=None,
391
402
  image_urls=None,
392
403
  func_tool=None,
@@ -455,7 +466,7 @@ class ProviderOpenAIOfficial(Provider):
455
466
 
456
467
  async def text_chat_stream(
457
468
  self,
458
- prompt: str,
469
+ prompt=None,
459
470
  session_id=None,
460
471
  image_urls=None,
461
472
  func_tool=None,
@@ -522,10 +533,8 @@ class ProviderOpenAIOfficial(Provider):
522
533
  raise Exception("未知错误")
523
534
  raise last_exception
524
535
 
525
- async def _remove_image_from_context(self, contexts: List):
526
- """
527
- 从上下文中删除所有带有 image 的记录
528
- """
536
+ async def _remove_image_from_context(self, contexts: list):
537
+ """从上下文中删除所有带有 image 的记录"""
529
538
  new_contexts = []
530
539
 
531
540
  for context in contexts:
@@ -546,14 +555,16 @@ class ProviderOpenAIOfficial(Provider):
546
555
  def get_current_key(self) -> str:
547
556
  return self.client.api_key
548
557
 
549
- def get_keys(self) -> List[str]:
558
+ def get_keys(self) -> list[str]:
550
559
  return self.api_keys
551
560
 
552
561
  def set_key(self, key):
553
562
  self.client.api_key = key
554
563
 
555
564
  async def assemble_context(
556
- self, text: str, image_urls: List[str] | None = None
565
+ self,
566
+ text: str,
567
+ image_urls: list[str] | None = None,
557
568
  ) -> dict:
558
569
  """组装成符合 OpenAI 格式的 role 为 user 的消息段"""
559
570
  if image_urls:
@@ -577,16 +588,13 @@ class ProviderOpenAIOfficial(Provider):
577
588
  {
578
589
  "type": "image_url",
579
590
  "image_url": {"url": image_data},
580
- }
591
+ },
581
592
  )
582
593
  return user_content
583
- else:
584
- return {"role": "user", "content": text}
594
+ return {"role": "user", "content": text}
585
595
 
586
596
  async def encode_image_bs64(self, image_url: str) -> str:
587
- """
588
- 将图片转换为 base64
589
- """
597
+ """将图片转换为 base64"""
590
598
  if image_url.startswith("base64://"):
591
599
  return image_url.replace("base64://", "data:image/jpeg;base64,")
592
600
  with open(image_url, "rb") as f:
@@ -1,14 +1,19 @@
1
1
  import os
2
2
  import uuid
3
- from openai import AsyncOpenAI, NOT_GIVEN
4
- from ..provider import TTSProvider
3
+
4
+ from openai import NOT_GIVEN, AsyncOpenAI
5
+
6
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
5
8
  from ..entities import ProviderType
9
+ from ..provider import TTSProvider
6
10
  from ..register import register_provider_adapter
7
- from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
11
 
9
12
 
10
13
  @register_provider_adapter(
11
- "openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
14
+ "openai_tts_api",
15
+ "OpenAI TTS API",
16
+ provider_type=ProviderType.TEXT_TO_SPEECH,
12
17
  )
13
18
  class ProviderOpenAITTSAPI(TTSProvider):
14
19
  def __init__(
@@ -26,7 +31,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
26
31
 
27
32
  self.client = AsyncOpenAI(
28
33
  api_key=self.chosen_api_key,
29
- base_url=provider_config.get("api_base", None),
34
+ base_url=provider_config.get("api_base"),
30
35
  timeout=timeout,
31
36
  )
32
37
 
@@ -36,7 +41,10 @@ class ProviderOpenAITTSAPI(TTSProvider):
36
41
  temp_dir = os.path.join(get_astrbot_data_path(), "temp")
37
42
  path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
38
43
  async with self.client.audio.speech.with_streaming_response.create(
39
- model=self.model_name, voice=self.voice, response_format="wav", input=text
44
+ model=self.model_name,
45
+ voice=self.voice,
46
+ response_format="wav",
47
+ input=text,
40
48
  ) as response:
41
49
  with open(path, "wb") as f:
42
50
  async for chunk in response.iter_bytes(chunk_size=1024):
@@ -1,22 +1,24 @@
1
- """
2
- Author: diudiu62
1
+ """Author: diudiu62
3
2
  Date: 2025-02-24 18:04:18
4
3
  LastEditTime: 2025-02-25 14:06:30
5
4
  """
6
5
 
7
6
  import asyncio
8
- from datetime import datetime
9
7
  import os
10
8
  import re
9
+ from datetime import datetime
10
+
11
11
  from funasr_onnx import SenseVoiceSmall
12
12
  from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
13
- from ..provider import STTProvider
14
- from ..entities import ProviderType
15
- from astrbot.core.utils.io import download_file
16
- from ..register import register_provider_adapter
13
+
17
14
  from astrbot.core import logger
15
+ from astrbot.core.utils.io import download_file
18
16
  from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
19
17
 
18
+ from ..entities import ProviderType
19
+ from ..provider import STTProvider
20
+ from ..register import register_provider_adapter
21
+
20
22
 
21
23
  @register_provider_adapter(
22
24
  "sensevoice_stt_selfhost",
@@ -30,7 +32,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
30
32
  provider_settings: dict,
31
33
  ) -> None:
32
34
  super().__init__(provider_config, provider_settings)
33
- self.set_model(provider_config.get("stt_model", None))
35
+ self.set_model(provider_config.get("stt_model"))
34
36
  self.model = None
35
37
  self.is_emotion = provider_config.get("is_emotion", False)
36
38
 
@@ -39,7 +41,8 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
39
41
 
40
42
  # 将模型加载放到线程池中执行
41
43
  self.model = await asyncio.get_event_loop().run_in_executor(
42
- None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16)
44
+ None,
45
+ lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
43
46
  )
44
47
 
45
48
  logger.info("SenseVoice 模型加载完成。")
@@ -55,8 +58,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
55
58
 
56
59
  if silk_header in file_header:
57
60
  return True
58
- else:
59
- return False
61
+ return False
60
62
 
61
63
  async def get_text(self, audio_url: str) -> str:
62
64
  try: