AstrBot 4.8.0__py3-none-any.whl → 4.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. astrbot/cli/__init__.py +1 -1
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
  3. astrbot/core/agent/tool.py +7 -2
  4. astrbot/core/astr_agent_tool_exec.py +5 -1
  5. astrbot/core/config/astrbot_config.py +4 -0
  6. astrbot/core/config/default.py +72 -1
  7. astrbot/core/config/i18n_utils.py +1 -0
  8. astrbot/core/core_lifecycle.py +1 -1
  9. astrbot/core/db/__init__.py +2 -3
  10. astrbot/core/db/migration/migra_3_to_4.py +2 -0
  11. astrbot/core/db/migration/sqlite_v3.py +6 -4
  12. astrbot/core/db/po.py +16 -15
  13. astrbot/core/db/sqlite.py +4 -3
  14. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
  15. astrbot/core/event_bus.py +6 -1
  16. astrbot/core/knowledge_base/retrieval/manager.py +5 -1
  17. astrbot/core/log.py +2 -1
  18. astrbot/core/message/components.py +9 -3
  19. astrbot/core/persona_mgr.py +2 -2
  20. astrbot/core/pipeline/content_safety_check/stage.py +1 -1
  21. astrbot/core/pipeline/context_utils.py +2 -1
  22. astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
  23. astrbot/core/pipeline/process_stage/stage.py +1 -1
  24. astrbot/core/pipeline/respond/stage.py +8 -2
  25. astrbot/core/pipeline/result_decorate/stage.py +89 -22
  26. astrbot/core/pipeline/scheduler.py +5 -1
  27. astrbot/core/pipeline/waking_check/stage.py +10 -0
  28. astrbot/core/platform/astr_message_event.py +5 -3
  29. astrbot/core/platform/astrbot_message.py +2 -2
  30. astrbot/core/platform/manager.py +4 -0
  31. astrbot/core/platform/platform.py +11 -3
  32. astrbot/core/platform/platform_metadata.py +1 -1
  33. astrbot/core/platform/register.py +1 -0
  34. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
  35. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +9 -5
  36. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +24 -16
  37. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
  38. astrbot/core/platform/sources/discord/client.py +16 -4
  39. astrbot/core/platform/sources/discord/components.py +2 -2
  40. astrbot/core/platform/sources/discord/discord_platform_adapter.py +52 -24
  41. astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
  42. astrbot/core/platform/sources/lark/lark_adapter.py +183 -20
  43. astrbot/core/platform/sources/lark/lark_event.py +39 -4
  44. astrbot/core/platform/sources/lark/server.py +206 -0
  45. astrbot/core/platform/sources/misskey/misskey_adapter.py +2 -3
  46. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +62 -18
  47. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +13 -7
  48. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +5 -3
  49. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  50. astrbot/core/platform/sources/slack/client.py +9 -2
  51. astrbot/core/platform/sources/slack/slack_adapter.py +15 -9
  52. astrbot/core/platform/sources/slack/slack_event.py +8 -7
  53. astrbot/core/platform/sources/telegram/tg_adapter.py +1 -1
  54. astrbot/core/platform/sources/telegram/tg_event.py +23 -27
  55. astrbot/core/platform/sources/webchat/webchat_adapter.py +2 -2
  56. astrbot/core/platform/sources/webchat/webchat_event.py +2 -2
  57. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +26 -9
  58. astrbot/core/platform/sources/wecom/wecom_adapter.py +25 -28
  59. astrbot/core/platform/sources/wecom/wecom_event.py +2 -2
  60. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
  61. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +30 -25
  62. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +10 -7
  63. astrbot/core/provider/func_tool_manager.py +3 -3
  64. astrbot/core/provider/manager.py +130 -74
  65. astrbot/core/provider/provider.py +12 -1
  66. astrbot/core/provider/sources/azure_tts_source.py +31 -9
  67. astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
  68. astrbot/core/provider/sources/dashscope_tts.py +3 -2
  69. astrbot/core/provider/sources/edge_tts_source.py +1 -1
  70. astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
  71. astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
  72. astrbot/core/provider/sources/gemini_source.py +12 -10
  73. astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
  74. astrbot/core/provider/sources/openai_embedding_source.py +2 -2
  75. astrbot/core/provider/sources/openai_source.py +4 -0
  76. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
  77. astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
  78. astrbot/core/provider/sources/whisper_api_source.py +1 -1
  79. astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
  80. astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
  81. astrbot/core/star/context.py +2 -2
  82. astrbot/core/star/register/star_handler.py +22 -5
  83. astrbot/core/star/star_handler.py +85 -4
  84. astrbot/core/updator.py +3 -3
  85. astrbot/core/utils/io.py +1 -1
  86. astrbot/core/utils/session_waiter.py +17 -10
  87. astrbot/core/utils/shared_preferences.py +32 -0
  88. astrbot/core/utils/t2i/__init__.py +2 -2
  89. astrbot/core/utils/t2i/local_strategy.py +25 -31
  90. astrbot/core/utils/tencent_record_helper.py +1 -1
  91. astrbot/core/utils/version_comparator.py +6 -3
  92. astrbot/core/utils/webhook_utils.py +19 -0
  93. astrbot/dashboard/routes/chat.py +14 -9
  94. astrbot/dashboard/routes/config.py +10 -20
  95. astrbot/dashboard/routes/conversation.py +91 -1
  96. astrbot/dashboard/routes/knowledge_base.py +253 -78
  97. astrbot/dashboard/routes/log.py +13 -8
  98. astrbot/dashboard/routes/platform.py +1 -1
  99. astrbot/dashboard/routes/plugin.py +113 -52
  100. astrbot/dashboard/routes/route.py +2 -0
  101. astrbot/dashboard/server.py +6 -3
  102. {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/METADATA +9 -1
  103. {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/RECORD +106 -105
  104. {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/WHEEL +0 -0
  105. {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/entry_points.txt +0 -0
  106. {astrbot-4.8.0.dist-info → astrbot-4.9.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import uuid
3
+ from typing import cast
3
4
 
4
5
  from wechatpy import WeChatClient
5
6
  from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
85
86
 
86
87
  async def send(self, message: MessageChain):
87
88
  message_obj = self.message_obj
88
- active_send_mode = message_obj.raw_message.get("active_send_mode", False)
89
+ active_send_mode = cast(dict, message_obj.raw_message).get(
90
+ "active_send_mode", False
91
+ )
89
92
  for comp in message.chain:
90
93
  if isinstance(comp, Plain):
91
94
  # Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
96
99
  else:
97
100
  reply = TextReply(
98
101
  content=chunk,
99
- message=self.message_obj.raw_message["message"],
102
+ message=cast(dict, self.message_obj.raw_message)["message"],
100
103
  )
101
104
  xml = reply.render()
102
- future = self.message_obj.raw_message["future"]
105
+ future = cast(dict, self.message_obj.raw_message)["future"]
103
106
  assert isinstance(future, asyncio.Future)
104
107
  future.set_result(xml)
105
108
  await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
125
128
  else:
126
129
  reply = ImageReply(
127
130
  media_id=response["media_id"],
128
- message=self.message_obj.raw_message["message"],
131
+ message=cast(dict, self.message_obj.raw_message)["message"],
129
132
  )
130
133
  xml = reply.render()
131
- future = self.message_obj.raw_message["future"]
134
+ future = cast(dict, self.message_obj.raw_message)["future"]
132
135
  assert isinstance(future, asyncio.Future)
133
136
  future.set_result(xml)
134
137
 
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
160
163
  else:
161
164
  reply = VoiceReply(
162
165
  media_id=response["media_id"],
163
- message=self.message_obj.raw_message["message"],
166
+ message=cast(dict, self.message_obj.raw_message)["message"],
164
167
  )
165
168
  xml = reply.render()
166
- future = self.message_obj.raw_message["future"]
169
+ future = cast(dict, self.message_obj.raw_message)["future"]
167
170
  assert isinstance(future, asyncio.Future)
168
171
  future.set_result(xml)
169
172
 
@@ -4,7 +4,7 @@ import asyncio
4
4
  import copy
5
5
  import json
6
6
  import os
7
- from collections.abc import Awaitable, Callable
7
+ from collections.abc import AsyncGenerator, Awaitable, Callable
8
8
  from typing import Any
9
9
 
10
10
  import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
118
118
  name: str,
119
119
  func_args: list[dict],
120
120
  desc: str,
121
- handler: Callable[..., Awaitable[Any]],
121
+ handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
122
122
  ) -> FuncTool:
123
123
  params = {
124
124
  "type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
140
140
  name: str,
141
141
  func_args: list,
142
142
  desc: str,
143
- handler: Callable[..., Awaitable[Any]],
143
+ handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
144
144
  ) -> None:
145
145
  """添加函数调用工具
146
146
 
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import traceback
3
+ from typing import Protocol, runtime_checkable
3
4
 
4
5
  from astrbot.core import astrbot_config, logger, sp
5
6
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +11,7 @@ from .entities import ProviderType
10
11
  from .provider import (
11
12
  EmbeddingProvider,
12
13
  Provider,
14
+ Providers,
13
15
  RerankProvider,
14
16
  STTProvider,
15
17
  TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
17
19
  from .register import llm_tools, provider_cls_map
18
20
 
19
21
 
22
+ @runtime_checkable
23
+ class HasInitialize(Protocol):
24
+ async def initialize(self) -> None: ...
25
+
26
+
20
27
  class ProviderManager:
21
28
  def __init__(
22
29
  self,
@@ -48,7 +55,7 @@ class ProviderManager:
48
55
  """加载的 Rerank Provider 的实例"""
49
56
  self.inst_map: dict[
50
57
  str,
51
- Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
58
+ Providers,
52
59
  ] = {}
53
60
  """Provider 实例映射. key: provider_id, value: Provider 实例"""
54
61
  self.llm_tools = llm_tools
@@ -123,15 +130,13 @@ class ProviderManager:
123
130
  self.curr_provider_inst = prov
124
131
  sp.put("curr_provider", provider_id, scope="global", scope_id="global")
125
132
 
126
- async def get_provider_by_id(self, provider_id: str) -> Provider | None:
133
+ async def get_provider_by_id(self, provider_id: str) -> Providers | None:
127
134
  """根据提供商 ID 获取提供商实例"""
128
135
  return self.inst_map.get(provider_id)
129
136
 
130
137
  def get_using_provider(
131
- self,
132
- provider_type: ProviderType,
133
- umo=None,
134
- ) -> Provider | STTProvider | TTSProvider | None:
138
+ self, provider_type: ProviderType, umo=None
139
+ ) -> Providers | None:
135
140
  """获取正在使用的提供商实例。
136
141
 
137
142
  Args:
@@ -191,7 +196,6 @@ class ProviderManager:
191
196
  logger.error(traceback.format_exc())
192
197
  logger.error(e)
193
198
 
194
- # 设置默认提供商
195
199
  selected_provider_id = sp.get(
196
200
  "curr_provider",
197
201
  self.provider_settings.get("default_provider_id"),
@@ -210,15 +214,37 @@ class ProviderManager:
210
214
  scope="global",
211
215
  scope_id="global",
212
216
  )
213
- self.curr_provider_inst = self.inst_map.get(selected_provider_id)
217
+
218
+ temp_provider = (
219
+ self.inst_map.get(selected_provider_id)
220
+ if isinstance(selected_provider_id, str)
221
+ else None
222
+ )
223
+ self.curr_provider_inst = (
224
+ temp_provider if isinstance(temp_provider, Provider) else None
225
+ )
214
226
  if not self.curr_provider_inst and self.provider_insts:
215
227
  self.curr_provider_inst = self.provider_insts[0]
216
228
 
217
- self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
229
+ temp_stt = (
230
+ self.inst_map.get(selected_stt_provider_id)
231
+ if isinstance(selected_stt_provider_id, str)
232
+ else None
233
+ )
234
+ self.curr_stt_provider_inst = (
235
+ temp_stt if isinstance(temp_stt, STTProvider) else None
236
+ )
218
237
  if not self.curr_stt_provider_inst and self.stt_provider_insts:
219
238
  self.curr_stt_provider_inst = self.stt_provider_insts[0]
220
239
 
221
- self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
240
+ temp_tts = (
241
+ self.inst_map.get(selected_tts_provider_id)
242
+ if isinstance(selected_tts_provider_id, str)
243
+ else None
244
+ )
245
+ self.curr_tts_provider_inst = (
246
+ temp_tts if isinstance(temp_tts, TTSProvider) else None
247
+ )
222
248
  if not self.curr_tts_provider_inst and self.tts_provider_insts:
223
249
  self.curr_tts_provider_inst = self.tts_provider_insts[0]
224
250
 
@@ -358,73 +384,103 @@ class ProviderManager:
358
384
 
359
385
  provider_metadata.id = provider_config["id"]
360
386
 
361
- if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
362
- # STT 任务
363
- inst = cls_type(provider_config, self.provider_settings)
364
-
365
- if getattr(inst, "initialize", None):
366
- await inst.initialize()
367
-
368
- self.stt_provider_insts.append(inst)
369
- if (
370
- self.provider_stt_settings.get("provider_id")
371
- == provider_config["id"]
372
- ):
373
- self.curr_stt_provider_inst = inst
374
- logger.info(
375
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
376
- )
377
- if not self.curr_stt_provider_inst:
378
- self.curr_stt_provider_inst = inst
379
-
380
- elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
381
- # TTS 任务
382
- inst = cls_type(provider_config, self.provider_settings)
383
-
384
- if getattr(inst, "initialize", None):
385
- await inst.initialize()
386
-
387
- self.tts_provider_insts.append(inst)
388
- if self.provider_settings.get("provider_id") == provider_config["id"]:
389
- self.curr_tts_provider_inst = inst
390
- logger.info(
391
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
387
+ match provider_metadata.provider_type:
388
+ case ProviderType.SPEECH_TO_TEXT:
389
+ # STT 任务
390
+ if not issubclass(cls_type, STTProvider):
391
+ raise TypeError(
392
+ f"Provider class {cls_type} is not a subclass of STTProvider"
393
+ )
394
+ inst = cls_type(provider_config, self.provider_settings)
395
+
396
+ if isinstance(inst, HasInitialize):
397
+ await inst.initialize()
398
+
399
+ self.stt_provider_insts.append(inst)
400
+ if (
401
+ self.provider_stt_settings.get("provider_id")
402
+ == provider_config["id"]
403
+ ):
404
+ self.curr_stt_provider_inst = inst
405
+ logger.info(
406
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
407
+ )
408
+ if not self.curr_stt_provider_inst:
409
+ self.curr_stt_provider_inst = inst
410
+
411
+ case ProviderType.TEXT_TO_SPEECH:
412
+ # TTS 任务
413
+ if not issubclass(cls_type, TTSProvider):
414
+ raise TypeError(
415
+ f"Provider class {cls_type} is not a subclass of TTSProvider"
416
+ )
417
+ inst = cls_type(provider_config, self.provider_settings)
418
+
419
+ if isinstance(inst, HasInitialize):
420
+ await inst.initialize()
421
+
422
+ self.tts_provider_insts.append(inst)
423
+ if (
424
+ self.provider_settings.get("provider_id")
425
+ == provider_config["id"]
426
+ ):
427
+ self.curr_tts_provider_inst = inst
428
+ logger.info(
429
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
430
+ )
431
+ if not self.curr_tts_provider_inst:
432
+ self.curr_tts_provider_inst = inst
433
+
434
+ case ProviderType.CHAT_COMPLETION:
435
+ # 文本生成任务
436
+ if not issubclass(cls_type, Provider):
437
+ raise TypeError(
438
+ f"Provider class {cls_type} is not a subclass of Provider"
439
+ )
440
+ inst = cls_type(
441
+ provider_config,
442
+ self.provider_settings,
392
443
  )
393
- if not self.curr_tts_provider_inst:
394
- self.curr_tts_provider_inst = inst
395
-
396
- elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
397
- # 文本生成任务
398
- inst = cls_type(
399
- provider_config,
400
- self.provider_settings,
401
- )
402
444
 
403
- if getattr(inst, "initialize", None):
404
- await inst.initialize()
405
-
406
- self.provider_insts.append(inst)
407
- if (
408
- self.provider_settings.get("default_provider_id")
409
- == provider_config["id"]
410
- ):
411
- self.curr_provider_inst = inst
412
- logger.info(
413
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
445
+ if isinstance(inst, HasInitialize):
446
+ await inst.initialize()
447
+
448
+ self.provider_insts.append(inst)
449
+ if (
450
+ self.provider_settings.get("default_provider_id")
451
+ == provider_config["id"]
452
+ ):
453
+ self.curr_provider_inst = inst
454
+ logger.info(
455
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
456
+ )
457
+ if not self.curr_provider_inst:
458
+ self.curr_provider_inst = inst
459
+
460
+ case ProviderType.EMBEDDING:
461
+ if not issubclass(cls_type, EmbeddingProvider):
462
+ raise TypeError(
463
+ f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
464
+ )
465
+ inst = cls_type(provider_config, self.provider_settings)
466
+ if isinstance(inst, HasInitialize):
467
+ await inst.initialize()
468
+ self.embedding_provider_insts.append(inst)
469
+ case ProviderType.RERANK:
470
+ if not issubclass(cls_type, RerankProvider):
471
+ raise TypeError(
472
+ f"Provider class {cls_type} is not a subclass of RerankProvider"
473
+ )
474
+ inst = cls_type(provider_config, self.provider_settings)
475
+ if isinstance(inst, HasInitialize):
476
+ await inst.initialize()
477
+ self.rerank_provider_insts.append(inst)
478
+ case _:
479
+ # 未知供应商抛出异常,确保inst初始化
480
+ # Should be unreachable
481
+ raise Exception(
482
+ f"未知的提供商类型:{provider_metadata.provider_type}"
414
483
  )
415
- if not self.curr_provider_inst:
416
- self.curr_provider_inst = inst
417
-
418
- elif provider_metadata.provider_type == ProviderType.EMBEDDING:
419
- inst = cls_type(provider_config, self.provider_settings)
420
- if getattr(inst, "initialize", None):
421
- await inst.initialize()
422
- self.embedding_provider_insts.append(inst)
423
- elif provider_metadata.provider_type == ProviderType.RERANK:
424
- inst = cls_type(provider_config, self.provider_settings)
425
- if getattr(inst, "initialize", None):
426
- await inst.initialize()
427
- self.rerank_provider_insts.append(inst)
428
484
 
429
485
  self.inst_map[provider_config["id"]] = inst
430
486
  except Exception as e:
@@ -2,6 +2,7 @@ import abc
2
2
  import asyncio
3
3
  import os
4
4
  from collections.abc import AsyncGenerator
5
+ from typing import TypeAlias, Union
5
6
 
6
7
  from astrbot.core.agent.message import Message
7
8
  from astrbot.core.agent.tool import ToolSet
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
14
15
  from astrbot.core.provider.register import provider_cls_map
15
16
  from astrbot.core.utils.astrbot_path import get_astrbot_path
16
17
 
18
+ Providers: TypeAlias = Union[
19
+ "Provider",
20
+ "STTProvider",
21
+ "TTSProvider",
22
+ "EmbeddingProvider",
23
+ "RerankProvider",
24
+ ]
25
+
17
26
 
18
27
  class AbstractProvider(abc.ABC):
19
28
  """Provider Abstract Class"""
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
142
151
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
143
152
 
144
153
  """
145
- ...
154
+ if False: # pragma: no cover - make this an async generator for typing
155
+ yield None # type: ignore
156
+ raise NotImplementedError()
146
157
 
147
158
  async def pop_record(self, context: list):
148
159
  """弹出 context 第一条非系统提示词对话记录"""
@@ -29,15 +29,24 @@ class OTTSProvider:
29
29
  self.last_sync_time = 0
30
30
  self.timeout = Timeout(10.0)
31
31
  self.retry_count = 3
32
- self.client = None
32
+ self._client: AsyncClient | None = None
33
+
34
+ @property
35
+ def client(self) -> AsyncClient:
36
+ if self._client is None:
37
+ raise RuntimeError(
38
+ "Client not initialized. Please use 'async with' context."
39
+ )
40
+ return self._client
33
41
 
34
42
  async def __aenter__(self):
35
- self.client = AsyncClient(timeout=self.timeout)
43
+ self._client = AsyncClient(timeout=self.timeout)
36
44
  return self
37
45
 
38
46
  async def __aexit__(self, exc_type, exc_val, exc_tb):
39
- if self.client:
40
- await self.client.aclose()
47
+ if self._client:
48
+ await self._client.aclose()
49
+ self._client = None
41
50
 
42
51
  async def _sync_time(self):
43
52
  try:
@@ -90,6 +99,7 @@ class OTTSProvider:
90
99
  if attempt == self.retry_count - 1:
91
100
  raise RuntimeError(f"OTTS请求失败: {e!s}") from e
92
101
  await asyncio.sleep(0.5 * (attempt + 1))
102
+ raise RuntimeError("OTTS未返回音频文件")
93
103
 
94
104
 
95
105
  class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
105
115
  self.endpoint = (
106
116
  f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
107
117
  )
108
- self.client = None
118
+ self._client: AsyncClient | None = None
109
119
  self.token = None
110
120
  self.token_expire = 0
111
121
  self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
116
126
  "volume": provider_config.get("azure_tts_volume", "100"),
117
127
  }
118
128
 
129
+ @property
130
+ def client(self) -> AsyncClient:
131
+ if self._client is None:
132
+ raise RuntimeError(
133
+ "Client not initialized. Please use 'async with' context."
134
+ )
135
+ return self._client
136
+
119
137
  async def __aenter__(self):
120
- self.client = AsyncClient(
138
+ self._client = AsyncClient(
121
139
  headers={
122
140
  "User-Agent": f"AstrBot/{VERSION}",
123
141
  "Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
127
145
  return self
128
146
 
129
147
  async def __aexit__(self, exc_type, exc_val, exc_tb):
130
- if self.client:
131
- await self.client.aclose()
148
+ if self._client:
149
+ await self._client.aclose()
150
+ self._client = None
132
151
 
133
152
  async def _refresh_token(self):
134
153
  token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
181
200
  key_value = provider_config.get("azure_tts_subscription_key", "")
182
201
  self.provider = self._parse_provider(key_value, provider_config)
183
202
 
184
- def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
203
+ def _parse_provider(
204
+ self, key_value: str, config: dict
205
+ ) -> OTTSProvider | AzureNativeProvider:
185
206
  if key_value.lower().startswith("other["):
207
+ json_str = ""
186
208
  try:
187
209
  match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
188
210
  if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
177
177
  Returns:
178
178
  重排序结果列表
179
179
  """
180
+ if not self.client:
181
+ logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
182
+ return []
183
+
180
184
  if not documents:
181
185
  logger.warning("文档列表为空,返回空结果")
182
186
  return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
36
36
  super().__init__(provider_config, provider_settings)
37
37
  self.chosen_api_key: str = provider_config.get("api_key", "")
38
38
  self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
39
- self.set_model(provider_config.get("model"))
39
+ self.set_model(provider_config["model"])
40
40
  self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
41
41
  dashscope.api_key = self.chosen_api_key
42
42
 
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
71
71
 
72
72
  kwargs = {
73
73
  "model": model,
74
- "text": text,
74
+ "messages": None,
75
75
  "api_key": self.chosen_api_key,
76
76
  "voice": self.voice or "Cherry",
77
+ "text": text,
77
78
  }
78
79
  if not self.voice:
79
80
  logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
67
67
  from pyffmpeg import FFmpeg
68
68
 
69
69
  ff = FFmpeg()
70
- ff.convert(input=mp3_path, output=wav_path)
70
+ ff.convert(input_file=mp3_path, output_file=wav_path)
71
71
  except Exception as e:
72
72
  logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
73
73
  # use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
59
59
  self.headers = {
60
60
  "Authorization": f"Bearer {self.chosen_api_key}",
61
61
  }
62
- self.set_model(provider_config.get("model"))
62
+ self.set_model(provider_config["model"])
63
63
 
64
- async def _get_reference_id_by_character(self, character: str) -> str:
64
+ async def _get_reference_id_by_character(self, character: str) -> str | None:
65
65
  """获取角色的reference_id
66
66
 
67
67
  Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
109
109
  pattern = r"^[a-fA-F0-9]{32}$"
110
110
  return bool(re.match(pattern, reference_id.strip()))
111
111
 
112
- async def _generate_request(self, text: str) -> dict:
112
+ async def _generate_request(self, text: str) -> ServeTTSRequest:
113
113
  # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
114
114
  if self.reference_id and self.reference_id.strip():
115
115
  # 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
146
146
  async for chunk in response.aiter_bytes():
147
147
  f.write(chunk)
148
148
  return path
149
- text = await response.aread()
149
+ body = await response.aread()
150
+ text = body.decode("utf-8", errors="replace")
150
151
  raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
1
+ from typing import cast
2
+
1
3
  from google import genai
2
4
  from google.genai import types
3
5
  from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
18
20
  self.provider_config = provider_config
19
21
  self.provider_settings = provider_settings
20
22
 
21
- api_key: str = provider_config.get("embedding_api_key")
22
- api_base: str = provider_config.get("embedding_api_base")
23
+ api_key: str = provider_config["embedding_api_key"]
24
+ api_base: str = provider_config["embedding_api_base"]
23
25
  timeout: int = int(provider_config.get("timeout", 20))
24
26
 
25
27
  http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
41
43
  model=self.model,
42
44
  contents=text,
43
45
  )
46
+ assert result.embeddings is not None
47
+ assert result.embeddings[0].values is not None
44
48
  return result.embeddings[0].values
45
49
  except APIError as e:
46
50
  raise Exception(f"Gemini Embedding API请求失败: {e.message}")
47
51
 
48
- async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
52
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
49
53
  """批量获取文本的嵌入"""
50
54
  try:
51
55
  result = await self.client.models.embed_content(
52
56
  model=self.model,
53
- contents=texts,
57
+ contents=cast(types.ContentListUnion, text),
54
58
  )
55
- return [embedding.values for embedding in result.embeddings]
59
+ assert result.embeddings is not None
60
+
61
+ embeddings: list[list[float]] = []
62
+ for embedding in result.embeddings:
63
+ assert embedding.values is not None
64
+ embeddings.append(embedding.values)
65
+ return embeddings
56
66
  except APIError as e:
57
67
  raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
58
68