AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,81 @@
1
+ import os
2
+ import uuid
3
+ import wave
4
+
5
+ from google import genai
6
+ from google.genai import types
7
+
8
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
9
+
10
+ from ..entities import ProviderType
11
+ from ..provider import TTSProvider
12
+ from ..register import register_provider_adapter
13
+
14
+
15
+ @register_provider_adapter(
16
+ "gemini_tts",
17
+ "Gemini TTS API",
18
+ provider_type=ProviderType.TEXT_TO_SPEECH,
19
+ )
20
+ class ProviderGeminiTTSAPI(TTSProvider):
21
+ def __init__(
22
+ self,
23
+ provider_config: dict,
24
+ provider_settings: dict,
25
+ ) -> None:
26
+ super().__init__(provider_config, provider_settings)
27
+ api_key: str = provider_config.get("gemini_tts_api_key", "")
28
+ api_base: str | None = provider_config.get("gemini_tts_api_base")
29
+ timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
30
+ http_options = types.HttpOptions(timeout=timeout * 1000)
31
+
32
+ if api_base:
33
+ api_base = api_base.removesuffix("/")
34
+ http_options.base_url = api_base
35
+
36
+ self.client = genai.Client(api_key=api_key, http_options=http_options).aio
37
+ self.model: str = provider_config.get(
38
+ "gemini_tts_model",
39
+ "gemini-2.5-flash-preview-tts",
40
+ )
41
+ self.prefix: str | None = provider_config.get(
42
+ "gemini_tts_prefix",
43
+ )
44
+ self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
45
+
46
+ async def get_audio(self, text: str) -> str:
47
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
48
+ path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
49
+ prompt = f"{self.prefix}: {text}" if self.prefix else text
50
+ response = await self.client.models.generate_content(
51
+ model=self.model,
52
+ contents=prompt,
53
+ config=types.GenerateContentConfig(
54
+ response_modalities=["AUDIO"],
55
+ speech_config=types.SpeechConfig(
56
+ voice_config=types.VoiceConfig(
57
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(
58
+ voice_name=self.voice_name,
59
+ ),
60
+ ),
61
+ ),
62
+ ),
63
+ )
64
+
65
+ # 不想看类型检查报错
66
+ if (
67
+ not response.candidates
68
+ or not response.candidates[0].content
69
+ or not response.candidates[0].content.parts
70
+ or not response.candidates[0].content.parts[0].inline_data
71
+ or not response.candidates[0].content.parts[0].inline_data.data
72
+ ):
73
+ raise Exception("No audio content returned from Gemini TTS API.")
74
+
75
+ with wave.open(path, "wb") as wf:
76
+ wf.setnchannels(1)
77
+ wf.setsampwidth(2)
78
+ wf.setframerate(24000)
79
+ wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
80
+
81
+ return path
@@ -0,0 +1,15 @@
1
+ from ..register import register_provider_adapter
2
+ from .openai_source import ProviderOpenAIOfficial
3
+
4
+
5
+ @register_provider_adapter(
6
+ "groq_chat_completion", "Groq Chat Completion Provider Adapter"
7
+ )
8
+ class ProviderGroq(ProviderOpenAIOfficial):
9
+ def __init__(
10
+ self,
11
+ provider_config: dict,
12
+ provider_settings: dict,
13
+ ) -> None:
14
+ super().__init__(provider_config, provider_settings)
15
+ self.reasoning_key = "reasoning"
@@ -0,0 +1,151 @@
1
+ import asyncio
2
+ import os
3
+ import uuid
4
+
5
+ import aiohttp
6
+
7
+ from astrbot import logger
8
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
9
+
10
+ from ..entities import ProviderType
11
+ from ..provider import TTSProvider
12
+ from ..register import register_provider_adapter
13
+
14
+
15
+ @register_provider_adapter(
16
+ provider_type_name="gsv_tts_selfhost",
17
+ desc="GPT-SoVITS TTS(本地加载)",
18
+ provider_type=ProviderType.TEXT_TO_SPEECH,
19
+ )
20
+ class ProviderGSVTTS(TTSProvider):
21
+ def __init__(
22
+ self,
23
+ provider_config: dict,
24
+ provider_settings: dict,
25
+ ) -> None:
26
+ super().__init__(provider_config, provider_settings)
27
+
28
+ self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
29
+ "/",
30
+ )
31
+ self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
32
+ self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
33
+
34
+ # TTS 请求的默认参数,移除前缀gsv_
35
+ self.default_params: dict = {
36
+ key.removeprefix("gsv_"): str(value).lower()
37
+ for key, value in provider_config.get("gsv_default_parms", {}).items()
38
+ }
39
+ self.timeout = provider_config.get("timeout", 60)
40
+ self._session: aiohttp.ClientSession | None = None
41
+
42
+ async def initialize(self):
43
+ """异步初始化:在 ProviderManager 中被调用"""
44
+ self._session = aiohttp.ClientSession(
45
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
46
+ )
47
+ try:
48
+ await self._set_model_weights()
49
+ logger.info("[GSV TTS] 初始化完成")
50
+ except Exception as e:
51
+ logger.error(f"[GSV TTS] 初始化失败:{e}")
52
+ raise
53
+
54
+ def get_session(self) -> aiohttp.ClientSession:
55
+ if not self._session or self._session.closed:
56
+ raise RuntimeError(
57
+ "[GSV TTS] Provider HTTP session is not ready or closed.",
58
+ )
59
+ return self._session
60
+
61
+ async def _make_request(
62
+ self,
63
+ endpoint: str,
64
+ params=None,
65
+ retries: int = 3,
66
+ ) -> bytes | None:
67
+ """发起请求"""
68
+ for attempt in range(retries):
69
+ logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
70
+ try:
71
+ async with self.get_session().get(endpoint, params=params) as response:
72
+ if response.status != 200:
73
+ error_text = await response.text()
74
+ raise Exception(
75
+ f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}",
76
+ )
77
+ return await response.read()
78
+ except Exception as e:
79
+ if attempt < retries - 1:
80
+ logger.warning(
81
+ f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中...",
82
+ )
83
+ await asyncio.sleep(1)
84
+ else:
85
+ logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
86
+ raise
87
+
88
+ async def _set_model_weights(self):
89
+ """设置模型路径"""
90
+ try:
91
+ if self.gpt_weights_path:
92
+ await self._make_request(
93
+ f"{self.api_base}/set_gpt_weights",
94
+ {"weights_path": self.gpt_weights_path},
95
+ )
96
+ logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
97
+ else:
98
+ logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
99
+
100
+ if self.sovits_weights_path:
101
+ await self._make_request(
102
+ f"{self.api_base}/set_sovits_weights",
103
+ {"weights_path": self.sovits_weights_path},
104
+ )
105
+ logger.info(
106
+ f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}",
107
+ )
108
+ else:
109
+ logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
110
+ except aiohttp.ClientError as e:
111
+ logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
112
+ except Exception as e:
113
+ logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
114
+
115
+ async def get_audio(self, text: str) -> str:
116
+ """实现 TTS 核心方法,根据文本内容自动切换情绪"""
117
+ if not text.strip():
118
+ raise ValueError("[GSV TTS] TTS 文本不能为空")
119
+
120
+ endpoint = f"{self.api_base}/tts"
121
+
122
+ params = self.build_synthesis_params(text)
123
+
124
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
125
+ os.makedirs(temp_dir, exist_ok=True)
126
+ path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
127
+
128
+ logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
129
+
130
+ result = await self._make_request(endpoint, params)
131
+ if isinstance(result, bytes):
132
+ with open(path, "wb") as f:
133
+ f.write(result)
134
+ return path
135
+ raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
136
+
137
+ def build_synthesis_params(self, text: str) -> dict:
138
+ """构建语音合成所需的参数字典。
139
+
140
+ 当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
141
+ """
142
+ params = self.default_params.copy()
143
+ params["text"] = text
144
+ # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
145
+ return params
146
+
147
+ async def terminate(self):
148
+ """终止释放资源:在 ProviderManager 中被调用"""
149
+ if self._session and not self._session.closed:
150
+ await self._session.close()
151
+ logger.info("[GSV TTS] Session 已关闭")
@@ -1,13 +1,20 @@
1
+ import os
2
+ import urllib.parse
1
3
  import uuid
