AstrBot 4.1.3__py3-none-any.whl → 4.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. astrbot/core/agent/agent.py +1 -1
  2. astrbot/core/agent/mcp_client.py +3 -1
  3. astrbot/core/agent/runners/tool_loop_agent_runner.py +6 -27
  4. astrbot/core/agent/tool.py +28 -17
  5. astrbot/core/config/default.py +50 -14
  6. astrbot/core/db/sqlite.py +15 -1
  7. astrbot/core/pipeline/content_safety_check/stage.py +1 -1
  8. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +1 -1
  9. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -1
  10. astrbot/core/pipeline/context_utils.py +4 -1
  11. astrbot/core/pipeline/process_stage/method/llm_request.py +23 -4
  12. astrbot/core/pipeline/process_stage/method/star_request.py +8 -6
  13. astrbot/core/platform/manager.py +4 -0
  14. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -1
  15. astrbot/core/platform/sources/misskey/misskey_adapter.py +391 -0
  16. astrbot/core/platform/sources/misskey/misskey_api.py +404 -0
  17. astrbot/core/platform/sources/misskey/misskey_event.py +123 -0
  18. astrbot/core/platform/sources/misskey/misskey_utils.py +327 -0
  19. astrbot/core/platform/sources/satori/satori_adapter.py +290 -24
  20. astrbot/core/platform/sources/satori/satori_event.py +9 -0
  21. astrbot/core/platform/sources/telegram/tg_event.py +0 -1
  22. astrbot/core/provider/entities.py +13 -3
  23. astrbot/core/provider/func_tool_manager.py +4 -4
  24. astrbot/core/provider/manager.py +35 -19
  25. astrbot/core/star/context.py +26 -12
  26. astrbot/core/star/filter/command.py +3 -4
  27. astrbot/core/star/filter/command_group.py +4 -4
  28. astrbot/core/star/filter/platform_adapter_type.py +10 -5
  29. astrbot/core/star/register/star.py +3 -1
  30. astrbot/core/star/register/star_handler.py +65 -36
  31. astrbot/core/star/session_plugin_manager.py +3 -0
  32. astrbot/core/star/star_handler.py +4 -4
  33. astrbot/core/star/star_manager.py +10 -4
  34. astrbot/core/star/star_tools.py +6 -2
  35. astrbot/core/star/updator.py +3 -0
  36. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/METADATA +6 -7
  37. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/RECORD +40 -36
  38. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/WHEEL +0 -0
  39. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/entry_points.txt +0 -0
  40. {astrbot-4.1.3.dist-info → astrbot-4.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,14 @@ from astrbot.api.platform import (
17
17
  register_platform_adapter,
18
18
  )
19
19
  from astrbot.core.platform.astr_message_event import MessageSession
20
- from astrbot.api.message_components import Plain, Image, At, File, Record
20
+ from astrbot.api.message_components import (
21
+ Plain,
22
+ Image,
23
+ At,
24
+ File,
25
+ Record,
26
+ Reply,
27
+ )
21
28
  from xml.etree import ElementTree as ET
22
29
 
23
30
 
@@ -38,12 +45,18 @@ class SatoriPlatformAdapter(Platform):
38
45
  )
39
46
  self.token = self.config.get("satori_token", "")
40
47
  self.endpoint = self.config.get(
41
- "satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
48
+ "satori_endpoint", "ws://localhost:5140/satori/v1/events"
42
49
  )
43
50
  self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
44
51
  self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
45
52
  self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
46
53
 
54
+ self.metadata = PlatformMetadata(
55
+ name="satori",
56
+ description="Satori 通用协议适配器",
57
+ id=self.config["id"],
58
+ )
59
+
47
60
  self.ws: Optional[ClientConnection] = None
48
61
  self.session: Optional[ClientSession] = None
49
62
  self.sequence = 0
@@ -63,7 +76,7 @@ class SatoriPlatformAdapter(Platform):
63
76
  await super().send_by_session(session, message_chain)
64
77
 
65
78
  def meta(self) -> PlatformMetadata:
66
- return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
79
+ return self.metadata
67
80
 
68
81
  def _is_websocket_closed(self, ws) -> bool:
69
82
  """检查WebSocket连接是否已关闭"""
@@ -312,12 +325,52 @@ class SatoriPlatformAdapter(Platform):
312
325
 
