AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +4 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +12 -10
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +32 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +77 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
  181. astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +97 -75
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.1.dist-info/RECORD +0 -260
  242. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,17 +1,18 @@
1
1
  import abc
2
2
  import asyncio
3
- from typing import List
4
- from typing import AsyncGenerator
3
+ from collections.abc import AsyncGenerator
4
+ from dataclasses import dataclass
5
+
6
+ from astrbot.core.agent.message import Message
5
7
  from astrbot.core.agent.tool import ToolSet
8
+ from astrbot.core.db.po import Personality
6
9
  from astrbot.core.provider.entities import (
7
10
  LLMResponse,
8
- ToolCallsResult,
9
11
  ProviderType,
10
12
  RerankResult,
13
+ ToolCallsResult,
11
14
  )
12
15
  from astrbot.core.provider.register import provider_cls_map
13
- from astrbot.core.db.po import Personality
14
- from dataclasses import dataclass
15
16
 
16
17
 
17
18
  @dataclass
@@ -23,24 +24,28 @@ class ProviderMeta:
23
24
 
24
25
 
25
26
  class AbstractProvider(abc.ABC):
27
+ """Provider Abstract Class"""
28
+
26
29
  def __init__(self, provider_config: dict) -> None:
27
30
  super().__init__()
28
31
  self.model_name = ""
29
32
  self.provider_config = provider_config
30
33
 
31
34
  def set_model(self, model_name: str):
32
- """设置当前使用的模型名称"""
35
+ """Set the current model name"""
33
36
  self.model_name = model_name
34
37
 
35
38
  def get_model(self) -> str:
36
- """获得当前使用的模型名称"""
39
+ """Get the current model name"""
37
40
  return self.model_name
38
41
 
39
42
  def meta(self) -> ProviderMeta:
40
- """获取 Provider 的元数据"""
43
+ """Get the provider metadata"""
41
44
  provider_type_name = self.provider_config["type"]
42
45
  meta_data = provider_cls_map.get(provider_type_name)
43
46
  provider_type = meta_data.provider_type if meta_data else None
