AstrBot 4.7.4__py3-none-any.whl → 4.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. astrbot/cli/__init__.py +1 -1
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +0 -1
  3. astrbot/core/agent/tool.py +7 -2
  4. astrbot/core/astr_agent_run_util.py +15 -1
  5. astrbot/core/astr_agent_tool_exec.py +5 -1
  6. astrbot/core/config/astrbot_config.py +4 -0
  7. astrbot/core/config/default.py +116 -1
  8. astrbot/core/core_lifecycle.py +1 -1
  9. astrbot/core/db/__init__.py +32 -4
  10. astrbot/core/db/migration/migra_3_to_4.py +2 -0
  11. astrbot/core/db/migration/sqlite_v3.py +6 -4
  12. astrbot/core/db/po.py +16 -15
  13. astrbot/core/db/sqlite.py +56 -1
  14. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +2 -0
  15. astrbot/core/event_bus.py +6 -1
  16. astrbot/core/knowledge_base/retrieval/manager.py +5 -1
  17. astrbot/core/log.py +2 -1
  18. astrbot/core/message/components.py +9 -3
  19. astrbot/core/persona_mgr.py +2 -2
  20. astrbot/core/pipeline/content_safety_check/stage.py +1 -1
  21. astrbot/core/pipeline/context_utils.py +2 -1
  22. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +1 -1
  23. astrbot/core/pipeline/process_stage/method/star_request.py +1 -2
  24. astrbot/core/pipeline/process_stage/stage.py +1 -1
  25. astrbot/core/pipeline/respond/stage.py +4 -2
  26. astrbot/core/pipeline/result_decorate/stage.py +68 -21
  27. astrbot/core/pipeline/scheduler.py +5 -1
  28. astrbot/core/pipeline/waking_check/stage.py +10 -0
  29. astrbot/core/platform/astr_message_event.py +5 -3
  30. astrbot/core/platform/astrbot_message.py +2 -2
  31. astrbot/core/platform/manager.py +71 -9
  32. astrbot/core/platform/platform.py +109 -4
  33. astrbot/core/platform/platform_metadata.py +1 -1
  34. astrbot/core/platform/register.py +1 -0
  35. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +8 -6
  36. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +13 -8
  37. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +28 -22
  38. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +5 -2
  39. astrbot/core/platform/sources/discord/client.py +16 -4
  40. astrbot/core/platform/sources/discord/components.py +2 -2
  41. astrbot/core/platform/sources/discord/discord_platform_adapter.py +53 -26
  42. astrbot/core/platform/sources/discord/discord_platform_event.py +29 -8
  43. astrbot/core/platform/sources/lark/lark_adapter.py +178 -22
  44. astrbot/core/platform/sources/lark/lark_event.py +39 -4
  45. astrbot/core/platform/sources/lark/server.py +206 -0
  46. astrbot/core/platform/sources/misskey/misskey_adapter.py +3 -5
  47. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +64 -18
  48. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +14 -10
  49. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -11
  50. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +15 -2
  51. astrbot/core/platform/sources/satori/satori_adapter.py +1 -2
  52. astrbot/core/platform/sources/slack/client.py +58 -40
  53. astrbot/core/platform/sources/slack/slack_adapter.py +36 -16
  54. astrbot/core/platform/sources/slack/slack_event.py +11 -10
  55. astrbot/core/platform/sources/telegram/tg_adapter.py +2 -3
  56. astrbot/core/platform/sources/telegram/tg_event.py +23 -27
  57. astrbot/core/platform/sources/webchat/webchat_adapter.py +97 -31
  58. astrbot/core/platform/sources/webchat/webchat_event.py +35 -35
  59. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +27 -11
  60. astrbot/core/platform/sources/wecom/wecom_adapter.py +75 -36
  61. astrbot/core/platform/sources/wecom/wecom_event.py +3 -3
  62. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +26 -9
  63. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +3 -3
  64. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +27 -5
  65. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +81 -35
  66. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +11 -8
  67. astrbot/core/platform_message_history_mgr.py +3 -3
  68. astrbot/core/provider/func_tool_manager.py +3 -3
  69. astrbot/core/provider/manager.py +130 -74
  70. astrbot/core/provider/provider.py +12 -1
  71. astrbot/core/provider/sources/azure_tts_source.py +31 -9
  72. astrbot/core/provider/sources/bailian_rerank_source.py +4 -0
  73. astrbot/core/provider/sources/dashscope_tts.py +3 -2
  74. astrbot/core/provider/sources/edge_tts_source.py +1 -1
  75. astrbot/core/provider/sources/fishaudio_tts_api_source.py +5 -4
  76. astrbot/core/provider/sources/gemini_embedding_source.py +15 -5
  77. astrbot/core/provider/sources/gemini_source.py +12 -10
  78. astrbot/core/provider/sources/minimax_tts_api_source.py +4 -2
  79. astrbot/core/provider/sources/openai_embedding_source.py +2 -2
  80. astrbot/core/provider/sources/openai_source.py +4 -0
  81. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +5 -2
  82. astrbot/core/provider/sources/vllm_rerank_source.py +1 -0
  83. astrbot/core/provider/sources/whisper_api_source.py +44 -12
  84. astrbot/core/provider/sources/whisper_selfhosted_source.py +6 -2
  85. astrbot/core/provider/sources/xinference_rerank_source.py +10 -2
  86. astrbot/core/star/context.py +2 -2
  87. astrbot/core/star/register/star_handler.py +22 -5
  88. astrbot/core/star/star_handler.py +85 -4
  89. astrbot/core/updator.py +3 -3
  90. astrbot/core/utils/io.py +1 -1
  91. astrbot/core/utils/session_waiter.py +17 -10
  92. astrbot/core/utils/shared_preferences.py +32 -0
  93. astrbot/core/utils/t2i/__init__.py +2 -2
  94. astrbot/core/utils/t2i/local_strategy.py +25 -31
  95. astrbot/core/utils/tencent_record_helper.py +2 -2
  96. astrbot/core/utils/version_comparator.py +6 -3
  97. astrbot/core/utils/webhook_utils.py +66 -0
  98. astrbot/dashboard/routes/__init__.py +2 -0
  99. astrbot/dashboard/routes/chat.py +311 -76
  100. astrbot/dashboard/routes/config.py +14 -5
  101. astrbot/dashboard/routes/knowledge_base.py +254 -79
  102. astrbot/dashboard/routes/log.py +13 -8
  103. astrbot/dashboard/routes/platform.py +100 -0
  104. astrbot/dashboard/routes/plugin.py +108 -51
  105. astrbot/dashboard/routes/route.py +2 -0
  106. astrbot/dashboard/server.py +9 -4
  107. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/METADATA +50 -37
  108. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/RECORD +111 -108
  109. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/WHEEL +0 -0
  110. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/entry_points.txt +0 -0
  111. {astrbot-4.7.4.dist-info → astrbot-4.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import traceback
3
+ from typing import Protocol, runtime_checkable
3
4
 
4
5
  from astrbot.core import astrbot_config, logger, sp
5
6
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +11,7 @@ from .entities import ProviderType
10
11
  from .provider import (
11
12
  EmbeddingProvider,
12
13
  Provider,
14
+ Providers,
13
15
  RerankProvider,
14
16
  STTProvider,
15
17
  TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
17
19
  from .register import llm_tools, provider_cls_map
18
20
 
19
21
 
22
+ @runtime_checkable
23
+ class HasInitialize(Protocol):
24
+ async def initialize(self) -> None: ...
25
+
26
+
20
27
  class ProviderManager:
21
28
  def __init__(
22
29
  self,
@@ -48,7 +55,7 @@ class ProviderManager:
48
55
  """加载的 Rerank Provider 的实例"""
49
56
  self.inst_map: dict[
50
57
  str,
51
- Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
58
+ Providers,
52
59
  ] = {}
53
60
  """Provider 实例映射. key: provider_id, value: Provider 实例"""
54
61
  self.llm_tools = llm_tools
@@ -123,15 +130,13 @@ class ProviderManager:
123
130
  self.curr_provider_inst = prov
124
131
  sp.put("curr_provider", provider_id, scope="global", scope_id="global")
125
132
 
126
- async def get_provider_by_id(self, provider_id: str) -> Provider | None:
133
+ async def get_provider_by_id(self, provider_id: str) -> Providers | None:
127
134
  """根据提供商 ID 获取提供商实例"""
128
135
  return self.inst_map.get(provider_id)
129
136
 
130
137
  def get_using_provider(
131
- self,
132
- provider_type: ProviderType,
133
- umo=None,
134
- ) -> Provider | STTProvider | TTSProvider | None:
138
+ self, provider_type: ProviderType, umo=None
139
+ ) -> Providers | None:
135
140
  """获取正在使用的提供商实例。
136
141
 
137
142
  Args:
@@ -191,7 +196,6 @@ class ProviderManager:
191
196
  logger.error(traceback.format_exc())
192
197
  logger.error(e)
193
198
 
194
- # 设置默认提供商
195
199
  selected_provider_id = sp.get(
196
200
  "curr_provider",
197
201
  self.provider_settings.get("default_provider_id"),
@@ -210,15 +214,37 @@ class ProviderManager:
210
214
  scope="global",
211
215
  scope_id="global",
212
216
  )
213
- self.curr_provider_inst = self.inst_map.get(selected_provider_id)
217
+
218
+ temp_provider = (
219
+ self.inst_map.get(selected_provider_id)
220
+ if isinstance(selected_provider_id, str)
221
+ else None
222
+ )
223
+ self.curr_provider_inst = (
224
+ temp_provider if isinstance(temp_provider, Provider) else None
225
+ )
214
226
  if not self.curr_provider_inst and self.provider_insts:
215
227
  self.curr_provider_inst = self.provider_insts[0]
216
228
 
217
- self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
229
+ temp_stt = (
230
+ self.inst_map.get(selected_stt_provider_id)
231
+ if isinstance(selected_stt_provider_id, str)
232
+ else None
233
+ )
234
+ self.curr_stt_provider_inst = (
235
+ temp_stt if isinstance(temp_stt, STTProvider) else None
236
+ )
218
237
  if not self.curr_stt_provider_inst and self.stt_provider_insts:
219
238
  self.curr_stt_provider_inst = self.stt_provider_insts[0]
220
239
 
221
- self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
240
+ temp_tts = (
241
+ self.inst_map.get(selected_tts_provider_id)
242
+ if isinstance(selected_tts_provider_id, str)
243
+ else None
244
+ )
245
+ self.curr_tts_provider_inst = (
246
+ temp_tts if isinstance(temp_tts, TTSProvider) else None
247
+ )
222
248
  if not self.curr_tts_provider_inst and self.tts_provider_insts:
223
249
  self.curr_tts_provider_inst = self.tts_provider_insts[0]
224
250
 
@@ -358,73 +384,103 @@ class ProviderManager:
358
384
 
359
385
  provider_metadata.id = provider_config["id"]
360
386
 
361
- if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
362
- # STT 任务
363
- inst = cls_type(provider_config, self.provider_settings)
364
-
365
- if getattr(inst, "initialize", None):
366
- await inst.initialize()
367
-
368
- self.stt_provider_insts.append(inst)
369
- if (
370
- self.provider_stt_settings.get("provider_id")
371
- == provider_config["id"]
372
- ):
373
- self.curr_stt_provider_inst = inst
374
- logger.info(
375
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
376
- )
377
- if not self.curr_stt_provider_inst:
378
- self.curr_stt_provider_inst = inst
379
-
380
- elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
381
- # TTS 任务
382
- inst = cls_type(provider_config, self.provider_settings)
383
-
384
- if getattr(inst, "initialize", None):
385
- await inst.initialize()
386
-
387
- self.tts_provider_insts.append(inst)
388
- if self.provider_settings.get("provider_id") == provider_config["id"]:
389
- self.curr_tts_provider_inst = inst
390
- logger.info(
391
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
387
+ match provider_metadata.provider_type:
388
+ case ProviderType.SPEECH_TO_TEXT:
389
+ # STT 任务
390
+ if not issubclass(cls_type, STTProvider):
391
+ raise TypeError(
392
+ f"Provider class {cls_type} is not a subclass of STTProvider"
393
+ )
394
+ inst = cls_type(provider_config, self.provider_settings)
395
+
396
+ if isinstance(inst, HasInitialize):
397
+ await inst.initialize()
398
+
399
+ self.stt_provider_insts.append(inst)
400
+ if (
401
+ self.provider_stt_settings.get("provider_id")
402
+ == provider_config["id"]
403
+ ):
404
+ self.curr_stt_provider_inst = inst
405
+ logger.info(
406
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
407
+ )
408
+ if not self.curr_stt_provider_inst:
409
+ self.curr_stt_provider_inst = inst
410
+
411
+ case ProviderType.TEXT_TO_SPEECH:
412
+ # TTS 任务
413
+ if not issubclass(cls_type, TTSProvider):
414
+ raise TypeError(
415
+ f"Provider class {cls_type} is not a subclass of TTSProvider"
416
+ )
417
+ inst = cls_type(provider_config, self.provider_settings)
418
+
419
+ if isinstance(inst, HasInitialize):
420
+ await inst.initialize()
421
+
422
+ self.tts_provider_insts.append(inst)
423
+ if (
424
+ self.provider_settings.get("provider_id")
425
+ == provider_config["id"]
426
+ ):
427
+ self.curr_tts_provider_inst = inst
428
+ logger.info(
429
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
430
+ )
431
+ if not self.curr_tts_provider_inst:
432
+ self.curr_tts_provider_inst = inst
433
+
434
+ case ProviderType.CHAT_COMPLETION:
435
+ # 文本生成任务
436
+ if not issubclass(cls_type, Provider):
437
+ raise TypeError(
438
+ f"Provider class {cls_type} is not a subclass of Provider"
439
+ )
440
+ inst = cls_type(
441
+ provider_config,
442
+ self.provider_settings,
392
443
  )
393
- if not self.curr_tts_provider_inst:
394
- self.curr_tts_provider_inst = inst
395
-
396
- elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
397
- # 文本生成任务
398
- inst = cls_type(
399
- provider_config,
400
- self.provider_settings,
401
- )
402
444
 
403
- if getattr(inst, "initialize", None):
404
- await inst.initialize()
405
-
406
- self.provider_insts.append(inst)
407
- if (
408
- self.provider_settings.get("default_provider_id")
409
- == provider_config["id"]
410
- ):
411
- self.curr_provider_inst = inst
412
- logger.info(
413
- f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
445
+ if isinstance(inst, HasInitialize):
446
+ await inst.initialize()
447
+
448
+ self.provider_insts.append(inst)
449
+ if (
450
+ self.provider_settings.get("default_provider_id")
451
+ == provider_config["id"]
452
+ ):
453
+ self.curr_provider_inst = inst
454
+ logger.info(
455
+ f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
456
+ )
457
+ if not self.curr_provider_inst:
458
+ self.curr_provider_inst = inst
459
+
460
+ case ProviderType.EMBEDDING:
461
+ if not issubclass(cls_type, EmbeddingProvider):
462
+ raise TypeError(
463
+ f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
464
+ )
465
+ inst = cls_type(provider_config, self.provider_settings)
466
+ if isinstance(inst, HasInitialize):
467
+ await inst.initialize()
468
+ self.embedding_provider_insts.append(inst)
469
+ case ProviderType.RERANK:
470
+ if not issubclass(cls_type, RerankProvider):
471
+ raise TypeError(
472
+ f"Provider class {cls_type} is not a subclass of RerankProvider"
473
+ )
474
+ inst = cls_type(provider_config, self.provider_settings)
475
+ if isinstance(inst, HasInitialize):
476
+ await inst.initialize()
477
+ self.rerank_provider_insts.append(inst)
478
+ case _:
479
+ # 未知供应商抛出异常,确保inst初始化
480
+ # Should be unreachable
481
+ raise Exception(
482
+ f"未知的提供商类型:{provider_metadata.provider_type}"
414
483
  )
415
- if not self.curr_provider_inst:
416
- self.curr_provider_inst = inst
417
-
418
- elif provider_metadata.provider_type == ProviderType.EMBEDDING:
419
- inst = cls_type(provider_config, self.provider_settings)
420
- if getattr(inst, "initialize", None):
421
- await inst.initialize()
422
- self.embedding_provider_insts.append(inst)
423
- elif provider_metadata.provider_type == ProviderType.RERANK:
424
- inst = cls_type(provider_config, self.provider_settings)
425
- if getattr(inst, "initialize", None):
426
- await inst.initialize()
427
- self.rerank_provider_insts.append(inst)
428
484
 
429
485
  self.inst_map[provider_config["id"]] = inst
430
486
  except Exception as e:
@@ -2,6 +2,7 @@ import abc
2
2
  import asyncio
3
3
  import os
4
4
  from collections.abc import AsyncGenerator
5
+ from typing import TypeAlias, Union
5
6
 
6
7
  from astrbot.core.agent.message import Message
7
8
  from astrbot.core.agent.tool import ToolSet
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
14
15
  from astrbot.core.provider.register import provider_cls_map
15
16
  from astrbot.core.utils.astrbot_path import get_astrbot_path
16
17
 
18
+ Providers: TypeAlias = Union[
19
+ "Provider",
20
+ "STTProvider",
21
+ "TTSProvider",
22
+ "EmbeddingProvider",
23
+ "RerankProvider",
24
+ ]
25
+
17
26
 
18
27
  class AbstractProvider(abc.ABC):
19
28
  """Provider Abstract Class"""
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
142
151
  - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
143
152
 
144
153
  """
145
- ...
154
+ if False: # pragma: no cover - make this an async generator for typing
155
+ yield None # type: ignore
156
+ raise NotImplementedError()
146
157
 
147
158
  async def pop_record(self, context: list):
148
159
  """弹出 context 第一条非系统提示词对话记录"""
@@ -29,15 +29,24 @@ class OTTSProvider:
29
29
  self.last_sync_time = 0
30
30
  self.timeout = Timeout(10.0)
31
31
  self.retry_count = 3
32
- self.client = None
32
+ self._client: AsyncClient | None = None
33
+
34
+ @property
35
+ def client(self) -> AsyncClient:
36
+ if self._client is None:
37
+ raise RuntimeError(
38
+ "Client not initialized. Please use 'async with' context."
39
+ )
40
+ return self._client
33
41
 
34
42
  async def __aenter__(self):
35
- self.client = AsyncClient(timeout=self.timeout)
43
+ self._client = AsyncClient(timeout=self.timeout)
36
44
  return self
37
45
 
38
46
  async def __aexit__(self, exc_type, exc_val, exc_tb):
39
- if self.client:
40
- await self.client.aclose()
47
+ if self._client:
48
+ await self._client.aclose()
49
+ self._client = None
41
50
 
42
51
  async def _sync_time(self):
43
52
  try:
@@ -90,6 +99,7 @@ class OTTSProvider:
90
99
  if attempt == self.retry_count - 1:
91
100
  raise RuntimeError(f"OTTS请求失败: {e!s}") from e
92
101
  await asyncio.sleep(0.5 * (attempt + 1))
102
+ raise RuntimeError("OTTS未返回音频文件")
93
103
 
94
104
 
95
105
  class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
105
115
  self.endpoint = (
106
116
  f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
107
117
  )
108
- self.client = None
118
+ self._client: AsyncClient | None = None
109
119
  self.token = None
110
120
  self.token_expire = 0
111
121
  self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
116
126
  "volume": provider_config.get("azure_tts_volume", "100"),
117
127
  }
118
128
 
129
+ @property
130
+ def client(self) -> AsyncClient:
131
+ if self._client is None:
132
+ raise RuntimeError(
133
+ "Client not initialized. Please use 'async with' context."
134
+ )
135
+ return self._client
136
+
119
137
  async def __aenter__(self):
120
- self.client = AsyncClient(
138
+ self._client = AsyncClient(
121
139
  headers={
122
140
  "User-Agent": f"AstrBot/{VERSION}",
123
141
  "Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
127
145
  return self
128
146
 
129
147
  async def __aexit__(self, exc_type, exc_val, exc_tb):
130
- if self.client:
131
- await self.client.aclose()
148
+ if self._client:
149
+ await self._client.aclose()
150
+ self._client = None
132
151
 
133
152
  async def _refresh_token(self):
134
153
  token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
181
200
  key_value = provider_config.get("azure_tts_subscription_key", "")
182
201
  self.provider = self._parse_provider(key_value, provider_config)
183
202
 
184
- def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
203
+ def _parse_provider(
204
+ self, key_value: str, config: dict
205
+ ) -> OTTSProvider | AzureNativeProvider:
185
206
  if key_value.lower().startswith("other["):
207
+ json_str = ""
186
208
  try:
187
209
  match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
188
210
  if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
177
177
  Returns:
178
178
  重排序结果列表
179
179
  """
180
+ if not self.client:
181
+ logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
182
+ return []
183
+
180
184
  if not documents:
181
185
  logger.warning("文档列表为空,返回空结果")
182
186
  return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
36
36
  super().__init__(provider_config, provider_settings)
37
37
  self.chosen_api_key: str = provider_config.get("api_key", "")
38
38
  self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
39
- self.set_model(provider_config.get("model"))
39
+ self.set_model(provider_config["model"])
40
40
  self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
41
41
  dashscope.api_key = self.chosen_api_key
42
42
 
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
71
71
 
72
72
  kwargs = {
73
73
  "model": model,
74
- "text": text,
74
+ "messages": None,
75
75
  "api_key": self.chosen_api_key,
76
76
  "voice": self.voice or "Cherry",
77
+ "text": text,
77
78
  }
78
79
  if not self.voice:
79
80
  logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
67
67
  from pyffmpeg import FFmpeg
68
68
 
69
69
  ff = FFmpeg()
70
- ff.convert(input=mp3_path, output=wav_path)
70
+ ff.convert(input_file=mp3_path, output_file=wav_path)
71
71
  except Exception as e:
72
72
  logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
73
73
  # use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
59
59
  self.headers = {
60
60
  "Authorization": f"Bearer {self.chosen_api_key}",
61
61
  }
62
- self.set_model(provider_config.get("model"))
62
+ self.set_model(provider_config["model"])
63
63
 
64
- async def _get_reference_id_by_character(self, character: str) -> str:
64
+ async def _get_reference_id_by_character(self, character: str) -> str | None:
65
65
  """获取角色的reference_id
66
66
 
67
67
  Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
109
109
  pattern = r"^[a-fA-F0-9]{32}$"
110
110
  return bool(re.match(pattern, reference_id.strip()))
111
111
 
112
- async def _generate_request(self, text: str) -> dict:
112
+ async def _generate_request(self, text: str) -> ServeTTSRequest:
113
113
  # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
114
114
  if self.reference_id and self.reference_id.strip():
115
115
  # 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
146
146
  async for chunk in response.aiter_bytes():
147
147
  f.write(chunk)
148
148
  return path
149
- text = await response.aread()
149
+ body = await response.aread()
150
+ text = body.decode("utf-8", errors="replace")
150
151
  raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
1
+ from typing import cast
2
+
1
3
  from google import genai
2
4
  from google.genai import types
3
5
  from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
18
20
  self.provider_config = provider_config
19
21
  self.provider_settings = provider_settings
20
22
 
21
- api_key: str = provider_config.get("embedding_api_key")
22
- api_base: str = provider_config.get("embedding_api_base")
23
+ api_key: str = provider_config["embedding_api_key"]
24
+ api_base: str = provider_config["embedding_api_base"]
23
25
  timeout: int = int(provider_config.get("timeout", 20))
24
26
 
25
27
  http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
41
43
  model=self.model,
42
44
  contents=text,
43
45
  )
46
+ assert result.embeddings is not None
47
+ assert result.embeddings[0].values is not None
44
48
  return result.embeddings[0].values
45
49
  except APIError as e:
46
50
  raise Exception(f"Gemini Embedding API请求失败: {e.message}")
47
51
 
48
- async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
52
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
49
53
  """批量获取文本的嵌入"""
50
54
  try:
51
55
  result = await self.client.models.embed_content(
52
56
  model=self.model,
53
- contents=texts,
57
+ contents=cast(types.ContentListUnion, text),
54
58
  )
55
- return [embedding.values for embedding in result.embeddings]
59
+ assert result.embeddings is not None
60
+
61
+ embeddings: list[list[float]] = []
62
+ for embedding in result.embeddings:
63
+ assert embedding.values is not None
64
+ embeddings.append(embedding.values)
65
+ return embeddings
56
66
  except APIError as e:
57
67
  raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
58
68
 
@@ -4,6 +4,7 @@ import json
4
4
  import logging
5
5
  import random
6
6
  from collections.abc import AsyncGenerator
7
+ from typing import cast
7
8
 
8
9
  from google import genai
9
10
  from google.genai import types
@@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider):
126
127
  ) -> types.GenerateContentConfig:
127
128
  """准备查询配置"""
128
129
  if not modalities:
129
- modalities = ["Text"]
130
+ modalities = ["TEXT"]
130
131
 
131
132
  # 流式输出不支持图片模态
132
133
  if (
133
134
  self.provider_settings.get("streaming_response", False)
134
- and "Image" in modalities
135
+ and "IMAGE" in modalities
135
136
  ):
136
137
  logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
137
- modalities = ["Text"]
138
+ modalities = ["TEXT"]
138
139
 
139
- tool_list = []
140
+ tool_list: list[types.Tool] | None = []
140
141
  model_name = self.get_model()
141
142
  native_coderunner = self.provider_config.get("gm_native_coderunner", False)
142
143
  native_search = self.provider_config.get("gm_native_search", False)
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
213
214
  logprobs=payloads.get("logprobs"),
214
215
  seed=payloads.get("seed"),
215
216
  response_modalities=modalities,
216
- tools=tool_list,
217
+ tools=cast(types.ToolListUnion | None, tool_list),
217
218
  safety_settings=self.safety_settings if self.safety_settings else None,
218
219
  thinking_config=(
219
220
  types.ThinkingConfig(
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
257
258
  content_cls: type[types.Content],
258
259
  ) -> None:
259
260
  if contents and isinstance(contents[-1], content_cls):
261
+ assert contents[-1].parts is not None
260
262
  contents[-1].parts.extend(part)
261
263
  else:
262
264
  contents.append(content_cls(parts=part))
@@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider):
429
431
  None,
430
432
  )
