AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +44 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +18 -13
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +40 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +102 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  181. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +109 -84
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.0.dist-info/RECORD +0 -258
  242. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,22 +1,24 @@
1
- """
2
- Author: diudiu62
1
+ """Author: diudiu62
3
2
  Date: 2025-02-24 18:04:18
4
3
  LastEditTime: 2025-02-25 14:06:30
5
4
  """
6
5
 
7
6
  import asyncio
8
- from datetime import datetime
9
7
  import os
10
8
  import re
9
+ from datetime import datetime
10
+
11
11
  from funasr_onnx import SenseVoiceSmall
12
12
  from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
13
- from ..provider import STTProvider
14
- from ..entities import ProviderType
15
- from astrbot.core.utils.io import download_file
16
- from ..register import register_provider_adapter
13
+
17
14
  from astrbot.core import logger
15
+ from astrbot.core.utils.io import download_file
18
16
  from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
19
17
 
18
+ from ..entities import ProviderType
19
+ from ..provider import STTProvider
20
+ from ..register import register_provider_adapter
21
+
20
22
 
21
23
  @register_provider_adapter(
22
24
  "sensevoice_stt_selfhost",
@@ -30,7 +32,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
30
32
  provider_settings: dict,
31
33
  ) -> None:
32
34
  super().__init__(provider_config, provider_settings)
33
- self.set_model(provider_config.get("stt_model", None))
35
+ self.set_model(provider_config.get("stt_model"))
34
36
  self.model = None
35
37
  self.is_emotion = provider_config.get("is_emotion", False)
36
38
 
@@ -39,7 +41,8 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
39
41
 
40
42
  # 将模型加载放到线程池中执行
41
43
  self.model = await asyncio.get_event_loop().run_in_executor(
42
- None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16)
44
+ None,
45
+ lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
43
46
  )
44
47
 
45
48
  logger.info("SenseVoice 模型加载完成。")
@@ -55,8 +58,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
55
58
 
56
59
  if silk_header in file_header:
57
60
  return True
58
- else:
59
- return False
61
+ return False
60
62
 
61
63
  async def get_text(self, audio_url: str) -> str:
62
64
  try:
@@ -1,8 +1,10 @@
1
1
  import aiohttp
2
+
2
3
  from astrbot import logger
4
+
5
+ from ..entities import ProviderType, RerankResult
3
6
  from ..provider import RerankProvider
4
7
  from ..register import register_provider_adapter
5
- from ..entities import ProviderType, RerankResult
6
8
 
7
9
 
8
10
  @register_provider_adapter(
@@ -30,7 +32,10 @@ class VLLMRerankProvider(RerankProvider):
30
32
  )
31
33
 
32
34
  async def rerank(
33
- self, query: str, documents: list[str], top_n: int | None = None
35
+ self,
36
+ query: str,
37
+ documents: list[str],
38
+ top_n: int | None = None,
34
39
  ) -> list[RerankResult]:
35
40
  payload = {
36
41
  "query": query,
@@ -40,14 +45,15 @@ class VLLMRerankProvider(RerankProvider):
40
45
  if top_n is not None:
41
46
  payload["top_n"] = top_n
42
47
  async with self.client.post(
43
- f"{self.base_url}/v1/rerank", json=payload
48
+ f"{self.base_url}/v1/rerank",
49
+ json=payload,
44
50
  ) as response:
45
51
  response_data = await response.json()
46
52
  results = response_data.get("results", [])
47
53
 
48
54
  if not results:
49
55
  logger.warning(
50
- f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
56
+ f"Rerank API 返回了空的列表数据。原始响应: {response_data}",
51
57
  )
52
58
 
53
59
  return [
@@ -1,18 +1,23 @@
1
- import uuid
1
+ import asyncio
2
2
  import base64
3
3
  import json
4
4
  import os
5
5
  import traceback
6
- import asyncio
6
+ import uuid
7
+
7
8
  import aiohttp
8
- from ..provider import TTSProvider
9
+
10
+ from astrbot import logger
11
+
9
12
  from ..entities import ProviderType
13
+ from ..provider import TTSProvider
10
14
  from ..register import register_provider_adapter
11
- from astrbot import logger
12
15
 
13
16
 
14
17
  @register_provider_adapter(
15
- "volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH
18
+ "volcengine_tts",
19
+ "火山引擎 TTS",
20
+ provider_type=ProviderType.TEXT_TO_SPEECH,
16
21
  )
17
22
  class ProviderVolcengineTTS(TTSProvider):
18
23
  def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -23,7 +28,8 @@ class ProviderVolcengineTTS(TTSProvider):
23
28
  self.voice_type = provider_config.get("volcengine_voice_type", "")
24
29
  self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0)
25
30
  self.api_base = provider_config.get(
26
- "api_base", "https://openspeech.bytedance.com/api/v1/tts"
31
+ "api_base",
32
+ "https://openspeech.bytedance.com/api/v1/tts",
27
33
  )
28
34
  self.timeout = provider_config.get("timeout", 20)
29
35
 
@@ -66,43 +72,44 @@ class ProviderVolcengineTTS(TTSProvider):
66
72
  logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...")
67
73
 
68
74
  try:
69
- async with aiohttp.ClientSession() as session:
70
- async with session.post(
75
+ async with (
76
+ aiohttp.ClientSession() as session,
77
+ session.post(
71
78
  self.api_base,
72
79
  data=json.dumps(payload),
73
80
  headers=headers,
74
81
  timeout=self.timeout,
75
- ) as response:
76
- logger.debug(f"响应状态码: {response.status}")
77
-
78
- response_text = await response.text()
79
- logger.debug(f"响应内容: {response_text[:200]}...")
82
+ ) as response,
83
+ ):
84
+ logger.debug(f"响应状态码: {response.status}")
80
85
 
81
- if response.status == 200:
82
- resp_data = json.loads(response_text)
86
+ response_text = await response.text()
87
+ logger.debug(f"响应内容: {response_text[:200]}...")
83
88
 
84
- if "data" in resp_data:
85
- audio_data = base64.b64decode(resp_data["data"])
89
+ if response.status == 200:
90
+ resp_data = json.loads(response_text)
86
91
 
87
- os.makedirs("data/temp", exist_ok=True)
92
+ if "data" in resp_data:
93
+ audio_data = base64.b64decode(resp_data["data"])
88
94
 
89
- file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
95
+ os.makedirs("data/temp", exist_ok=True)
90
96
 
91
- loop = asyncio.get_running_loop()
92
- await loop.run_in_executor(
93
- None, lambda: open(file_path, "wb").write(audio_data)
94
- )
97
+ file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
95
98
 
96
- return file_path
97
- else:
98
- error_msg = resp_data.get("message", "未知错误")
99
- raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
100
- else:
101
- raise Exception(
102
- f"火山引擎 TTS API 请求失败: {response.status}, {response_text}"
99
+ loop = asyncio.get_running_loop()
100
+ await loop.run_in_executor(
101
+ None,
102
+ lambda: open(file_path, "wb").write(audio_data),
103
103
  )
104
104
 
105
+ return file_path
106
+ error_msg = resp_data.get("message", "未知错误")
107
+ raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}")
108
+ raise Exception(
109
+ f"火山引擎 TTS API 请求失败: {response.status}, {response_text}",
110
+ )
111
+
105
112
  except Exception as e:
106
113
  error_details = traceback.format_exc()
107
114
  logger.debug(f"火山引擎 TTS 异常详情: {error_details}")
108
- raise Exception(f"火山引擎 TTS 异常: {str(e)}")
115
+ raise Exception(f"火山引擎 TTS 异常: {e!s}")
@@ -1,13 +1,16 @@
1
- import uuid
2
1
  import os
3
- from openai import AsyncOpenAI, NOT_GIVEN
4
- from ..provider import STTProvider
5
- from ..entities import ProviderType
6
- from astrbot.core.utils.io import download_file
7
- from ..register import register_provider_adapter
2
+ import uuid
3
+
4
+ from openai import NOT_GIVEN, AsyncOpenAI
5
+
8
6
  from astrbot.core import logger
9
- from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
10
7
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
+ from astrbot.core.utils.io import download_file
9
+ from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
10
+
11
+ from ..entities import ProviderType
12
+ from ..provider import STTProvider
13
+ from ..register import register_provider_adapter
11
14
 
12
15
 
13
16
  @register_provider_adapter(
@@ -26,11 +29,11 @@ class ProviderOpenAIWhisperAPI(STTProvider):
26
29
 
27
30
  self.client = AsyncOpenAI(
28
31
  api_key=self.chosen_api_key,
29
- base_url=provider_config.get("api_base", None),
32
+ base_url=provider_config.get("api_base"),
30
33
  timeout=provider_config.get("timeout", NOT_GIVEN),
31
34
  )
32
35
 
33
- self.set_model(provider_config.get("model", None))
36
+ self.set_model(provider_config.get("model"))
34
37
 
35
38
  async def _is_silk_file(self, file_path):
36
39
  silk_header = b"SILK"
@@ -39,11 +42,10 @@ class ProviderOpenAIWhisperAPI(STTProvider):
39
42
 
40
43
  if silk_header in file_header:
41
44
  return True
42
- else:
43
- return False
45
+ return False
44
46
 
45
47
  async def get_text(self, audio_url: str) -> str:
46
- """only supports mp3, mp4, mpeg, m4a, wav, webm"""
48
+ """Only supports mp3, mp4, mpeg, m4a, wav, webm"""
47
49
  is_tencent = False
48
50
 
49
51
  if audio_url.startswith("http"):
@@ -1,14 +1,17 @@
1
- import uuid
2
- import os
3
1
  import asyncio
2
+ import os
3
+ import uuid
4
+
4
5
  import whisper
5
- from ..provider import STTProvider
6
- from ..entities import ProviderType
7
- from astrbot.core.utils.io import download_file
8
- from ..register import register_provider_adapter
6
+
9
7
  from astrbot.core import logger
10
- from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
11
8
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
9
+ from astrbot.core.utils.io import download_file
10
+ from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
11
+
12
+ from ..entities import ProviderType
13
+ from ..provider import STTProvider
14
+ from ..register import register_provider_adapter
12
15
 
13
16
 
14
17
  @register_provider_adapter(
@@ -23,14 +26,16 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
23
26
  provider_settings: dict,
24
27
  ) -> None:
25
28
  super().__init__(provider_config, provider_settings)
26
- self.set_model(provider_config.get("model", None))
29
+ self.set_model(provider_config.get("model"))
27
30
  self.model = None
28
31
 
29
32
  async def initialize(self):
30
33
  loop = asyncio.get_event_loop()
31
34
  logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
32
35
  self.model = await loop.run_in_executor(
33
- None, whisper.load_model, self.model_name
36
+ None,
37
+ whisper.load_model,
38
+ self.model_name,
34
39
  )
35
40
  logger.info("Whisper 模型加载完成。")
36
41
 
@@ -41,8 +46,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
41
46
 
42
47
  if silk_header in file_header:
43
48
  return True
44
- else:
45
- return False
49
+ return False
46
50
 
47
51
  async def get_text(self, audio_url: str) -> str:
48
52
  loop = asyncio.get_event_loop()
@@ -0,0 +1,116 @@
1
+ from xinference_client.client.restful.async_restful_client import (
2
+ AsyncClient as Client,
3
+ )
4
+
5
+ from astrbot import logger
6
+
7
+ from ..entities import ProviderType, RerankResult
8
+ from ..provider import RerankProvider
9
+ from ..register import register_provider_adapter
10
+
11
+
12
+ @register_provider_adapter(
13
+ "xinference_rerank",
14
+ "Xinference Rerank 适配器",
15
+ provider_type=ProviderType.RERANK,
16
+ )
17
+ class XinferenceRerankProvider(RerankProvider):
18
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
19
+ super().__init__(provider_config, provider_settings)
20
+ self.provider_config = provider_config
21
+ self.provider_settings = provider_settings
22
+ self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
23
+ self.base_url = self.base_url.rstrip("/")
24
+ self.timeout = provider_config.get("timeout", 20)
25
+ self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
26
+ self.api_key = provider_config.get("rerank_api_key")
27
+ self.launch_model_if_not_running = provider_config.get(
28
+ "launch_model_if_not_running",
29
+ False,
30
+ )
31
+ self.client = None
32
+ self.model = None
33
+ self.model_uid = None
34
+
35
+ async def initialize(self):
36
+ if self.api_key:
37
+ logger.info("Xinference Rerank: Using API key for authentication.")
38
+ self.client = Client(self.base_url, api_key=self.api_key)
39
+ else:
40
+ logger.info("Xinference Rerank: No API key provided.")
41
+ self.client = Client(self.base_url)
42
+
43
+ try:
44
+ running_models = await self.client.list_models()
45
+ for uid, model_spec in running_models.items():
46
+ if model_spec.get("model_name") == self.model_name:
47
+ logger.info(
48
+ f"Model '{self.model_name}' is already running with UID: {uid}",
49
+ )
50
+ self.model_uid = uid
51
+ break
52
+
53
+ if self.model_uid is None:
54
+ if self.launch_model_if_not_running:
55
+ logger.info(f"Launching {self.model_name} model...")
56
+ self.model_uid = await self.client.launch_model(
57
+ model_name=self.model_name,
58
+ model_type="rerank",
59
+ )
60
+ logger.info("Model launched.")
61
+ else:
62
+ logger.warning(
63
+ f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available.",
64
+ )
65
+ return
66
+
67
+ if self.model_uid:
68
+ self.model = await self.client.get_model(self.model_uid)
69
+
70
+ except Exception as e:
71
+ logger.error(f"Failed to initialize Xinference model: {e}")
72
+ logger.debug(
73
+ f"Xinference initialization failed with exception: {e}",
74
+ exc_info=True,
75
+ )
76
+ self.model = None
77
+
78
+ async def rerank(
79
+ self,
80
+ query: str,
81
+ documents: list[str],
82
+ top_n: int | None = None,
83
+ ) -> list[RerankResult]:
84
+ if not self.model:
85
+ logger.error("Xinference rerank model is not initialized.")
86
+ return []
87
+ try:
88
+ response = await self.model.rerank(documents, query, top_n)
89
+ results = response.get("results", [])
90
+ logger.debug(f"Rerank API response: {response}")
91
+
92
+ if not results:
93
+ logger.warning(
94
+ f"Rerank API returned an empty list. Original response: {response}",
95
+ )
96
+
97
+ return [
98
+ RerankResult(
99
+ index=result["index"],
100
+ relevance_score=result["relevance_score"],
101
+ )
102
+ for result in results
103
+ ]
104
+ except Exception as e:
105
+ logger.error(f"Xinference rerank failed: {e}")
106
+ logger.debug(f"Xinference rerank failed with exception: {e}", exc_info=True)
107
+ return []
108
+
109
+ async def terminate(self) -> None:
110
+ """关闭客户端会话"""
111
+ if self.client:
112
+ logger.info("Closing Xinference rerank client...")
113
+ try:
114
+ await self.client.close()
115
+ except Exception as e:
116
+ logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
@@ -0,0 +1,197 @@
1
+ import os
2
+ import uuid
3
+
4
+ import aiohttp
5
+ from xinference_client.client.restful.async_restful_client import (
6
+ AsyncClient as Client,
7
+ )
8
+
9
+ from astrbot.core import logger
10
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
11
+ from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
12
+
13
+ from ..entities import ProviderType
14
+ from ..provider import STTProvider
15
+ from ..register import register_provider_adapter
16
+
17
+
18
+ @register_provider_adapter(
19
+ "xinference_stt",
20
+ "Xinference STT",
21
+ provider_type=ProviderType.SPEECH_TO_TEXT,
22
+ )
23
+ class ProviderXinferenceSTT(STTProvider):
24
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
25
+ super().__init__(provider_config, provider_settings)
26
+ self.provider_config = provider_config
27
+ self.provider_settings = provider_settings
28
+ self.base_url = provider_config.get("api_base", "http://127.0.0.1:9997")
29
+ self.base_url = self.base_url.rstrip("/")
30
+ self.timeout = provider_config.get("timeout", 180)
31
+ self.model_name = provider_config.get("model", "whisper-large-v3")
32
+ self.api_key = provider_config.get("api_key")
33
+ self.launch_model_if_not_running = provider_config.get(
34
+ "launch_model_if_not_running",
35
+ False,
36
+ )
37
+ self.client = None
38
+ self.model_uid = None
39
+
40
+ async def initialize(self):
41
+ if self.api_key:
42
+ logger.info("Xinference STT: Using API key for authentication.")
43
+ self.client = Client(self.base_url, api_key=self.api_key)
44
+ else:
45
+ logger.info("Xinference STT: No API key provided.")
46
+ self.client = Client(self.base_url)
47
+
48
+ try:
49
+ running_models = await self.client.list_models()
50
+ for uid, model_spec in running_models.items():
51
+ if model_spec.get("model_name") == self.model_name:
52
+ logger.info(
53
+ f"Model '{self.model_name}' is already running with UID: {uid}",
54
+ )
55
+ self.model_uid = uid
56
+ break
57
+
58
+ if self.model_uid is None:
59
+ if self.launch_model_if_not_running:
60
+ logger.info(f"Launching {self.model_name} model...")
61
+ self.model_uid = await self.client.launch_model(
62
+ model_name=self.model_name,
63
+ model_type="audio",
64
+ )
65
+ logger.info("Model launched.")
66
+ else:
67
+ logger.warning(
68
+ f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available.",
69
+ )
70
+ return
71
+
72
+ except Exception as e:
73
+ logger.error(f"Failed to initialize Xinference model: {e}")
74
+ logger.debug(
75
+ f"Xinference initialization failed with exception: {e}",
76
+ exc_info=True,
77
+ )
78
+
79
+ async def get_text(self, audio_url: str) -> str:
80
+ if not self.model_uid or self.client is None or self.client.session is None:
81
+ logger.error("Xinference STT model is not initialized.")
82
+ return ""
83
+
84
+ audio_bytes = None
85
+ temp_files = []
86
+ is_tencent = False
87
+
88
+ try:
89
+ # 1. Get audio bytes
90
+ if audio_url.startswith("http"):
91
+ if "multimedia.nt.qq.com.cn" in audio_url:
92
+ is_tencent = True
93
+ async with aiohttp.ClientSession() as session:
94
+ async with session.get(audio_url, timeout=self.timeout) as resp:
95
+ if resp.status == 200:
96
+ audio_bytes = await resp.read()
97
+ else:
98
+ logger.error(
99
+ f"Failed to download audio from {audio_url}, status: {resp.status}",
100
+ )
101
+ return ""
102
+ elif os.path.exists(audio_url):
103
+ with open(audio_url, "rb") as f:
104
+ audio_bytes = f.read()
105
+ else:
106
+ logger.error(f"File not found: {audio_url}")
107
+ return ""
108
+
109
+ if not audio_bytes:
110
+ logger.error("Audio bytes are empty.")
111
+ return ""
112
+
113
+ # 2. Check for conversion
114
+ needs_conversion = False
115
+ if (
116
+ audio_url.endswith((".amr", ".silk"))
117
+ or is_tencent
118
+ or b"SILK" in audio_bytes[:8]
119
+ ):
120
+ needs_conversion = True
121
+
122
+ # 3. Perform conversion if needed
123
+ if needs_conversion:
124
+ logger.info("Audio requires conversion, using temporary files...")
125
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
126
+ os.makedirs(temp_dir, exist_ok=True)
127
+
128
+ input_path = os.path.join(temp_dir, str(uuid.uuid4()))
129
+ output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
130
+ temp_files.extend([input_path, output_path])
131
+
132
+ with open(input_path, "wb") as f:
133
+ f.write(audio_bytes)
134
+
135
+ logger.info("Converting silk/amr file to wav ...")
136
+ await tencent_silk_to_wav(input_path, output_path)
137
+
138
+ with open(output_path, "rb") as f:
139
+ audio_bytes = f.read()
140
+
141
+ # 4. Transcribe
142
+ # 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来
143
+ url = f"{self.base_url}/v1/audio/transcriptions"
144
+ headers = {
145
+ "accept": "application/json",
146
+ }
147
+ if self.client and self.client._headers:
148
+ headers.update(self.client._headers)
149
+
150
+ data = aiohttp.FormData()
151
+ data.add_field("model", self.model_uid)
152
+ data.add_field(
153
+ "file",
154
+ audio_bytes,
155
+ filename="audio.wav",
156
+ content_type="audio/wav",
157
+ )
158
+
159
+ async with self.client.session.post(
160
+ url,
161
+ data=data,
162
+ headers=headers,
163
+ timeout=self.timeout,
164
+ ) as resp:
165
+ if resp.status == 200:
166
+ result = await resp.json()
167
+ text = result.get("text", "")
168
+ logger.debug(f"Xinference STT result: {text}")
169
+ return text
170
+ error_text = await resp.text()
171
+ logger.error(
172
+ f"Xinference STT transcription failed with status {resp.status}: {error_text}",
173
+ )
174
+ return ""
175
+
176
+ except Exception as e:
177
+ logger.error(f"Xinference STT failed: {e}")
178
+ logger.debug(f"Xinference STT failed with exception: {e}", exc_info=True)
179
+ return ""
180
+ finally:
181
+ # 5. Cleanup
182
+ for temp_file in temp_files:
183
+ try:
184
+ if os.path.exists(temp_file):
185
+ os.remove(temp_file)
186
+ logger.debug(f"Removed temporary file: {temp_file}")
187
+ except Exception as e:
188
+ logger.error(f"Failed to remove temporary file {temp_file}: {e}")
189
+
190
+ async def terminate(self) -> None:
191
+ """关闭客户端会话"""
192
+ if self.client:
193
+ logger.info("Closing Xinference STT client...")
194
+ try:
195
+ await self.client.close()
196
+ except Exception as e:
197
+ logger.error(f"Failed to close Xinference client: {e}", exc_info=True)
@@ -1,10 +1,11 @@
1
- from .star import StarMetadata, star_map, star_registry
2
- from .star_manager import PluginManager
3
- from .context import Context
4
- from astrbot.core.provider import Provider
5
- from astrbot.core.utils.command_parser import CommandParserMixin
6
1
  from astrbot.core import html_renderer
2
+ from astrbot.core.provider import Provider
7
3
  from astrbot.core.star.star_tools import StarTools
4
+ from astrbot.core.utils.command_parser import CommandParserMixin
5
+
6
+ from .context import Context
7
+ from .star import StarMetadata, star_map, star_registry
8
+ from .star_manager import PluginManager
8
9
 
9
10
 
10
11
  class Star(CommandParserMixin):
@@ -36,24 +37,28 @@ class Star(CommandParserMixin):
36
37
  )
37
38
 
38
39
  async def html_render(
39
- self, tmpl: str, data: dict, return_url=True, options: dict | None = None
40
+ self,
41
+ tmpl: str,
42
+ data: dict,
43
+ return_url=True,
44
+ options: dict | None = None,
40
45
  ) -> str:
41
46
  """渲染 HTML"""
42
47
  return await html_renderer.render_custom_template(
43
- tmpl, data, return_url=return_url, options=options
48
+ tmpl,
49
+ data,
50
+ return_url=return_url,
51
+ options=options,
44
52
  )
45
53
 
46
54
  async def initialize(self):
47
55
  """当插件被激活时会调用这个方法"""
48
- pass
49
56
 
50
57
  async def terminate(self):
51
58
  """当插件被禁用、重载插件时会调用这个方法"""
52
- pass
53
59
 
54
60
  def __del__(self):
55
61
  """[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
56
- pass
57
62
 
58
63
 
59
- __all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
64
+ __all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"]