47
+ if provider_type is None:
48
+ raise ValueError(f"Cannot find provider type: {provider_type_name}")
44
49
  return ProviderMeta(
45
50
  id=self.provider_config["id"],
46
51
  model=self.get_model(),
@@ -50,6 +55,8 @@ class AbstractProvider(abc.ABC):
50
55
 
51
56
 
52
57
  class Provider(AbstractProvider):
58
+ """Chat Provider"""
59
+
53
60
  def __init__(
54
61
  self,
55
62
  provider_config: dict,
@@ -65,99 +72,114 @@ class Provider(AbstractProvider):
65
72
 
66
73
  @abc.abstractmethod
67
74
  def get_current_key(self) -> str:
68
- raise NotImplementedError()
75
+ raise NotImplementedError
69
76
 
70
- def get_keys(self) -> List[str]:
77
+ def get_keys(self) -> list[str]:
71
78
  """获得提供商 Key"""
72
79
  keys = self.provider_config.get("key", [""])
73
80
  return keys or [""]
74
81
 
75
82
  @abc.abstractmethod
76
83
  def set_key(self, key: str):
77
- raise NotImplementedError()
84
+ raise NotImplementedError
78
85
 
79
86
  @abc.abstractmethod
80
- async def get_models(self) -> List[str]:
87
+ async def get_models(self) -> list[str]:
81
88
  """获得支持的模型列表"""
82
- raise NotImplementedError()
89
+ raise NotImplementedError
83
90
 
84
91
  @abc.abstractmethod
85
92
  async def text_chat(
86
93
  self,
87
- prompt: str,
88
- session_id: str = None,
89
- image_urls: list[str] = None,
90
- func_tool: ToolSet = None,
91
- contexts: list = None,
92
- system_prompt: str = None,
93
- tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
94
+ prompt: str | None = None,
95
+ session_id: str | None = None,
96
+ image_urls: list[str] | None = None,
97
+ func_tool: ToolSet | None = None,
98
+ contexts: list[Message] | list[dict] | None = None,
99
+ system_prompt: str | None = None,
100
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
94
101
  model: str | None = None,
95
102
  **kwargs,
96
103
  ) -> LLMResponse:
97
104
  """获得 LLM 的文本对话结果。会使用当前的模型进行对话。
98
105
 
99
106
  Args:
100
- prompt: 提示词
107
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
101
108
  session_id: 会话 ID(此属性已经被废弃)
102
109
  image_urls: 图片 URL 列表
103
- tools: Function-calling 工具
104
- contexts: 上下文
110
+ tools: tool set
111
+ contexts: 上下文,和 prompt 二选一使用
105
112
  tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
106
113
  kwargs: 其他参数
107
114
 
108
115
  Notes:
109
116
  - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
110
117
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
118
+
111
119
  """
112
120
  ...
113
121
 
114
122
  async def text_chat_stream(
115
123
  self,
116
- prompt: str,
117
- session_id: str = None,
118
- image_urls: list[str] = None,
119
- func_tool: ToolSet = None,
120
- contexts: list = None,
121
- system_prompt: str = None,
122
- tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
124
+ prompt: str | None = None,
125
+ session_id: str | None = None,
126
+ image_urls: list[str] | None = None,
127
+ func_tool: ToolSet | None = None,
128
+ contexts: list[Message] | list[dict] | None = None,
129
+ system_prompt: str | None = None,
130
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
123
131
  model: str | None = None,
124
132
  **kwargs,
125
133
  ) -> AsyncGenerator[LLMResponse, None]:
126
134
  """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
127
135
 
128
136
  Args:
129
- prompt: 提示词
137
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
130
138
  session_id: 会话 ID(此属性已经被废弃)
131
139
  image_urls: 图片 URL 列表
132
- tools: Function-calling 工具
133
- contexts: 上下文
140
+ tools: tool set
141
+ contexts: 上下文,和 prompt 二选一使用
134
142
  tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
135
143
  kwargs: 其他参数
136
144
 
137
145
  Notes:
138
146
  - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
139
147
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
148
+
140
149
  """
141
150
  ...
142
151
 
143
- async def pop_record(self, context: List):
144
- """
145
- 弹出 context 第一条非系统提示词对话记录
146
- """
152
+ async def pop_record(self, context: list):
153
+ """弹出 context 第一条非系统提示词对话记录"""
147
154
  poped = 0
148
155
  indexs_to_pop = []
149
156
  for idx, record in enumerate(context):
150
157
  if record["role"] == "system":
151
158
  continue
152
- else:
153
- indexs_to_pop.append(idx)
154
- poped += 1
155
- if poped == 2:
156
- break
159
+ indexs_to_pop.append(idx)
160
+ poped += 1
161
+ if poped == 2:
162
+ break
157
163
 
158
164
  for idx in reversed(indexs_to_pop):
159
165
  context.pop(idx)
160
166
 
167
+ def _ensure_message_to_dicts(
168
+ self,
169
+ messages: list[dict] | list[Message] | None,
170
+ ) -> list[dict]:
171
+ """Convert a list of Message objects to a list of dictionaries."""
172
+ if not messages:
173
+ return []
174
+ dicts: list[dict] = []
175
+ for message in messages:
176
+ if isinstance(message, Message):
177
+ dicts.append(message.model_dump())
178
+ else:
179
+ dicts.append(message)
180
+
181
+ return dicts
182
+
161
183
 
162
184
  class STTProvider(AbstractProvider):
163
185
  def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -168,7 +190,7 @@ class STTProvider(AbstractProvider):
168
190
  @abc.abstractmethod
169
191
  async def get_text(self, audio_url: str) -> str:
170
192
  """获取音频的文本"""
171
- raise NotImplementedError()
193
+ raise NotImplementedError
172
194
 
173
195
 
174
196
  class TTSProvider(AbstractProvider):
@@ -180,7 +202,7 @@ class TTSProvider(AbstractProvider):
180
202
  @abc.abstractmethod
181
203
  async def get_audio(self, text: str) -> str:
182
204
  """获取文本的音频,返回音频文件路径"""
183
- raise NotImplementedError()
205
+ raise NotImplementedError
184
206
 
185
207
 
186
208
  class EmbeddingProvider(AbstractProvider):
@@ -223,6 +245,7 @@ class EmbeddingProvider(AbstractProvider):
223
245
 
224
246
  Returns:
225
247
  向量列表
248
+
226
249
  """
227
250
  semaphore = asyncio.Semaphore(tasks_limit)
228
251
  all_embeddings: list[list[float]] = []
@@ -246,7 +269,7 @@ class EmbeddingProvider(AbstractProvider):
246
269
  # 最后一次重试失败,记录失败的批次
247
270
  failed_batches.append((batch_idx, batch_texts))
248
271
  raise Exception(
249
- f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
272
+ f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
250
273
  )
251
274
  # 等待一段时间后重试,使用指数退避
252
275
  await asyncio.sleep(2**attempt)
@@ -279,7 +302,10 @@ class RerankProvider(AbstractProvider):
279
302
 
280
303
  @abc.abstractmethod
281
304
  async def rerank(
282
- self, query: str, documents: list[str], top_n: int | None = None
305
+ self,
306
+ query: str,
307
+ documents: list[str],
308
+ top_n: int | None = None,
283
309
  ) -> list[RerankResult]:
284
310
  """获取查询和文档的重排序分数"""
285
311
  ...
@@ -1,11 +1,11 @@
1
- from typing import List, Dict
2
- from .entities import ProviderMetaData, ProviderType
3
1
  from astrbot.core import logger
2
+
3
+ from .entities import ProviderMetaData, ProviderType
4
4
  from .func_tool_manager import FuncCall
5
5
 
6
- provider_registry: List[ProviderMetaData] = []
6
+ provider_registry: list[ProviderMetaData] = []
7
7
  """维护了通过装饰器注册的 Provider"""
8
- provider_cls_map: Dict[str, ProviderMetaData] = {}
8
+ provider_cls_map: dict[str, ProviderMetaData] = {}
9
9
  """维护了 Provider 类型名称和 ProviderMetadata 的映射"""
10
10
 
11
11
  llm_tools = FuncCall()
@@ -15,15 +15,15 @@ def register_provider_adapter(
15
15
  provider_type_name: str,
16
16
  desc: str,
17
17
  provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
18
- default_config_tmpl: dict = None,
19
- provider_display_name: str = None,
18
+ default_config_tmpl: dict | None = None,
19
+ provider_display_name: str | None = None,
20
20
  ):
21
21
  """用于注册平台适配器的带参装饰器"""
22
22
 
23
23
  def decorator(cls):
24
24
  if provider_type_name in provider_cls_map:
25
25
  raise ValueError(
26
- f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。"
26
+ f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
27
27
  )
28
28
 
29
29
  # 添加必备选项
@@ -1,23 +1,24 @@
1
- import json
2
- import anthropic
3
1
  import base64
4
- from typing import List
2
+ import json
3
+ from collections.abc import AsyncGenerator
5
4
  from mimetypes import guess_type
6
5
 
6
+ import anthropic
7
7
  from anthropic import AsyncAnthropic
8
8
  from anthropic.types import Message
9
9
 
10
- from astrbot.core.utils.io import download_image_by_url
11
- from astrbot.api.provider import Provider
12
10
  from astrbot import logger
11
+ from astrbot.api.provider import Provider
12
+ from astrbot.core.provider.entities import LLMResponse
13
13
  from astrbot.core.provider.func_tool_manager import ToolSet
14
+ from astrbot.core.utils.io import download_image_by_url
15
+
14
16
  from ..register import register_provider_adapter
15
- from astrbot.core.provider.entities import LLMResponse
16
- from typing import AsyncGenerator
17
17
 
18
18
 
19
19
  @register_provider_adapter(
20
- "anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
20
+ "anthropic_chat_completion",
21
+ "Anthropic Claude API 提供商适配器",
21
22
  )
22
23
  class ProviderAnthropic(Provider):
23
24
  def __init__(
@@ -33,7 +34,7 @@ class ProviderAnthropic(Provider):
33
34
  )
34
35
 
35
36
  self.chosen_api_key: str = ""
36
- self.api_keys: List = super().get_keys()
37
+ self.api_keys: list = super().get_keys()
37
38
  self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
38
39
  self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
39
40
  self.timeout = provider_config.get("timeout", 120)
@@ -41,7 +42,9 @@ class ProviderAnthropic(Provider):
41
42
  self.timeout = int(self.timeout)
42
43
 
43
44
  self.client = AsyncAnthropic(
44
- api_key=self.chosen_api_key, timeout=self.timeout, base_url=self.base_url
45
+ api_key=self.chosen_api_key,
46
+ timeout=self.timeout,
47
+ base_url=self.base_url,
45
48
  )
46
49
 
47
50
  self.set_model(provider_config["model_config"]["model"])
@@ -54,6 +57,7 @@ class ProviderAnthropic(Provider):
54
57
  Returns:
55
58
  system_prompt: 系统提示内容
56
59
  new_messages: 处理后的消息列表,去除系统提示
60
+
57
61
  """
58
62
  system_prompt = ""
59
63
  new_messages = []
@@ -73,18 +77,19 @@ class ProviderAnthropic(Provider):
73
77
  "input": (
74
78
  json.loads(tool_call["function"]["arguments"])
75
79
  if isinstance(
76
- tool_call["function"]["arguments"], str
80
+ tool_call["function"]["arguments"],
81
+ str,
77
82
  )
78
83
  else tool_call["function"]["arguments"]
79
84
  ),
80
85
  "id": tool_call["id"],
81
- }
86
+ },
82
87
  )