313
326
  abm.self_id = login.get("user", {}).get("id", "")
314
327
 
328
+ # 消息链
329
+ abm.message = []
330
+
315
331
  content = message.get("content", "")
316
- abm.message = await self.parse_satori_elements(content)
317
332
 
318
- # parse message_str
333
+ quote = message.get("quote")
334
+ content_for_parsing = content # 副本
335
+
336
+ # 提取<quote>标签
337
+ if "<quote" in content:
338
+ try:
339
+ quote_info = await self._extract_quote_element(content)
340
+ if quote_info:
341
+ quote = quote_info["quote"]
342
+ content_for_parsing = quote_info["content_without_quote"]
343
+ except Exception as e:
344
+ logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
345
+
346
+ if quote:
347
+ # 引用消息
348
+ quote_abm = await self._convert_quote_message(quote)
349
+ if quote_abm:
350
+ sender_id = quote_abm.sender.user_id
351
+ if isinstance(sender_id, str) and sender_id.isdigit():
352
+ sender_id = int(sender_id)
353
+ elif not isinstance(sender_id, int):
354
+ sender_id = 0 # 默认值
355
+
356
+ reply_component = Reply(
357
+ id=quote_abm.message_id,
358
+ chain=quote_abm.message,
359
+ sender_id=quote_abm.sender.user_id,
360
+ sender_nickname=quote_abm.sender.nickname,
361
+ time=quote_abm.timestamp,
362
+ message_str=quote_abm.message_str,
363
+ text=quote_abm.message_str,
364
+ qq=sender_id,
365
+ )
366
+ abm.message.append(reply_component)
367
+
368
+ # 解析消息内容
369
+ content_elements = await self.parse_satori_elements(content_for_parsing)
370
+ abm.message.extend(content_elements)
371
+
319
372
  abm.message_str = ""
320
- for comp in abm.message:
373
+ for comp in content_elements:
321
374
  if isinstance(comp, Plain):
322
375
  abm.message_str += comp.text
323
376
 
@@ -333,6 +386,163 @@ class SatoriPlatformAdapter(Platform):
333
386
  logger.error(f"转换 Satori 消息失败: {e}")
334
387
  return None
335
388
 