4
+
2
5
  import aiohttp
3
- import urllib.parse
4
- from ..provider import TTSProvider
6
+
7
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
8
+
5
9
  from ..entities import ProviderType
10
+ from ..provider import TTSProvider
6
11
  from ..register import register_provider_adapter
7
12
 
8
13
 
9
14
  @register_provider_adapter(
10
- "gsvi_tts_api", "GSVI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
15
+ "gsvi_tts_api",
16
+ "GSVI TTS API",
17
+ provider_type=ProviderType.TEXT_TO_SPEECH,
11
18
  )
12
19
  class ProviderGSVITTS(TTSProvider):
13
20
  def __init__(
@@ -17,13 +24,13 @@ class ProviderGSVITTS(TTSProvider):
17
24
  ) -> None:
18
25
  super().__init__(provider_config, provider_settings)
19
26
  self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000")
20
- if self.api_base.endswith("/"):
21
- self.api_base = self.api_base[:-1]
27
+ self.api_base = self.api_base.removesuffix("/")
22
28
  self.character = provider_config.get("character")
23
29
  self.emotion = provider_config.get("emotion")
24
30
 
25
31
  async def get_audio(self, text: str) -> str:
26
- path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav"
32
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
33
+ path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
27
34
  params = {"text": text}
28
35
 