431
433
 
432
- modalities = ["Text"]
434
+ modalities = ["TEXT"]
433
435
  if self.provider_config.get("gm_resp_image_modal", False):
434
- modalities.append("Image")
436
+ modalities.append("IMAGE")
435
437
 
436
438
  conversation = self._prepare_conversation(payloads)
437
439
  temperature = payloads.get("temperature", 0.7)
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
448
450
  )
449
451
  result = await self.client.models.generate_content(
450
452
  model=self.get_model(),
451
- contents=conversation,
453
+ contents=cast(types.ContentListUnion, conversation),
452
454
  config=config,
453
455
  )
454
456
  logger.debug(f"genai result: {result}")
@@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider):
488
490
  logger.warning(
489
491
  f"{self.get_model()} 不支持多模态输出,降级为文本模态",
490
492
  )
491
- modalities = ["Text"]
493
+ modalities = ["TEXT"]
492
494
  else:
493
495
  raise
494
496
  continue
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
524
526
  )
525
527
  result = await self.client.models.generate_content_stream(
526
528
  model=self.get_model(),
527
- contents=conversation,
529
+ contents=cast(types.ContentListUnion, conversation),
528
530
  config=config,
529
531
  )
530
532
  break
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
87
87
 
88
88
  return json.dumps(dict_body)
