AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,161 +1,170 @@
1
1
  import abc
2
- from typing import List
3
- from astrbot.core.db import BaseDatabase
4
- from typing import TypedDict, AsyncGenerator
5
- from astrbot.core.provider.func_tool_manager import FuncCall
6
- from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
7
- from dataclasses import dataclass
2
+ import asyncio
3
+ from collections.abc import AsyncGenerator
8
4
 
9
-
10
- class Personality(TypedDict):
11
- prompt: str = ""
12
- name: str = ""
13
- begin_dialogs: List[str] = []
14
- mood_imitation_dialogs: List[str] = []
15
-
16
- # cache
17
- _begin_dialogs_processed: List[dict] = []
18
- _mood_imitation_dialogs_processed: str = ""
19
-
20
-
21
- @dataclass
22
- class ProviderMeta:
23
- id: str
24
- model: str
25
- type: str
5
+ from astrbot.core.agent.message import Message
6
+ from astrbot.core.agent.tool import ToolSet
7
+ from astrbot.core.provider.entities import (
8
+ LLMResponse,
9
+ ProviderMeta,
10
+ RerankResult,
11
+ ToolCallsResult,
12
+ )
13
+ from astrbot.core.provider.register import provider_cls_map
26
14
 
27
15
 
28
16
  class AbstractProvider(abc.ABC):
17
+ """Provider Abstract Class"""
18
+
29
19
  def __init__(self, provider_config: dict) -> None:
30
20
  super().__init__()
31
21
  self.model_name = ""
32
22
  self.provider_config = provider_config
33
23
 
34
24
  def set_model(self, model_name: str):
35
- """设置当前使用的模型名称"""
25
+ """Set the current model name"""
36
26
  self.model_name = model_name
37
27
 
38
28
  def get_model(self) -> str:
39
- """获得当前使用的模型名称"""
29
+ """Get the current model name"""
40
30
  return self.model_name
41
31
 
42
32
  def meta(self) -> ProviderMeta:
43
- """获取 Provider 的元数据"""
44
- return ProviderMeta(
45
- id=self.provider_config["id"],
33
+ """Get the provider metadata"""
34
+ provider_type_name = self.provider_config["type"]
35
+ meta_data = provider_cls_map.get(provider_type_name)
36
+ if not meta_data:
37
+ raise ValueError(f"Provider type {provider_type_name} not registered")
38
+ meta = ProviderMeta(
39
+ id=self.provider_config.get("id", "default"),
46
40
  model=self.get_model(),
47
- type=self.provider_config["type"],
41
+ type=provider_type_name,
42
+ provider_type=meta_data.provider_type,
48
43
  )
44
+ return meta
49
45
 
50
46
 
51
47
  class Provider(AbstractProvider):
48
+ """Chat Provider"""
49
+
52
50
  def __init__(
53
51
  self,
54
52
  provider_config: dict,
55
53
  provider_settings: dict,
56
- persistant_history: bool = True,
57
- db_helper: BaseDatabase = None,
58
- default_persona: Personality = None,
59
54
  ) -> None:
60
55
  super().__init__(provider_config)
61
-
62
56
  self.provider_settings = provider_settings
63
57
 
64
- self.curr_personality: Personality = default_persona
65
- """维护了当前的使用的 persona,即人格。可能为 None"""
66
-
67
58
  @abc.abstractmethod
68
59
  def get_current_key(self) -> str:
69
- raise NotImplementedError()
60
+ raise NotImplementedError
70
61
 
71
- def get_keys(self) -> List[str]:
62
+ def get_keys(self) -> list[str]:
72
63
  """获得提供商 Key"""
73
- return self.provider_config.get("key", [])
64
+ keys = self.provider_config.get("key", [""])
65
+ return keys or [""]
74
66
 
75
67
  @abc.abstractmethod
76
68
  def set_key(self, key: str):
77
- raise NotImplementedError()
69
+ raise NotImplementedError
78
70
 
79
71
  @abc.abstractmethod
80
- def get_models(self) -> List[str]:
72
+ async def get_models(self) -> list[str]:
81
73
  """获得支持的模型列表"""
82
- raise NotImplementedError()
74
+ raise NotImplementedError
83
75
 
84
76
  @abc.abstractmethod
85
77
  async def text_chat(
86
78
  self,
87
- prompt: str,
88
- session_id: str = None,
89
- image_urls: List[str] = None,
90
- func_tool: FuncCall = None,
91
- contexts: List = None,
92
- system_prompt: str = None,
93
- tool_calls_result: ToolCallsResult = None,
79
+ prompt: str | None = None,
80
+ session_id: str | None = None,
81
+ image_urls: list[str] | None = None,
82
+ func_tool: ToolSet | None = None,
83
+ contexts: list[Message] | list[dict] | None = None,
84
+ system_prompt: str | None = None,
85
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
86
+ model: str | None = None,
94
87
  **kwargs,
95
88
  ) -> LLMResponse:
96
89
  """获得 LLM 的文本对话结果。会使用当前的模型进行对话。
97
90
 
98
91
  Args:
99
- prompt: 提示词
92
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
100
93
  session_id: 会话 ID(此属性已经被废弃)
101
94
  image_urls: 图片 URL 列表
102
- tools: Function-calling 工具
103
- contexts: 上下文
95
+ tools: tool set
96
+ contexts: 上下文,和 prompt 二选一使用
104
97
  tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
105
98
  kwargs: 其他参数
106
99
 
107
100
  Notes:
108
101
  - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
109
102
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
103
+
110
104
  """
111
105
  ...
112
106
 
113
107
  async def text_chat_stream(
114
108
  self,
115
- prompt: str,
116
- session_id: str = None,
117
- image_urls: List[str] = None,
118
- func_tool: FuncCall = None,
119
- contexts: List = None,
120
- system_prompt: str = None,
121
- tool_calls_result: ToolCallsResult = None,
109
+ prompt: str | None = None,
110
+ session_id: str | None = None,
111
+ image_urls: list[str] | None = None,
112
+ func_tool: ToolSet | None = None,
113
+ contexts: list[Message] | list[dict] | None = None,
114
+ system_prompt: str | None = None,
115
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
116
+ model: str | None = None,
122
117
  **kwargs,
123
118
  ) -> AsyncGenerator[LLMResponse, None]:
124
119
  """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
125
120
 
126
121
  Args:
127
- prompt: 提示词
122
+ prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
128
123
  session_id: 会话 ID(此属性已经被废弃)
129
124
  image_urls: 图片 URL 列表
130
- tools: Function-calling 工具
131
- contexts: 上下文
125
+ tools: tool set
126
+ contexts: 上下文,和 prompt 二选一使用
132
127
  tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
133
128
  kwargs: 其他参数
134
129
 
135
130
  Notes:
136
131
  - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
137
132
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
133
+
138
134
  """
139
135
  ...
140
136
 
141
- async def pop_record(self, context: List):
142
- """
143
- 弹出 context 第一条非系统提示词对话记录
144
- """
137
+ async def pop_record(self, context: list):
138
+ """弹出 context 第一条非系统提示词对话记录"""
145
139
  poped = 0
146
140
  indexs_to_pop = []
147
141
  for idx, record in enumerate(context):
148
142
  if record["role"] == "system":
149
143
  continue
150
- else:
151
- indexs_to_pop.append(idx)
152
- poped += 1
153
- if poped == 2:
154
- break
144
+ indexs_to_pop.append(idx)
145
+ poped += 1
146
+ if poped == 2:
147
+ break
155
148
 
156
149
  for idx in reversed(indexs_to_pop):
157
150
  context.pop(idx)
158
151
 
152
+ def _ensure_message_to_dicts(
153
+ self,
154
+ messages: list[dict] | list[Message] | None,
155
+ ) -> list[dict]:
156
+ """Convert a list of Message objects to a list of dictionaries."""
157
+ if not messages:
158
+ return []
159
+ dicts: list[dict] = []
160
+ for message in messages:
161
+ if isinstance(message, Message):
162
+ dicts.append(message.model_dump())
163
+ else:
164
+ dicts.append(message)
165
+
166
+ return dicts
167
+
159
168
 
160
169
  class STTProvider(AbstractProvider):
161
170
  def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -166,7 +175,7 @@ class STTProvider(AbstractProvider):
166
175
  @abc.abstractmethod
167
176
  async def get_text(self, audio_url: str) -> str:
168
177
  """获取音频的文本"""
169
- raise NotImplementedError()
178
+ raise NotImplementedError
170
179
 
171
180
 
172
181
  class TTSProvider(AbstractProvider):
@@ -178,4 +187,110 @@ class TTSProvider(AbstractProvider):
178
187
  @abc.abstractmethod
179
188
  async def get_audio(self, text: str) -> str:
180
189
  """获取文本的音频,返回音频文件路径"""
181
- raise NotImplementedError()
190
+ raise NotImplementedError
191
+
192
+
193
+ class EmbeddingProvider(AbstractProvider):
194
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
195
+ super().__init__(provider_config)
196
+ self.provider_config = provider_config
197
+ self.provider_settings = provider_settings
198
+
199
+ @abc.abstractmethod
200
+ async def get_embedding(self, text: str) -> list[float]:
201
+ """获取文本的向量"""
202
+ ...
203
+
204
+ @abc.abstractmethod
205
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
206
+ """批量获取文本的向量"""
207
+ ...
208
+
209
+ @abc.abstractmethod
210
+ def get_dim(self) -> int:
211
+ """获取向量的维度"""
212
+ ...
213
+
214
+ async def get_embeddings_batch(
215
+ self,
216
+ texts: list[str],
217
+ batch_size: int = 16,
218
+ tasks_limit: int = 3,
219
+ max_retries: int = 3,
220
+ progress_callback=None,
221
+ ) -> list[list[float]]:
222
+ """批量获取文本的向量,分批处理以节省内存
223
+
224
+ Args:
225
+ texts: 文本列表
226
+ batch_size: 每批处理的文本数量
227
+ tasks_limit: 并发任务数量限制
228
+ max_retries: 失败时的最大重试次数
229
+ progress_callback: 进度回调函数,接收参数 (current, total)
230
+
231
+ Returns:
232
+ 向量列表
233
+
234
+ """
235
+ semaphore = asyncio.Semaphore(tasks_limit)
236
+ all_embeddings: list[list[float]] = []
237
+ failed_batches: list[tuple[int, list[str]]] = []
238
+ completed_count = 0
239
+ total_count = len(texts)
240
+
241
+ async def process_batch(batch_idx: int, batch_texts: list[str]):
242
+ nonlocal completed_count
243
+ async with semaphore:
244
+ for attempt in range(max_retries):
245
+ try:
246
+ batch_embeddings = await self.get_embeddings(batch_texts)
247
+ all_embeddings.extend(batch_embeddings)
248
+ completed_count += len(batch_texts)
249
+ if progress_callback:
250
+ await progress_callback(completed_count, total_count)
251
+ return
252
+ except Exception as e:
253
+ if attempt == max_retries - 1:
254
+ # 最后一次重试失败,记录失败的批次
255
+ failed_batches.append((batch_idx, batch_texts))
256
+ raise Exception(
257
+ f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
258
+ )
259
+ # 等待一段时间后重试,使用指数退避
260
+ await asyncio.sleep(2**attempt)
261
+
262
+ tasks = []
263
+ for i in range(0, len(texts), batch_size):
264
+ batch_texts = texts[i : i + batch_size]
265
+ batch_idx = i // batch_size
266
+ tasks.append(process_batch(batch_idx, batch_texts))
267
+
268
+ # 收集所有任务的结果,包括失败的任务
269
+ results = await asyncio.gather(*tasks, return_exceptions=True)
270
+
271
+ # 检查是否有失败的任务
272
+ errors = [r for r in results if isinstance(r, Exception)]
273
+ if errors:
274
+ error_msg = (
275
+ f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
276
+ )
277
+ raise Exception(error_msg)
278
+
279
+ return all_embeddings
280
+
281
+
282
+ class RerankProvider(AbstractProvider):
283
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
284
+ super().__init__(provider_config)
285
+ self.provider_config = provider_config
286
+ self.provider_settings = provider_settings
287
+
288
+ @abc.abstractmethod
289
+ async def rerank(
290
+ self,
291
+ query: str,
292
+ documents: list[str],
293
+ top_n: int | None = None,
294
+ ) -> list[RerankResult]:
295
+ """获取查询和文档的重排序分数"""
296
+ ...
@@ -1,11 +1,11 @@
1
- from typing import List, Dict
2
- from .entities import ProviderMetaData, ProviderType
3
1
  from astrbot.core import logger
2
+
3
+ from .entities import ProviderMetaData, ProviderType
4
4
  from .func_tool_manager import FuncCall
5
5
 
6
- provider_registry: List[ProviderMetaData] = []
6
+ provider_registry: list[ProviderMetaData] = []
7
7
  """维护了通过装饰器注册的 Provider"""
8
- provider_cls_map: Dict[str, ProviderMetaData] = {}
8
+ provider_cls_map: dict[str, ProviderMetaData] = {}
9
9
  """维护了 Provider 类型名称和 ProviderMetadata 的映射"""
10
10
 
11
11
  llm_tools = FuncCall()
@@ -15,15 +15,15 @@ def register_provider_adapter(
15
15
  provider_type_name: str,
16
16
  desc: str,
17
17
  provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
18
- default_config_tmpl: dict = None,
19
- provider_display_name: str = None,
18
+ default_config_tmpl: dict | None = None,
19
+ provider_display_name: str | None = None,
20
20
  ):
21
21
  """用于注册平台适配器的带参装饰器"""
22
22
 
23
23
  def decorator(cls):
24
24
  if provider_type_name in provider_cls_map:
25
25
  raise ValueError(
26
- f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。"
26
+ f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
27
27
  )
28
28
 
29
29
  # 添加必备选项
@@ -36,6 +36,8 @@ def register_provider_adapter(
36
36
  default_config_tmpl["id"] = provider_type_name
37
37
 
38
38
  pm = ProviderMetaData(
39
+ id="default", # will be replaced when instantiated
40
+ model=None,
39
41
  type=provider_type_name,
40
42
  desc=desc,
41
43
  provider_type=provider_type,