29
36
  if self.character:
@@ -46,7 +53,7 @@ class ProviderGSVITTS(TTSProvider):
46
53
  else:
47
54
  error_text = await response.text()
48
55
  raise Exception(
49
- f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}"
56
+ f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}",
50
57
  )
51
58
 
52
59
  return path
@@ -0,0 +1,159 @@
1
+ import json
2
+ import os
3
+ import uuid
4
+ from collections.abc import AsyncIterator
5
+
6
+ import aiohttp
7
+
8
+ from astrbot.api import logger
9
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
10
+
11
+ from ..entities import ProviderType
12
+ from ..provider import TTSProvider
13
+ from ..register import register_provider_adapter
14
+
15
+
16
+ @register_provider_adapter(
17
+ "minimax_tts_api",
18
+ "MiniMax TTS API",
19
+ provider_type=ProviderType.TEXT_TO_SPEECH,
20
+ )
21
+ class ProviderMiniMaxTTSAPI(TTSProvider):
22
+ def __init__(
23
+ self,
24
+ provider_config: dict,
25
+ provider_settings: dict,
26
+ ) -> None:
27
+ super().__init__(provider_config, provider_settings)
28
+ self.chosen_api_key: str = provider_config.get("api_key", "")
29
+ self.api_base: str = provider_config.get(
30
+ "api_base",
31
+ "https://api.minimax.chat/v1/t2a_v2",
32
+ )
33
+ self.group_id: str = provider_config.get("minimax-group-id", "")
34
+ self.set_model(provider_config.get("model", ""))
35
+ self.lang_boost: str = provider_config.get("minimax-langboost", "auto")
36
+ self.is_timber_weight: bool = provider_config.get(
37
+ "minimax-is-timber-weight",
38
+ False,
39
+ )
40
+ self.timber_weight: list[dict[str, str | int]] = json.loads(
41
+ provider_config.get(
42
+ "minimax-timber-weight",
43
+ '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]',
44
+ ),
45
+ )
46
+
47
+ self.voice_setting: dict = {
48
+ "speed": provider_config.get("minimax-voice-speed", 1.0),
49
+ "vol": provider_config.get("minimax-voice-vol", 1.0),
50
+ "pitch": provider_config.get("minimax-voice-pitch", 0),
51
+ "voice_id": ""
52
+ if self.is_timber_weight
53
+ else provider_config.get("minimax-voice-id", ""),
54
+ "emotion": provider_config.get("minimax-voice-emotion", "neutral"),
55
+ "latex_read": provider_config.get("minimax-voice-latex", False),
56
+ "english_normalization": provider_config.get(
57
+ "minimax-voice-english-normalization",
58
+ False,
59
+ ),
60
+ }
61
+
62
+ self.audio_setting: dict = {
63
+ "sample_rate": 32000,
64
+ "bitrate": 128000,
65
+ "format": "mp3",
66
+ }
67
+
68
+ self.concat_base_url: str = f"{self.api_base}?GroupId={self.group_id}"
69
+ self.headers = {
70
+ "Authorization": f"Bearer {self.chosen_api_key}",
71
+ "accept": "application/json, text/plain, */*",
72
+ "content-type": "application/json",
73
+ }
74
+
75
+ def _build_tts_stream_body(self, text: str):
76
+ """构建流式请求体"""
77
+ dict_body: dict[str, object] = {
78
+ "model": self.model_name,
79
+ "text": text,
80
+ "stream": True,
81
+ "language_boost": self.lang_boost,
82
+ "voice_setting": self.voice_setting,
83
+ "audio_setting": self.audio_setting,
84
+ }
85
+ if self.is_timber_weight:
86
+ dict_body["timber_weights"] = self.timber_weight
87
+
88
+ return json.dumps(dict_body)
89
+
90
+ async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
91
+ """进行流式请求"""
92
+ try:
93
+ async with (
94
+ aiohttp.ClientSession() as session,
95
+ session.post(
96
+ self.concat_base_url,
97
+ headers=self.headers,
98
+ data=self._build_tts_stream_body(text),
99
+ timeout=aiohttp.ClientTimeout(total=60),
100
+ ) as response,
101
+ ):
102
+ response.raise_for_status()
103
+
104
+ buffer = b""
105
+ while True:
106
+ chunk = await response.content.read(8192)
107
+ if not chunk:
108
+ break
109
+
110
+ buffer += chunk
111
+
112
+ while b"\n\n" in buffer:
113
+ try:
114
+ message, buffer = buffer.split(b"\n\n", 1)
115
+ if message.startswith(b"data: "):
116
+ try:
117
+ data = json.loads(message[6:])
118
+ if "extra_info" in data:
119
+ continue
120
+ audio = data.get("data", {}).get("audio")
121
+ if audio is not None:
122
+ yield audio
123
+ except json.JSONDecodeError:
124
+ logger.warning(
125
+ "Failed to parse JSON data from SSE message",
126
+ )
127
+ continue
128
+ except ValueError:
129
+ buffer = buffer[-1024:]
130
+
131
+ except aiohttp.ClientError as e:
132
+ raise Exception(f"MiniMax TTS API请求失败: {e!s}")
133
+
134
+ async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes:
135
+ """解码数据流到 audio 比特流"""
136
+ chunks = []
137
+ async for chunk in audio_stream:
138
+ if chunk.strip():
139
+ chunks.append(bytes.fromhex(chunk.strip()))
140
+ return b"".join(chunks)
141
+
142
+ async def get_audio(self, text: str) -> str:
143
+ temp_dir = os.path.join(get_astrbot_data_path(), "temp")
144
+ os.makedirs(temp_dir, exist_ok=True)
145
+ path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
146
+
147
+ try:
148
+ # 直接将异步生成器传递给 _audio_play 方法
149
+ audio_stream = self._call_tts_stream(text)
150
+ audio = await self._audio_play(audio_stream)
151
+
152
+ # 结果保存至文件
153
+ with open(path, "wb") as file:
154
+ file.write(audio)
155
+
156
+ return path
157
+
158
+ except aiohttp.ClientError as e:
159
+ raise e
@@ -0,0 +1,40 @@
1
+ from openai import AsyncOpenAI
2
+
3
+ from ..entities import ProviderType
4
+ from ..provider import EmbeddingProvider
5
+ from ..register import register_provider_adapter
6
+
7
+
8
+ @register_provider_adapter(
9
+ "openai_embedding",
10
+ "OpenAI API Embedding 提供商适配器",
11
+ provider_type=ProviderType.EMBEDDING,
12
+ )
13
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
14
+ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
15
+ super().__init__(provider_config, provider_settings)
16
+ self.provider_config = provider_config
17
+ self.provider_settings = provider_settings
18
+ self.client = AsyncOpenAI(
19
+ api_key=provider_config.get("embedding_api_key"),
20
+ base_url=provider_config.get(
21
+ "embedding_api_base",
22
+ "https://api.openai.com/v1",
23
+ ),
24
+ timeout=int(provider_config.get("timeout", 20)),
25
+ )
26
+ self.model = provider_config.get("embedding_model", "text-embedding-3-small")
27
+
28
+ async def get_embedding(self, text: str) -> list[float]:
29
+ """获取文本的嵌入"""
30
+ embedding = await self.client.embeddings.create(input=text, model=self.model)
31
+ return embedding.data[0].embedding
32
+
33
+ async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
34
+ """批量获取文本的嵌入"""
35
+ embeddings = await self.client.embeddings.create(input=texts, model=self.model)
36
+ return [item.embedding for item in embeddings.data]
37
+
38
+ def get_dim(self) -> int:
39
+ """获取向量的维度"""
40
+ return self.provider_config.get("embedding_dimensions", 1024)