AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,12 +1,19 @@
1
+ import os
1
2
  import uuid
2
- from openai import AsyncOpenAI, NOT_GIVEN
3
- from ..provider import TTSProvider
3
+
4
+ from openai import NOT_GIVEN, AsyncOpenAI
5
+
6
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
4
8
  from ..entities import ProviderType
9
+ from ..provider import TTSProvider
5
10
  from ..register import register_provider_adapter
6
11
 
7
12
 
8
13
  @register_provider_adapter(
9
- "openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
14
+ "openai_tts_api",
15
+ "OpenAI TTS API",
16
+ provider_type=ProviderType.TEXT_TO_SPEECH,
10
17
  )
11
18
  class ProviderOpenAITTSAPI(TTSProvider):
12
19
  def __init__(
@@ -24,16 +31,20 @@ class ProviderOpenAITTSAPI(TTSProvider):
24
31
 
25
32
  self.client = AsyncOpenAI(
26
33
  api_key=self.chosen_api_key,
27
- base_url=provider_config.get("api_base", None),
34
+ base_url=provider_config.get("api_base"),
28
35
  timeout=timeout,
29
36
  )
30
37
 
31
- self.set_model(provider_config.get("model", None))
38
+ self.set_model(provider_config.get("model", ""))
32
39
 
33
40
  async def get_audio(self, text: str) -> str:
34
- path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav"
41
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
42
+ path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
35
43
  async with self.client.audio.speech.with_streaming_response.create(
36
- model=self.model_name, voice=self.voice, response_format="wav", input=text
44
+ model=self.model_name,
45
+ voice=self.voice,
46
+ response_format="wav",
47
+ input=text,
37
48
  ) as response:
38
49
  with open(path, "wb") as f:
39
50
  async for chunk in response.iter_bytes(chunk_size=1024):
@@ -1,22 +1,24 @@
1
- """
2
- Author: diudiu62
1
+ """Author: diudiu62
3
2
  Date: 2025-02-24 18:04:18
4
3
  LastEditTime: 2025-02-25 14:06:30
5
4
  """
6
5
 
7
6
  import asyncio
8
- from datetime import datetime
9
7
  import os
10
8
  import re
9
+ from datetime import datetime
10
+
11
11
  from funasr_onnx import SenseVoiceSmall
12
12
  from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
13
- from ..provider import STTProvider
14
- from ..entities import ProviderType
15
- from astrbot.core.utils.io import download_file
16
- from ..register import register_provider_adapter
13
+
17
14
  from astrbot.core import logger
15
+ from astrbot.core.utils.io import download_file
18
16
  from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
19
17
 
18
+ from ..entities import ProviderType
19
+ from ..provider import STTProvider
20
+ from ..register import register_provider_adapter
21
+
20
22
 
21
23
  @register_provider_adapter(
22
24
  "sensevoice_stt_selfhost",
@@ -30,7 +32,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
30
32
  provider_settings: dict,
31
33
  ) -> None:
32
34
  super().__init__(provider_config, provider_settings)
33
- self.set_model(provider_config.get("stt_model", None))
35
+ self.set_model(provider_config.get("stt_model"))
34
36
  self.model = None
35
37
  self.is_emotion = provider_config.get("is_emotion", False)
36
38
 
@@ -39,7 +41,8 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
39
41
 
40
42
  # 将模型加载放到线程池中执行
41
43
  self.model = await asyncio.get_event_loop().run_in_executor(
42
- None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16)
44
+ None,
45
+ lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
43
46
  )
44
47
 
45
48
  logger.info("SenseVoice 模型加载完成。")
@@ -55,8 +58,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
55
58
 
56
59
  if silk_header in file_header:
57
60
  return True
58
- else:
59
- return False
61
+ return False
60
62
 
61
63
  async def get_text(self, audio_url: str) -> str:
62
64
  try:
@@ -0,0 +1,71 @@
1
+ import aiohttp
2
+
3
+ from astrbot import logger
4
+
5
+ from ..entities import ProviderType, RerankResult
6
+ from ..provider import RerankProvider
7
+ from ..register import register_provider_adapter
8
+
9
+
10
+ @register_provider_adapter(
11
+ "vllm_rerank",
12
+ "VLLM Rerank 适配器",
13
+ provider_type=ProviderType.RERANK,
14
+ )
15
+ class VLLMRerankProvider(RerankProvider):
16
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
17
+ super().__init__(provider_config, provider_settings)
18
+ self.provider_config = provider_config
19
+ self.provider_settings = provider_settings
20
+ self.auth_key = provider_config.get("rerank_api_key", "")
21
+ self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
22
+ self.base_url = self.base_url.rstrip("/")
23
+ self.timeout = provider_config.get("timeout", 20)
24
+ self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
25
+
26
+ h = {}
27
+ if self.auth_key:
28
+ h["Authorization"] = f"Bearer {self.auth_key}"
29
+ self.client = aiohttp.ClientSession(
30
+ headers=h,
31
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
32
+ )
33
+
34
+ async def rerank(
35
+ self,
36
+ query: str,
37
+ documents: list[str],
38
+ top_n: int | None = None,
39
+ ) -> list[RerankResult]:
40
+ payload = {
41
+ "query": query,
42
+ "documents": documents,
43
+ "model": self.model,
44
+ }
45
+ if top_n is not None:
46
+ payload["top_n"] = top_n
47
+ async with self.client.post(
48
+ f"{self.base_url}/v1/rerank",
49
+ json=payload,
50
+ ) as response:
51
+ response_data = await response.json()
52
+ results = response_data.get("results", [])
53
+
54
+ if not results:
55
+ logger.warning(
56
+ f"Rerank API 返回了空的列表数据。原始响应: {response_data}",
57
+ )
58
+
59
+ return [
60
+ RerankResult(
61
+ index=result["index"],
62
+ relevance_score=result["relevance_score"],
63
+ )
64
+ for result in results
65
+ ]
66
+
67
+ async def terminate(self) -> None:
68
+ """关闭客户端会话"""
69
+ if self.client:
70
+ await self.client.close()
71
+ self.client = None
@@ -0,0 +1,115 @@
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ import traceback
6
+ import uuid
7
+
8
+ import aiohttp
9
+
10
+ from astrbot import logger
11
+
12
+ from ..entities import ProviderType
13
+ from ..provider import TTSProvider
14
+ from ..register import register_provider_adapter
15
+
16
+
17
+ @register_provider_adapter(
18
+ "volcengine_tts",
19
+ "火山引擎 TTS",
20
+ provider_type=ProviderType.TEXT_TO_SPEECH,
21
+ )
22
+ class ProviderVolcengineTTS(TTSProvider):
23
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
24
+ super().__init__(provider_config, provider_settings)
25
+ self.api_key = provider_config.get("api_key", "")
26
+ self.appid = provider_config.get("appid", "")
27
+ self.cluster = provider_config.get("volcengine_cluster", "")
28
+ self.voice_type = provider_config.get("volcengine_voice_type", "")
29
+ self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
30
+ self.api_base = provider_config.get(
31
+ "api_base",
32
+ "https://openspeech.bytedance.com/api/v1/tts",
33
+ )
34
+ self.timeout = provider_config.get("timeout", 20)
35
+
36
+ def _build_request_payload(self, text: str) -> dict:
37
+ return {
38
+ "app": {
39
+ "appid": self.appid,
40
+ "token": self.api_key,
41
+ "cluster": self.cluster,
42
+ },
43
+ "user": {"uid": str(uuid.uuid4())},
44
+ "audio": {
45
+ "voice_type": self.voice_type,
46
+ "encoding": "mp3",
47
+ "speed_ratio": self.speed_ratio,
48
+ "volume_ratio": 1.0,
49
+ "pitch_ratio": 1.0,
50
+ },
51
+ "request": {
52
+ "reqid": str(uuid.uuid4()),
53
+ "text": text,
54
+ "text_type": "plain",
55
+ "operation": "query",
56
+ "with_frontend": 1,
57
+ "frontend_type": "unitTson",
58
+ },
59
+ }
60
+
61
+ async def get_audio(self, text: str) -> str:
62
+ """异步方法获取语音文件路径"""
63
+ headers = {
64
+ "Content-Type": "application/json",
65
+ "Authorization": f"Bearer; {self.api_key}",
66
+ }
67
+
68
+ payload = self._build_request_payload(text)
69
+
70
+ logger.debug(f"请求头: {headers}")
71
+ logger.debug(f"请求 URL: {self.api_base}")
72
+ logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
73
+
74
+ try:
75
+ async with (
76
+ aiohttp.ClientSession() as session,
77
+ session.post(
78
+ self.api_base,
79
+ data=json.dumps(payload),
80
+ headers=headers,
81
+ timeout=self.timeout,
82
+ ) as response,
83
+ ):
84
+ logger.debug(f"响应状态码: {response.status}")
85
+
86
+ response_text = await response.text()
87
+ logger.debug(f"响应内容: {response_text[:200]}...")
88
+
89
+ if response.status == 200:
90
+ resp_data = json.loads(response_text)
91
+
92
+ if "data" in resp_data:
93
+ audio_data = base64.b64decode(resp_data["data"])
94
+
95
+ os.makedirs("data/temp", exist_ok=True)
96
+
97
+ file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
98
+
99
+ loop = asyncio.get_running_loop()
100
+ await loop.run_in_executor(
101
+ None,
102
+ lambda: open(file_path, "wb").write(audio_data),
103
+ )
104
+
105
+ return file_path
106
+ error_msg = resp_data.get("message", "未知错误")
107
+ raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
108
+ raise Exception(
109
+ f"火山引擎 TTS API 请求失败: {response.status}, {response_text}",
110
+ )
111
+
112
+ except Exception as e:
113
+ error_details = traceback.format_exc()
114
+ logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
115
+ raise Exception(f"火山引擎 TTS 异常: {e!s}")
@@ -1,13 +1,17 @@
1
- import uuid
2
1
  import os
3
- from openai import AsyncOpenAI, NOT_GIVEN
4
- from ..provider import STTProvider
5
- from ..entities import ProviderType
6
- from astrbot.core.utils.io import download_file
7
- from ..register import register_provider_adapter
2
+ import uuid
3
+
4
+ from openai import NOT_GIVEN, AsyncOpenAI
5
+
8
6
  from astrbot.core import logger
7
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
+ from astrbot.core.utils.io import download_file
9
9
  from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
10
10
 
11
+ from ..entities import ProviderType
12
+ from ..provider import STTProvider
13
+ from ..register import register_provider_adapter
14
+
11
15
 
12
16
  @register_provider_adapter(
13
17
  "openai_whisper_api",
@@ -25,11 +29,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
25
29
 
26
30
  self.client = AsyncOpenAI(
27
31
  api_key=self.chosen_api_key,
28
- base_url=provider_config.get("api_base", None),
32
+ base_url=provider_config.get("api_base"),
29
33
  timeout=provider_config.get("timeout", NOT_GIVEN),
30
34
  )
31
35
 
32
- self.set_model(provider_config.get("model", None))
36
+ self.set_model(provider_config.get("model"))
33
37
 
34
38
  async def _is_silk_file(self, file_path):
35
39
  silk_header = b"SILK"
@@ -38,11 +42,10 @@ class ProviderOpenAIWhisperAPI(STTProvider):
38
42
 
39
43
  if silk_header in file_header:
40
44
  return True
41
- else:
42
- return False
45
+ return False
43
46
 
44
47
  async def get_text(self, audio_url: str) -> str:
45
- """only supports mp3, mp4, mpeg, m4a, wav, webm"""
48
+ """Only supports mp3, mp4, mpeg, m4a, wav, webm"""
46
49
  is_tencent = False
47
50
 
48
51
  if audio_url.startswith("http"):
@@ -50,7 +53,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
50
53
  is_tencent = True
51
54
 
52
55
  name = str(uuid.uuid4())
53
- path = os.path.join("data/temp", name)
56
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
57
+ path = os.path.join(temp_dir, name)
54
58
  await download_file(audio_url, path)
55
59
  audio_url = path
56
60
 
@@ -61,7 +65,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
61
65
  is_silk = await self._is_silk_file(audio_url)
62
66
  if is_silk:
63
67
  logger.info("Converting silk file to wav ...")
64
- output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
68
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
69
+ output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
65
70
  await tencent_silk_to_wav(audio_url, output_path)
66
71
  audio_url = output_path
67
72
 
@@ -1,14 +1,18 @@
1
- import uuid
2
- import os
3
1
  import asyncio
2
+ import os
3
+ import uuid
4
+
4
5
  import whisper
5
- from ..provider import STTProvider
6
- from ..entities import ProviderType
7
- from astrbot.core.utils.io import download_file
8
- from ..register import register_provider_adapter
6
+
9
7
  from astrbot.core import logger
8
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
9
+ from astrbot.core.utils.io import download_file
10
10
  from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
11
11
 
12
+ from ..entities import ProviderType
13
+ from ..provider import STTProvider
14
+ from ..register import register_provider_adapter
15
+
12
16
 
13
17
  @register_provider_adapter(
14
18
  "openai_whisper_selfhost",
@@ -22,14 +26,16 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
22
26
  provider_settings: dict,
23
27
  ) -> None:
24
28
  super().__init__(provider_config, provider_settings)
25
- self.set_model(provider_config.get("model", None))
29
+ self.set_model(provider_config.get("model"))
26
30
  self.model = None
27
31
 
28
32
  async def initialize(self):
29
33
  loop = asyncio.get_event_loop()
30
34
  logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
31
35
  self.model = await loop.run_in_executor(
32
- None, whisper.load_model, self.model_name
36
+ None,
37
+ whisper.load_model,
38
+ self.model_name,
33
39
  )
34
40
  logger.info("Whisper 模型加载完成。")
35
41
 
@@ -40,8 +46,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
40
46
 
41
47
  if silk_header in file_header:
42
48
  return True
43
- else:
44
- return False
49
+ return False
45
50
 
46
51
  async def get_text(self, audio_url: str) -> str:
47
52
  loop = asyncio.get_event_loop()
@@ -53,7 +58,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
53
58
  is_tencent = True
54
59
 
55
60
  name = str(uuid.uuid4())
56
- path = os.path.join("data/temp", name)
61
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
62
+ path = os.path.join(temp_dir, name)
57
63
  await download_file(audio_url, path)
58
64
  audio_url = path
59
65
 
@@ -64,7 +70,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
64
70
  is_silk = await self._is_silk_file(audio_url)
65
71
  if is_silk:
66
72
  logger.info("Converting silk file to wav ...")
67
- output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav")
73
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
74
+ output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
68
75
  await tencent_silk_to_wav(audio_url, output_path)
69
76
  audio_url = output_path
70
77
 
@@ -0,0 +1,116 @@
1
+ from xinference_client.client.restful.async_restful_client import (
2
+ AsyncClient as Client,
3
+ )
4
+
5
+ from astrbot import logger
6
+
7
+ from ..entities import ProviderType, RerankResult
8
+ from ..provider import RerankProvider
9
+ from ..register import register_provider_adapter
10
+
11
+
12
+ @register_provider_adapter(
13
+ "xinference_rerank",
14
+ "Xinference Rerank 适配器",
15
+ provider_type=ProviderType.RERANK,
16
+ )
17
+ class XinferenceRerankProvider(RerankProvider):
18
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
19
+ super().__init__(provider_config, provider_settings)
20
+ self.provider_config = provider_config
21
+ self.provider_settings = provider_settings
22
+ self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
23
+ self.base_url = self.base_url.rstrip("/")
24
+ self.timeout = provider_config.get("timeout", 20)
25
+ self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
26
+ self.api_key = provider_config.get("rerank_api_key")
27
+ self.launch_model_if_not_running = provider_config.get(
28
+ "launch_model_if_not_running",
29
+ False,
30
+ )
31
+ self.client = None
32
+ self.model = None
33
+ self.model_uid = None
34
+
35
+ async def initialize(self):
36
+ if self.api_key:
37
+ logger.info("Xinference Rerank: Using API key for authentication.")
38
+ self.client = Client(self.base_url, api_key=self.api_key)
39
+ else:
40
+ logger.info("Xinference Rerank: No API key provided.")
41
+ self.client = Client(self.base_url)
42
+
43
+ try:
44
+ running_models = await self.client.list_models()
45
+ for uid, model_spec in running_models.items():
46
+ if model_spec.get("model_name") == self.model_name:
47
+ logger.info(
48
+ f"Model '{self.model_name}' is already running with UID: {uid}",
49
+ )
50
+ self.model_uid = uid
51
+ break
52
+
53
+ if self.model_uid is None:
54
+ if self.launch_model_if_not_running:
55
+ logger.info(f"Launching {self.model_name} model...")
56
+ self.model_uid = await self.client.launch_model(
57
+ model_name=self.model_name,
58
+ model_type="rerank",
59
+ )
60
+ logger.info("Model launched.")
61
+ else:
62
+ logger.warning(
63
+ f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available.",
64
+ )
65
+ return
66
+
67
+ if self.model_uid:
68
+ self.model = await self.client.get_model(self.model_uid)
69
+
70
+ except Exception as e:
71
+ logger.error(f"Failed to initialize Xinference model: {e}")
72
+ logger.debug(
73
+ f"Xinference initialization failed with exception: {e}",
74
+ exc_info=True,
75
+ )
76
+ self.model = None
77
+
78
+ async def rerank(
79
+ self,
80
+ query: str,
81
+ documents: list[str],
82
+ top_n: int | None = None,
83
+ ) -> list[RerankResult]:
84
+ if not self.model:
85
+ logger.error("Xinference rerank model is not initialized.")
86
+ return []
87
+ try:
88
+ response = await self.model.rerank(documents, query, top_n)
89
+ results = response.get("results", [])
90
+ logger.debug(f"Rerank API response: {response}")
91
+
92
+ if not results:
93
+ logger.warning(
94
+ f"Rerank API returned an empty list. Original response: {response}",
95
+ )
96
+
97
+ return [
98
+ RerankResult(
99
+ index=result["index"],
100
+ relevance_score=result["relevance_score"],
101
+ )
102
+ for result in results
103
+ ]
104
+ except Exception as e:
105
+ logger.error(f"Xinference rerank failed: {e}")
106
+ logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
107
+ return []
108
+
109
+ async def terminate(self) -> None:
110
+ """关闭客户端会话"""
111
+ if self.client:
112
+ logger.info("Closing Xinference rerank client...")
113
+ try:
114
+ await self.client.close()
115
+ except Exception as e:
116
+ logger.error(f"Failed to close Xinference client: {e}", exc_info=True)