389
+ def _extract_namespace_prefixes(self, content: str) -> set:
390
+ """提取XML内容中的命名空间前缀"""
391
+ prefixes = set()
392
+
393
+ # 查找所有标签
394
+ i = 0
395
+ while i < len(content):
396
+ # 查找开始标签
397
+ if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
398
+ # 找到标签结束位置
399
+ tag_end = content.find(">", i)
400
+ if tag_end != -1:
401
+ # 提取标签内容
402
+ tag_content = content[i + 1 : tag_end]
403
+ # 检查是否有命名空间前缀
404
+ if ":" in tag_content and "xmlns:" not in tag_content:
405
+ # 分割标签名
406
+ parts = tag_content.split()
407
+ if parts:
408
+ tag_name = parts[0]
409
+ if ":" in tag_name:
410
+ prefix = tag_name.split(":")[0]
411
+ # 确保是有效的命名空间前缀
412
+ if (
413
+ prefix.isalnum()
414
+ or prefix.replace("_", "").isalnum()
415
+ ):
416
+ prefixes.add(prefix)
417
+ i = tag_end + 1
418
+ else:
419
+ i += 1
420
+ # 查找结束标签
421
+ elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
422
+ # 找到标签结束位置
423
+ tag_end = content.find(">", i)
424
+ if tag_end != -1:
425
+ # 提取标签内容
426
+ tag_content = content[i + 2 : tag_end]
427
+ # 检查是否有命名空间前缀
428
+ if ":" in tag_content:
429
+ prefix = tag_content.split(":")[0]
430
+ # 确保是有效的命名空间前缀
431
+ if prefix.isalnum() or prefix.replace("_", "").isalnum():
432
+ prefixes.add(prefix)
433
+ i = tag_end + 1
434
+ else:
435
+ i += 1
436
+ else:
437
+ i += 1
438
+
439
+ return prefixes
440
+
441
+ async def _extract_quote_element(self, content: str) -> Optional[dict]:
442
+ """提取<quote>标签信息"""
443
+ try:
444
+ # 处理命名空间前缀问题
445
+ processed_content = content
446
+ if ":" in content and not content.startswith("<root"):
447
+ prefixes = self._extract_namespace_prefixes(content)
448
+
449
+ # 构建命名空间声明
450
+ ns_declarations = " ".join(
451
+ [
452
+ f'xmlns:{prefix}="http://temp.uri/{prefix}"'
453
+ for prefix in prefixes
454
+ ]
455
+ )
456
+
457
+ # 包装内容
458
+ processed_content = f"<root {ns_declarations}>{content}</root>"
459
+ elif not content.startswith("<root"):
460
+ processed_content = f"<root>{content}</root>"
461
+ else:
462
+ processed_content = content
463
+
464
+ root = ET.fromstring(processed_content)
465
+
466
+ # 查找<quote>标签
467
+ quote_element = None
468
+ for elem in root.iter():
469
+ tag_name = elem.tag
470
+ if "}" in tag_name:
471
+ tag_name = tag_name.split("}")[1]
472
+ if tag_name.lower() == "quote":
473
+ quote_element = elem
474
+ break
475
+
476
+ if quote_element is not None:
477
+ # 提取quote标签的属性
478
+ quote_id = quote_element.get("id", "")
479
+
480
+ # 提取<quote>标签内部的内容
481
+ inner_content = ""
482
+ if quote_element.text:
483
+ inner_content += quote_element.text
484
+ for child in quote_element:
485
+ inner_content += ET.tostring(
486
+ child, encoding="unicode", method="xml"
487
+ )
488
+ if child.tail:
489
+ inner_content += child.tail
490
+
491
+ # 构造移除了<quote>标签的内容
492
+ content_without_quote = content.replace(
493
+ ET.tostring(quote_element, encoding="unicode", method="xml"), ""
494
+ )
495
+
496
+ return {
497
+ "quote": {"id": quote_id, "content": inner_content},
498
+ "content_without_quote": content_without_quote,
499
+ }
500
+
501
+ return None
502
+ except Exception as e:
503
+ logger.error(f"提取<quote>标签时发生错误: {e}")
504
+ return None
505
+
506
+ async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
507
+ """转换引用消息"""
508
+ try:
509
+ quote_abm = AstrBotMessage()
510
+ quote_abm.message_id = quote.get("id", "")
511
+
512
+ # 解析引用消息的发送者
513
+ quote_author = quote.get("author", {})
514
+ if quote_author:
515
+ quote_abm.sender = MessageMember(
516
+ user_id=quote_author.get("id", ""),
517
+ nickname=quote_author.get("nick", quote_author.get("name", "")),
518
+ )
519
+ else:
520
+ # 如果没有作者信息,使用默认值
521
+ quote_abm.sender = MessageMember(
522
+ user_id=quote.get("user_id", ""),
523
+ nickname="内容",
524
+ )
525
+
526
+ # 解析引用消息内容
527
+ quote_content = quote.get("content", "")
528
+ quote_abm.message = await self.parse_satori_elements(quote_content)
529
+
530
+ quote_abm.message_str = ""
531
+ for comp in quote_abm.message:
532
+ if isinstance(comp, Plain):
533
+ quote_abm.message_str += comp.text
534
+
535
+ quote_abm.timestamp = int(quote.get("timestamp", time.time()))
536
+
537
+ # 如果没有任何内容,使用默认文本
538
+ if not quote_abm.message_str.strip():
539
+ quote_abm.message_str = "[引用消息]"
540
+
541
+ return quote_abm
542
+ except Exception as e:
543
+ logger.error(f"转换引用消息失败: {e}")
544
+ return None
545
+
336
546
  async def parse_satori_elements(self, content: str) -> list:
337
547
  """解析 Satori 消息元素"""
338
548
  elements = []
@@ -341,12 +551,35 @@ class SatoriPlatformAdapter(Platform):
341
551
  return elements
342
552
 
343
553
  try:
344
- wrapped_content = f"<root>{content}</root>"
345
- root = ET.fromstring(wrapped_content)
554
+ # 处理命名空间前缀问题
555
+ processed_content = content
556
+ if ":" in content and not content.startswith("<root"):
557
+ prefixes = self._extract_namespace_prefixes(content)
558
+
559
+ # 构建命名空间声明
560
+ ns_declarations = " ".join(
561
+ [
562
+ f'xmlns:{prefix}="http://temp.uri/{prefix}"'
563
+ for prefix in prefixes
564
+ ]
565
+ )
566
+
567
+ # 包装内容
568
+ processed_content = f"<root {ns_declarations}>{content}</root>"
569
+ elif not content.startswith("<root"):
570
+ processed_content = f"<root>{content}</root>"
571
+ else:
572
+ processed_content = content
573
+
574
+ root = ET.fromstring(processed_content)
346
575
  await self._parse_xml_node(root, elements)
347
576
  except ET.ParseError as e:
348
- raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
577
+ logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
578
+ # 如果解析失败,将整个内容当作纯文本
579
+ if content.strip():
580
+ elements.append(Plain(text=content))
349
581
  except Exception as e:
582
+ logger.error(f"解析 Satori 元素时发生未知错误: {e}")
350
583
  raise e
351
584
 
352
585
  # 如果没有解析到任何元素,将整个内容当作纯文本
@@ -361,7 +594,12 @@ class SatoriPlatformAdapter(Platform):
361
594
  elements.append(Plain(text=node.text))
362
595
 
363
596
  for child in node:
364
- tag_name = child.tag.lower()
597
+ # 获取标签名,去除命名空间前缀
598
+ tag_name = child.tag
599
+ if "}" in tag_name:
600
+ tag_name = tag_name.split("}")[1]
601
+ tag_name = tag_name.lower()
602
+
365
603
  attrs = child.attrib
366
604
 
367
605
  if tag_name == "at":
@@ -372,31 +610,59 @@ class SatoriPlatformAdapter(Platform):
372
610
  src = attrs.get("src", "")
373
611
  if not src:
374
612
  continue
375
- if src.startswith("data:image/"):
376
- src = src.split(",")[1]
377
- elements.append(Image.fromBase64(src))
378
- elif src.startswith("http"):
379
- elements.append(Image.fromURL(src))
380
- else:
381
- logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
613
+ elements.append(Image(file=src))
382
614
 
383
615
  elif tag_name == "file":
384
616
  src = attrs.get("src", "")
385
617
  name = attrs.get("name", "文件")
386
618
  if src:
387
- elements.append(File(file=src, name=name))
619
+ elements.append(File(name=name, file=src))
388
620
 
389
621
  elif tag_name in ("audio", "record"):
390
622
  src = attrs.get("src", "")
391
623
  if not src:
392
624
  continue
393
- if src.startswith("data:audio/"):
394
- src = src.split(",")[1]
395
- elements.append(Record.fromBase64(src))
396
- elif src.startswith("http"):
397
- elements.append(Record.fromURL(src))
625
+ elements.append(Record(file=src))
626
+
627
+ elif tag_name == "quote":
628
+ # quote标签已经被特殊处理
629
+ pass
630
+
631
+ elif tag_name == "face":
632
+ face_id = attrs.get("id", "")
633
+ face_name = attrs.get("name", "")
634
+ face_type = attrs.get("type", "")
635
+
636
+ if face_name:
637
+ elements.append(Plain(text=f"[表情:{face_name}]"))
638
+ elif face_id and face_type:
639
+ elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
640
+ elif face_id:
641
+ elements.append(Plain(text=f"[表情ID:{face_id}]"))
642
+ else:
643
+ elements.append(Plain(text="[表情]"))
644
+
645
+ elif tag_name == "ark":
646
+ # 作为纯文本添加到消息链中
647
+ data = attrs.get("data", "")
648
+ if data:
649
+ import html
650
+
651
+ decoded_data = html.unescape(data)
652
+ elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
653
+ else:
654
+ elements.append(Plain(text="[ARK卡片]"))
655
+
656
+ elif tag_name == "json":
657
+ # JSON标签 视为ARK卡片消息
658
+ data = attrs.get("data", "")
659
+ if data:
660
+ import html
661
+
662
+ decoded_data = html.unescape(data)
663
+ elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
398
664
  else:
399
- logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
665
+ elements.append(Plain(text="[JSON卡片]"))
400
666
 
401
667
  else:
402
668
  # 未知标签,递归处理其内容
@@ -17,6 +17,15 @@ class SatoriPlatformEvent(AstrMessageEvent):
17
17
  session_id: str,
18
18
  adapter: "SatoriPlatformAdapter",
19
19
  ):