83
88
  new_messages.append(
84
89
  {
85
90
  "role": "assistant",
86
91
  "content": blocks,
87
- }
92
+ },
88
93
  )
89
94
  elif message["role"] == "tool":
90
95
  new_messages.append(
@@ -95,9 +100,9 @@ class ProviderAnthropic(Provider):
95
100
  "type": "tool_result",
96
101
  "tool_use_id": message["tool_call_id"],
97
102
  "content": message["content"],
98
- }
103
+ },
99
104
  ],
100
- }
105
+ },
101
106
  )
102
107
  else:
103
108
  new_messages.append(message)
@@ -135,7 +140,9 @@ class ProviderAnthropic(Provider):
135
140
  return llm_response
136
141
 
137
142
  async def _query_stream(
138
- self, payloads: dict, tools: ToolSet | None
143
+ self,
144
+ payloads: dict,
145
+ tools: ToolSet | None,
139
146
  ) -> AsyncGenerator[LLMResponse, None]:
140
147
  if tools:
141
148
  if tool_list := tools.get_func_desc_anthropic_style():
@@ -154,7 +161,9 @@ class ProviderAnthropic(Provider):
154
161
  if event.content_block.type == "text":
155
162
  # 文本块开始
156
163
  yield LLMResponse(
157
- role="assistant", completion_text="", is_chunk=True
164
+ role="assistant",
165
+ completion_text="",
166
+ is_chunk=True,
158
167
  )