89
89
 
90
- async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
90
+ async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
91
91
  """进行流式请求"""
92
92
  try:
93
93
  async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
117
117
  data = json.loads(message[6:])
118
118
  if "extra_info" in data:
119
119
  continue
120
- audio = data.get("data", {}).get("audio")
120
+ audio: str | None = data.get("data", {}).get(
121
+ "audio"
122
+ )
121
123
  if audio is not None:
122
124
  yield audio
123
125
  except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
30
30
  embedding = await self.client.embeddings.create(input=text, model=self.model)
31
31
  return embedding.data[0].embedding
32
32
 
33
- async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
33
+ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
34
34
  """批量获取文本的嵌入"""
35
- embeddings = await self.client.embeddings.create(input=texts, model=self.model)
35
+ embeddings = await self.client.embeddings.create(input=text, model=self.model)
36
36
  return [item.embedding for item in embeddings.data]
37
37
 
38
38
  def get_dim(self) -> int:
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
284
284
  if isinstance(tool_call, str):
285
285
  # workaround for #1359
286
286
  tool_call = json.loads(tool_call)
287
+ if tools is None:
288
+ # 工具集未提供
289
+ # Should be unreachable
290
+ raise Exception("工具集未提供")
287
291
  for tool in tools.func_list:
288
292
  if (
289
293
  tool_call.type == "function"