20
+ # 更新平台元数据
21
+ if adapter and hasattr(adapter, "logins") and adapter.logins:
22
+ current_login = adapter.logins[0]
23
+ platform_name = current_login.get("platform", "satori")
24
+ user = current_login.get("user", {})
25
+ user_id = user.get("id", "") if user else ""
26
+ if not platform_meta.id and user_id:
27
+ platform_meta.id = f"{platform_name}({user_id})"
28
+
20
29
  super().__init__(message_str, message_obj, platform_meta, session_id)
21
30
  self.adapter = adapter
22
31
  self.platform = None
@@ -218,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
218
218
  try:
219
219
  msg = await self.client.send_message(text=delta, **payload)
220
220
  current_content = delta
221
- delta = ""
222
221
  except Exception as e:
223
222
  logger.warning(f"发送消息失败(streaming): {e!s}")
224
223
  message_id = msg.message_id
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
65
65
  role: str = "assistant"
66
66
 
67
67
  def to_dict(self):
68
- ret = {
68
+ ret: dict[str, str | list[dict]] = {
69
69
  "role": self.role,
70
70
  }
71
71
  if self.content:
72
72
  ret["content"] = self.content
73
73
  if self.tool_calls:
74
- ret["tool_calls"] = self.tool_calls
74
+ tool_calls_dict = [
75
+ tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
76
+ ]
77
+ ret["tool_calls"] = tool_calls_dict
75
78
  return ret
76
79
 
77
80
 
@@ -117,7 +120,14 @@ class ProviderRequest:
117
120
  """模型名称,为 None 时使用提供商的默认模型"""
118
121
 
119
122
  def __repr__(self):
120
- return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
123
+ return (
124
+ f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
125
+ f"image_count={len(self.image_urls or [])}, "
126
+ f"func_tool={self.func_tool}, "
127
+ f"contexts={self._print_friendly_context()}, "
128
+ f"system_prompt={self.system_prompt}, "
129
+ f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
130
+ )
121
131
 
122
132
  def __str__(self):
123
133
  return self.__repr__()
@@ -4,7 +4,7 @@ import os
4
4
  import asyncio
5
5
  import aiohttp
6
6
 
7
- from typing import Dict, List, Awaitable
7
+ from typing import Dict, List, Awaitable, Callable, Any
8
8
  from astrbot import logger
9
9
  from astrbot.core import sp
10
10
 
@@ -109,7 +109,7 @@ class FunctionToolManager:
109
109
  name: str,
110
110
  func_args: list,
111
111
  desc: str,
112
- handler: Awaitable,
112
+ handler: Callable[..., Awaitable[Any]],
113
113
  ) -> FuncTool:
114
114
  params = {
115
115
  "type": "object", # hard-coded here
@@ -132,7 +132,7 @@ class FunctionToolManager:
132
132
  name: str,
133
133
  func_args: list,
134
134
  desc: str,
135
- handler: Awaitable,
135
+ handler: Callable[..., Awaitable[Any]],
136
136
  ) -> None:
137
137
  """添加函数调用工具
138
138
 
@@ -220,7 +220,7 @@ class FunctionToolManager:
220
220
  name: str,
221
221
  cfg: dict,
222
222
  event: asyncio.Event,
223
- ready_future: asyncio.Future = None,
223
+ ready_future: asyncio.Future | None = None,
224
224
  ) -> None:
225
225
  """初始化 MCP 客户端的包装函数,用于捕获异常"""
226
226
  try:
@@ -38,7 +38,7 @@ class ProviderManager:
38
38
  """加载的 Text To Speech Provider 的实例"""
39
39
  self.embedding_provider_insts: List[EmbeddingProvider] = []
40
40
  """加载的 Embedding Provider 的实例"""
41
- self.inst_map: dict[str, Provider] = {}
41
+ self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {}
42
42
  """Provider 实例映射. key: provider_id, value: Provider 实例"""
43
43
  self.llm_tools = llm_tools
44
44
 
@@ -87,19 +87,31 @@ class ProviderManager:
87
87
  )
88
88
  return
89
89
  # 不启用提供商会话隔离模式的情况
90
- self.curr_provider_inst = self.inst_map[provider_id]
91
- if provider_type == ProviderType.TEXT_TO_SPEECH:
90
+
91
+ prov = self.inst_map[provider_id]
92
+ if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
93
+ prov, TTSProvider
94
+ ):
95
+ self.curr_tts_provider_inst = prov
92
96
  sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
93
- elif provider_type == ProviderType.SPEECH_TO_TEXT:
97
+ elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
98
+ prov, STTProvider
99
+ ):
100
+ self.curr_stt_provider_inst = prov
94
101
  sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
95
- elif provider_type == ProviderType.CHAT_COMPLETION:
102
+ elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
103
+ prov, Provider
104
+ ):
105
+ self.curr_provider_inst = prov
96
106
  sp.put("curr_provider", provider_id, scope="global", scope_id="global")
97
107
 
98
108
  async def get_provider_by_id(self, provider_id: str) -> Provider | None:
99
109
  """根据提供商 ID 获取提供商实例"""
100
110
  return self.inst_map.get(provider_id)
101
111
 
102
- def get_using_provider(self, provider_type: ProviderType, umo=None):
112
+ def get_using_provider(
113
+ self, provider_type: ProviderType, umo=None
114
+ ) -> Provider | STTProvider | TTSProvider | None:
103
115
  """获取正在使用的提供商实例。