159
168
  elif event.content_block.type == "tool_use":
160
169
  # 工具使用块开始,初始化缓冲区
@@ -198,7 +207,7 @@ class ProviderAnthropic(Provider):
198
207
  "id": tool_info["id"],
199
208
  "name": tool_info["name"],
200
209
  "input": tool_info["input"],
201
- }
210
+ },
202
211
  )
203
212
 
204
213
  yield LLMResponse(
@@ -218,7 +227,9 @@ class ProviderAnthropic(Provider):
218
227
 
219
228
  # 返回最终的完整结果
220
229
  final_response = LLMResponse(
221
- role="assistant", completion_text=final_text, is_chunk=False
230
+ role="assistant",
231
+ completion_text=final_text,
232
+ is_chunk=False,
222
233
  )
223
234
 
224
235
  if final_tool_calls:
@@ -232,7 +243,7 @@ class ProviderAnthropic(Provider):
232
243
 
233
244
  async def text_chat(
234
245
  self,
235
- prompt,
246
+ prompt=None,
236
247
  session_id=None,
237
248
  image_urls=None,
238
249
  func_tool=None,
@@ -244,8 +255,13 @@ class ProviderAnthropic(Provider):
244
255
  ) -> LLMResponse:
245
256
  if contexts is None:
246
257
  contexts = []
247
- new_record = await self.assemble_context(prompt, image_urls)
248
- context_query = [*contexts, new_record]
258
+ new_record = None
259
+ if prompt is not None:
260
+ new_record = await self.assemble_context(prompt, image_urls)
261
+ context_query = self._ensure_message_to_dicts(contexts)
262
+ if new_record:
263
+ context_query.append(new_record)
264
+
249
265
  if system_prompt:
250
266
  context_query.insert(0, {"role": "system", "content": system_prompt})
251
267
 
@@ -295,8 +311,12 @@ class ProviderAnthropic(Provider):
295
311
  ):
296
312
  if contexts is None:
297
313
  contexts = []
298
- new_record = await self.assemble_context(prompt, image_urls)
299
- context_query = [*contexts, new_record]
314
+ new_record = None
315
+ if prompt is not None:
316
+ new_record = await self.assemble_context(prompt, image_urls)
317
+ context_query = self._ensure_message_to_dicts(contexts)
318
+ if new_record:
319
+ context_query.append(new_record)
300
320
  if system_prompt:
301
321
  context_query.insert(0, {"role": "system", "content": system_prompt})
302
322
 
@@ -326,7 +346,7 @@ class ProviderAnthropic(Provider):
326
346
  async for llm_response in self._query_stream(payloads, func_tool):
327
347
  yield llm_response
328
348
 
329
- async def assemble_context(self, text: str, image_urls: List[str] | None = None):
349
+ async def assemble_context(self, text: str, image_urls: list[str] | None = None):
330
350
  """组装上下文,支持文本和图片"""
331
351
  if not image_urls:
332
352
  return {"role": "user", "content": text}
@@ -365,15 +385,13 @@ class ProviderAnthropic(Provider):
365
385
  else image_data
366
386
  ),
367
387
  },
368
- }
388
+ },
369
389
  )
370
390
 
371
391
  return {"role": "user", "content": content}
372
392
 
373
393
  async def encode_image_bs64(self, image_url: str) -> str:
374
- """
375
- 将图片转换为 base64
376
- """
394
+ """将图片转换为 base64"""
377
395
  if image_url.startswith("base64://"):
378
396
  return image_url.replace("base64://", "data:image/jpeg;base64,")
379
397
  with open(image_url, "rb") as f:
@@ -384,7 +402,7 @@ class ProviderAnthropic(Provider):
384
402
  def get_current_key(self) -> str:
385
403
  return self.chosen_api_key
386
404
 
387
- async def get_models(self) -> List[str]:
405
+ async def get_models(self) -> list[str]:
388
406
  models_str = []
389
407
  models = await self.client.models.list()
390
408
  models = sorted(models.data, key=lambda x: x.id)
@@ -1,15 +1,15 @@
1
- import uuid
2
- import time
1
+ import asyncio
2
+ import hashlib
3
3
  import json
4
4
  import re
5
- import hashlib
6
- import random
7
- import asyncio
5
+ import secrets
6
+ import time
7
+ import uuid
8
8
  from pathlib import Path
9
- from typing import Dict
10
9
  from xml.sax.saxutils import escape
11
10
 
12
11
  from httpx import AsyncClient, Timeout
12
+
13
13
  from astrbot.core.config.default import VERSION
14
14
 
15
15
  from ..entities import ProviderType
@@ -21,7 +21,7 @@ TEMP_DIR.mkdir(parents=True, exist_ok=True)
21
21
 
22
22
 
23
23
  class OTTSProvider:
24
- def __init__(self, config: Dict):
24
+ def __init__(self, config: dict):
25
25
  self.skey = config["OTTS_SKEY"]
26
26
  self.api_url = config["OTTS_URL"]
27
27
  self.auth_time_url = config["OTTS_AUTH_TIME"]
@@ -54,11 +54,13 @@ class OTTSProvider:
54
54
  async def _generate_signature(self) -> str:
55
55
  await self._sync_time()
56
56
  timestamp = int(time.time()) + self.time_offset
57
- nonce = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=10))
57
+ nonce = "".join(
58
+ secrets.choice("abcdefghijklmnopqrstuvwxyz0123456789") for _ in range(10)
59
+ )
58
60
  path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/"
59
61
  return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}"
60
62
 
61
- async def get_audio(self, text: str, voice_params: Dict) -> str:
63
+ async def get_audio(self, text: str, voice_params: dict) -> str:
62
64
  file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav"
63
65
  signature = await self._generate_signature()
64
66
  for attempt in range(self.retry_count):
@@ -86,7 +88,7 @@ class OTTSProvider:
86
88
  return str(file_path.resolve())
87
89
  except Exception as e:
88
90
  if attempt == self.retry_count - 1:
89
- raise RuntimeError(f"OTTS请求失败: {str(e)}") from e
91
+ raise RuntimeError(f"OTTS请求失败: {e!s}") from e
90
92
  await asyncio.sleep(0.5 * (attempt + 1))
91
93
 
92
94
 
@@ -94,7 +96,8 @@ class AzureNativeProvider(TTSProvider):
94
96
  def __init__(self, provider_config: dict, provider_settings: dict):
95
97
  super().__init__(provider_config, provider_settings)
96
98
  self.subscription_key = provider_config.get(
97
- "azure_tts_subscription_key", ""
99
+ "azure_tts_subscription_key",
100
+ "",
98
101
  ).strip()
99
102
  if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
100
103
  raise ValueError("无效的Azure订阅密钥")
@@ -119,7 +122,7 @@ class AzureNativeProvider(TTSProvider):
119
122
  "User-Agent": f"AstrBot/{VERSION}",
120
123
  "Content-Type": "application/ssml+xml",
121
124
  "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
122
- }
125
+ },
123
126
  )
124
127
  return self
125
128
 
@@ -132,7 +135,8 @@ class AzureNativeProvider(TTSProvider):
132
135
  f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken"
133
136
  )
134
137
  response = await self.client.post(
135
- token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key}
138
+ token_url,
139
+ headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
136
140
  )
137
141
  response.raise_for_status()
138
142
  self.token = response.text