104
116
 
105
117
  Args:
@@ -303,12 +315,14 @@ class ProviderManager:
303
315
  provider_metadata = provider_cls_map[provider_config["type"]]
304
316
  try:
305
317
  # 按任务实例化提供商
318
+ cls_type = provider_metadata.cls_type
319
+ if not cls_type:
320
+ logger.error(f"无法找到 {provider_metadata.type} 的类")
321
+ return
306
322
 
307
323
  if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
308
324
  # STT 任务
309
- inst = provider_metadata.cls_type(
310
- provider_config, self.provider_settings
311
- )
325
+ inst = cls_type(provider_config, self.provider_settings)
312
326
 
313
327
  if getattr(inst, "initialize", None):
314
328
  await inst.initialize()
@@ -327,9 +341,7 @@ class ProviderManager:
327
341
 
328
342
  elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
329
343
  # TTS 任务
330
- inst = provider_metadata.cls_type(
331
- provider_config, self.provider_settings
332
- )
344
+ inst = cls_type(provider_config, self.provider_settings)
333
345
 
334
346
  if getattr(inst, "initialize", None):
335
347
  await inst.initialize()
@@ -345,7 +357,7 @@ class ProviderManager:
345
357
 
346
358
  elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
347
359
  # 文本生成任务
348
- inst = provider_metadata.cls_type(
360
+ inst = cls_type(
349
361
  provider_config,
350
362
  self.provider_settings,
351
363
  self.selected_default_persona,
@@ -370,9 +382,7 @@ class ProviderManager:
370
382
  ProviderType.EMBEDDING,
371
383
  ProviderType.RERANK,
372
384
  ]:
373
- inst = provider_metadata.cls_type(
374
- provider_config, self.provider_settings
375
- )
385
+ inst = cls_type(provider_config, self.provider_settings)
376
386
  if getattr(inst, "initialize", None):
377
387
  await inst.initialize()
378
388
  self.embedding_provider_insts.append(inst)
@@ -430,11 +440,17 @@ class ProviderManager:
430
440
  )
431
441
 
432
442
  if self.inst_map[provider_id] in self.provider_insts:
433
- self.provider_insts.remove(self.inst_map[provider_id])
443
+ prov_inst = self.inst_map[provider_id]
444
+ if isinstance(prov_inst, Provider):
445
+ self.provider_insts.remove(prov_inst)
434
446
  if self.inst_map[provider_id] in self.stt_provider_insts:
435
- self.stt_provider_insts.remove(self.inst_map[provider_id])
447
+ prov_inst = self.inst_map[provider_id]
448
+ if isinstance(prov_inst, STTProvider):
449
+ self.stt_provider_insts.remove(prov_inst)
436
450
  if self.inst_map[provider_id] in self.tts_provider_insts:
437
- self.tts_provider_insts.remove(self.inst_map[provider_id])
451
+ prov_inst = self.inst_map[provider_id]
452
+ if isinstance(prov_inst, TTSProvider):
453
+ self.tts_provider_insts.remove(prov_inst)
438
454
 
439
455
  if self.inst_map[provider_id] == self.curr_provider_inst:
440
456
  self.curr_provider_inst = None